Commit Β·
672ed11
1
Parent(s): 1a97904
Remove unused files: old Gradio frontend, dead model code, test artifacts
Browse files- frontend/ β entire old Gradio UI, replaced by React
- models/explainability.py β old GradCAM, replaced by gradcam_tool.py
- models/medsiglip_convnext_fusion.py β abandoned experimental model
- models/monet_concepts.py β old concept scorer, replaced by monet_tool.py
- test_models.py β one-off test script
- KAGGLE_REPORT.md β temporary report file
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- KAGGLE_REPORT.md +0 -104
- frontend/app.py +0 -532
- frontend/components/__init__.py +0 -0
- frontend/components/analysis_view.py +0 -214
- frontend/components/patient_select.py +0 -48
- frontend/components/sidebar.py +0 -55
- frontend/components/styles.py +0 -517
- models/explainability.py +0 -183
- models/medsiglip_convnext_fusion.py +0 -224
- models/monet_concepts.py +0 -332
- test_models.py +0 -86
KAGGLE_REPORT.md
DELETED
|
@@ -1,104 +0,0 @@
|
|
| 1 |
-
# SkinProAI: Explainable Multi-Model Dermatology Decision Support
|
| 2 |
-
|
| 3 |
-
## 1. System Architecture & AI Pipeline
|
| 4 |
-
|
| 5 |
-
SkinProAI is a clinician-first dermatology decision-support system that orchestrates multiple specialised AI models through a four-phase analysis pipeline. The system is designed to run entirely on-device using Google's MedGemma 4B β a 4-billion-parameter vision-language model purpose-built for edge deployment in clinical settings β executing on consumer hardware (Apple M4 Pro / MPS) in bfloat16 precision without requiring cloud GPU infrastructure.
|
| 6 |
-
|
| 7 |
-
### 1.1 Multi-Model Orchestration via MCP
|
| 8 |
-
|
| 9 |
-
Rather than relying on a single monolithic model, SkinProAI employs a Model Context Protocol (MCP) architecture that isolates each AI capability as an independent tool callable via JSON-RPC over subprocess stdio. This design provides fault isolation, independent model loading, and a clear audit trail of which model produced which output.
|
| 10 |
-
|
| 11 |
-
| Tool | Model | Purpose |
|
| 12 |
-
|------|-------|---------|
|
| 13 |
-
| `monet_analyze` | MONET (contrastive vision-language) | Extract 7 dermoscopic concept scores: ulceration, hair structures, vascular patterns, erythema, pigmented structures, gel artefacts, skin markings |
|
| 14 |
-
| `classify_lesion` | ConvNeXt Base + metadata MLP | 11-class lesion classification using dual-encoder architecture combining image features with MONET concept scores |
|
| 15 |
-
| `generate_gradcam` | Grad-CAM on ConvNeXt | Generate visual attention heatmap overlays highlighting regions driving the classification decision |
|
| 16 |
-
| `search_guidelines` | FAISS + all-MiniLM-L6-v2 | Retrieve relevant clinical guideline passages from a RAG index of 286 chunks across 7 dermatology PDFs |
|
| 17 |
-
| `compare_images` | MONET + overlay | Temporal change detection comparing dermoscopic feature vectors between sequential images |
|
| 18 |
-
|
| 19 |
-
All tools receive absolute file paths as input (never raw image data), enabling secure subprocess isolation between the main FastAPI process (Python 3.9) and the MCP server (Python 3.11).
|
| 20 |
-
|
| 21 |
-
### 1.2 Four-Phase Analysis Pipeline
|
| 22 |
-
|
| 23 |
-
Every image analysis follows a structured, transparent pipeline:
|
| 24 |
-
|
| 25 |
-
**Phase 1 β Independent Visual Examination.** MedGemma 4B performs a systematic dermoscopic assessment *before* seeing any AI tool output. It evaluates pattern architecture, colour distribution, border characteristics, and structural features, producing differential diagnoses ranked by clinical probability. This deliberate sequencing prevents anchoring bias β the language model forms its own clinical impression independently.
|
| 26 |
-
|
| 27 |
-
**Phase 2 β AI Classification Tools.** Three MCP tools execute in sequence: MONET extracts quantitative concept scores (0β1 range for each dermoscopic feature), ConvNeXt classifies the lesion into one of 11 diagnostic categories with calibrated probabilities, and Grad-CAM generates an attention overlay showing which image regions drove the classification. Each tool's output is streamed to the clinician in real-time with visual bar charts.
|
| 28 |
-
|
| 29 |
-
**Phase 3 β Reconciliation.** MedGemma receives both its own Phase 1 assessment and the Phase 2 tool outputs, then performs explicit agreement/disagreement analysis. It identifies where its visual findings align with or diverge from the quantitative classifiers, produces an integrated assessment with a stated confidence level, and explains its reasoning. This adversarial cross-check between independent assessments is central to the system's reliability.
|
| 30 |
-
|
| 31 |
-
**Phase 4 β Management Guidance with RAG.** The system automatically queries a FAISS-indexed knowledge base of clinical dermatology guidelines (BAD, NICE, and specialist PDFs covering BCC, SCC, melanoma, actinic keratosis, contact dermatitis, lichen sclerosus, and cutaneous warts). Using the diagnosed condition as a search query, the RAG system retrieves the top-5 most relevant guideline passages via cosine similarity over sentence-transformer embeddings (all-MiniLM-L6-v2, 384 dimensions). MedGemma then synthesises lesion-specific management recommendations β biopsy, excision, monitoring, or discharge β grounded in the retrieved evidence, with inline superscript citations linking back to source documents and page numbers.
|
| 32 |
-
|
| 33 |
-
---
|
| 34 |
-
|
| 35 |
-
## 2. AI Explainability
|
| 36 |
-
|
| 37 |
-
Explainability is not an afterthought in SkinProAI β it is embedded at every layer of the architecture. The system provides three complementary forms of explanation: visual, quantitative, and narrative.
|
| 38 |
-
|
| 39 |
-
### 2.1 Visual Explainability: Grad-CAM Attention Maps
|
| 40 |
-
|
| 41 |
-
Grad-CAM (Gradient-weighted Class Activation Mapping) hooks into the final convolutional layer of the ConvNeXt classifier to produce a spatial heatmap overlay on the original dermoscopic image. This directly answers the question "where is the model looking?" β showing clinicians which morphological features (border irregularity, colour variegation, structural asymmetry) are driving the classification. For temporal comparisons, the system generates side-by-side Grad-CAM pairs (previous vs. current) so clinicians can visually assess whether attention regions have shifted, expanded, or resolved.
|
| 42 |
-
|
| 43 |
-
### 2.2 Quantitative Explainability: MONET Concept Scores
|
| 44 |
-
|
| 45 |
-
MONET provides human-interpretable concept decomposition by scoring seven clinically meaningful dermoscopic features on a continuous 0β1 scale. Unlike black-box classifiers that output only a label and probability, MONET's concept scores reveal *why* a lesion received its classification: a high vascular score combined with low pigmentation explains a vascular lesion diagnosis; high pigmented-structure scores with border irregularity support melanocytic concern. These scores are rendered as visual bar charts in the streaming UI, giving clinicians immediate quantitative insight into the feature profile driving the AI's assessment.
|
| 46 |
-
|
| 47 |
-
When comparing sequential images over time, the system computes MONET feature deltas β the signed difference in each concept score between timepoints β enabling objective quantification of lesion evolution. A change from 0.3 to 0.7 in the vascular score, for example, signals new vessel formation that warrants clinical attention, independent of any subjective visual comparison.
|
| 48 |
-
|
| 49 |
-
### 2.3 Narrative Explainability: MedGemma Reasoning Transparency
|
| 50 |
-
|
| 51 |
-
The streaming interface exposes MedGemma's reasoning process through structured markup segments. `[THINKING]` blocks display the model's intermediate reasoning (differential construction, feature weighting, agreement analysis) in real-time, with animated spinners that resolve to completion indicators as each reasoning phase finishes. `[RESPONSE]` blocks contain the synthesised clinical narrative. This staged transparency allows clinicians to follow the AI's analytical process rather than receiving only a final pronouncement.
|
| 52 |
-
|
| 53 |
-
The reconciliation phase (Phase 3) is particularly significant for explainability: MedGemma explicitly states where its independent visual assessment agrees or disagrees with the quantitative classifiers, and explains the basis for its integrated conclusion. This adversarial structure makes disagreements visible and auditable.
|
| 54 |
-
|
| 55 |
-
### 2.4 Evidence Grounding: RAG Citations
|
| 56 |
-
|
| 57 |
-
Management recommendations in Phase 4 include inline superscript references (e.g., "Wide local excision with 2cm margins is recommended for tumours >2mm Breslow thicknessΒΉ") linked to specific guideline documents and page numbers. This evidence chain β from diagnosis through to management recommendation β is fully traceable to published clinical guidelines, supporting regulatory compliance and clinical governance requirements.
|
| 58 |
-
|
| 59 |
-
---
|
| 60 |
-
|
| 61 |
-
## 3. Clinician-First Design & User Interface
|
| 62 |
-
|
| 63 |
-
### 3.1 Design Philosophy
|
| 64 |
-
|
| 65 |
-
SkinProAI is built around the principle that AI should augment clinical decision-making, not replace it. The interface is designed for the workflow of a clinician reviewing dermoscopic images: patient selection, lesion documentation, temporal tracking, and structured analysis with clear next-step guidance.
|
| 66 |
-
|
| 67 |
-
The system deliberately avoids presenting AI conclusions as definitive diagnoses. Instead, results are framed as ranked differentials with calibrated confidence scores, supported by quantitative feature evidence and visual attention maps. The clinician retains full agency over the diagnostic and management decision.
|
| 68 |
-
|
| 69 |
-
### 3.2 Streaming Real-Time Interface
|
| 70 |
-
|
| 71 |
-
The UI employs Server-Sent Events (SSE) to stream analysis output in real-time, maintaining clinician engagement during model inference. Rather than a loading spinner followed by a wall of text, clinicians observe the analysis unfolding phase by phase:
|
| 72 |
-
|
| 73 |
-
- **Tool status lines** appear as compact single-line indicators with animated spinners that resolve to green completion dots, showing tool name and summary result (e.g., "ConvNeXt classification β Melanoma (89%)")
|
| 74 |
-
- **Thinking indicators** display MedGemma's reasoning steps with spinners that transition to done states as each phase completes
|
| 75 |
-
- **Streamed text** appears word-by-word, allowing clinicians to begin reading findings before analysis completes
|
| 76 |
-
|
| 77 |
-
This streaming pattern reduces perceived latency and provides transparency into the multi-model pipeline's progress.
|
| 78 |
-
|
| 79 |
-
### 3.3 Temporal Lesion Tracking
|
| 80 |
-
|
| 81 |
-
The data model supports longitudinal monitoring through a Patient β Lesion β LesionImage hierarchy. Each lesion maintains a timeline of images with timestamps, and the system automatically triggers temporal comparison when a new image is uploaded for a previously-analysed lesion. Comparison results include MONET feature deltas, side-by-side Grad-CAM overlays, and a status classification (Stable / Minor Change / Significant Change / Improved) rendered with colour-coded indicators.
|
| 82 |
-
|
| 83 |
-
This temporal architecture supports the clinical workflow for monitoring suspicious lesions over time β a common scenario in dermatology where serial dermoscopy is preferred over immediate biopsy for lesions with intermediate risk profiles.
|
| 84 |
-
|
| 85 |
-
### 3.4 Conversational Follow-Up
|
| 86 |
-
|
| 87 |
-
After analysis completes, the system transitions to a conversational interface where clinicians can ask follow-up questions grounded in the analysis context. The chat maintains full awareness of the diagnosed condition, MONET features, classification results, and guideline context, enabling queries like "what are the excision margin recommendations for this depth?" or "how does the asymmetry score compare to the previous image?" The text-only chat pathway routes through MedGemma with the full analysis state, ensuring responses remain clinically contextualised.
|
| 88 |
-
|
| 89 |
-
### 3.5 Edge Deployment Architecture
|
| 90 |
-
|
| 91 |
-
SkinProAI runs entirely on-device using MedGemma 4B in bfloat16 precision on consumer hardware (demonstrated on Apple M4 Pro with 24GB RAM using Metal Performance Shaders). No patient data leaves the device β all inference, image storage, and guideline retrieval execute locally. This edge-first architecture addresses data sovereignty requirements in clinical settings where patient images cannot be transmitted to cloud services. The system also supports containerised deployment via Docker for institutional environments, with optional GPU acceleration on CUDA-equipped hardware or HuggingFace Spaces.
|
| 92 |
-
|
| 93 |
-
### 3.6 Interface Components
|
| 94 |
-
|
| 95 |
-
The frontend is built in React 18 with a glassmorphism-inspired design language. Key interface elements include:
|
| 96 |
-
|
| 97 |
-
- **Patient grid** β Card-based patient selection with creation workflow
|
| 98 |
-
- **Chat interface** β Unified conversation view combining image analysis, tool outputs, and text chat in a single scrollable thread
|
| 99 |
-
- **Tool call cards** β Inline status indicators for each AI tool invocation (analyse, classify, Grad-CAM, guidelines search, compare) with expandable result summaries
|
| 100 |
-
- **Image upload** β Drag-and-drop or click-to-upload with preview, supporting both initial analysis and temporal comparison workflows
|
| 101 |
-
- **Post-analysis prompt** β Contextual hint guiding clinicians to ask follow-up questions, provide additional context, or upload comparison images
|
| 102 |
-
- **Markdown rendering** β Clinical narratives rendered with proper formatting, headers, lists, and inline citations for readability
|
| 103 |
-
|
| 104 |
-
The interface prioritises information density without clutter: tool outputs collapse to single-line summaries by default, reasoning phases are visually distinguished from conclusions, and temporal comparisons use colour-coded status indicators (green/amber/red/blue) that map to clinical urgency levels.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
frontend/app.py
DELETED
|
@@ -1,532 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
SkinProAI Frontend - Modular Gradio application
|
| 3 |
-
"""
|
| 4 |
-
|
| 5 |
-
import gradio as gr
|
| 6 |
-
from typing import Dict, Generator, Optional
|
| 7 |
-
from datetime import datetime
|
| 8 |
-
import sys
|
| 9 |
-
import os
|
| 10 |
-
import re
|
| 11 |
-
import base64
|
| 12 |
-
|
| 13 |
-
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 14 |
-
|
| 15 |
-
from data.case_store import get_case_store
|
| 16 |
-
from frontend.components.styles import MAIN_CSS
|
| 17 |
-
from frontend.components.analysis_view import format_output
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
# =============================================================================
|
| 21 |
-
# CONFIG
|
| 22 |
-
# =============================================================================
|
| 23 |
-
|
| 24 |
-
class Config:
|
| 25 |
-
APP_TITLE = "SkinProAI"
|
| 26 |
-
SERVER_PORT = int(os.environ.get("GRADIO_SERVER_PORT", "7860"))
|
| 27 |
-
HF_SPACES = os.environ.get("SPACE_ID") is not None
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
# =============================================================================
|
| 31 |
-
# AGENT
|
| 32 |
-
# =============================================================================
|
| 33 |
-
|
| 34 |
-
class AnalysisAgent:
|
| 35 |
-
"""Wrapper for the MedGemma analysis agent"""
|
| 36 |
-
|
| 37 |
-
def __init__(self):
|
| 38 |
-
self.model = None
|
| 39 |
-
self.loaded = False
|
| 40 |
-
|
| 41 |
-
def load(self):
|
| 42 |
-
if self.loaded:
|
| 43 |
-
return
|
| 44 |
-
from models.medgemma_agent import MedGemmaAgent
|
| 45 |
-
self.model = MedGemmaAgent(verbose=True)
|
| 46 |
-
self.model.load_model()
|
| 47 |
-
self.loaded = True
|
| 48 |
-
|
| 49 |
-
def analyze(self, image_path: str, question: str = "") -> Generator[str, None, None]:
|
| 50 |
-
if not self.loaded:
|
| 51 |
-
yield "[STAGE:loading]Loading AI models...[/STAGE]\n"
|
| 52 |
-
self.load()
|
| 53 |
-
|
| 54 |
-
for chunk in self.model.analyze_image_stream(image_path, question=question):
|
| 55 |
-
yield chunk
|
| 56 |
-
|
| 57 |
-
def management_guidance(self, confirmed: bool, feedback: str = None) -> Generator[str, None, None]:
|
| 58 |
-
for chunk in self.model.generate_management_guidance(confirmed, feedback):
|
| 59 |
-
yield chunk
|
| 60 |
-
|
| 61 |
-
def followup(self, message: str) -> Generator[str, None, None]:
|
| 62 |
-
if not self.loaded or not self.model.last_diagnosis:
|
| 63 |
-
yield "[ERROR]No analysis context available.[/ERROR]\n"
|
| 64 |
-
return
|
| 65 |
-
for chunk in self.model.chat_followup(message):
|
| 66 |
-
yield chunk
|
| 67 |
-
|
| 68 |
-
def reset(self):
|
| 69 |
-
if self.model:
|
| 70 |
-
self.model.reset_state()
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
agent = AnalysisAgent()
|
| 74 |
-
case_store = get_case_store()
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
# =============================================================================
|
| 78 |
-
# APP
|
| 79 |
-
# =============================================================================
|
| 80 |
-
|
| 81 |
-
with gr.Blocks(title=Config.APP_TITLE, css=MAIN_CSS, theme=gr.themes.Soft()) as app:
|
| 82 |
-
|
| 83 |
-
# =========================================================================
|
| 84 |
-
# STATE
|
| 85 |
-
# =========================================================================
|
| 86 |
-
state = gr.State({
|
| 87 |
-
"page": "patient_select", # patient_select | analysis
|
| 88 |
-
"case_id": None,
|
| 89 |
-
"instance_id": None,
|
| 90 |
-
"output": "",
|
| 91 |
-
"gradcam_base64": None
|
| 92 |
-
})
|
| 93 |
-
|
| 94 |
-
# =========================================================================
|
| 95 |
-
# PAGE 1: PATIENT SELECTION
|
| 96 |
-
# =========================================================================
|
| 97 |
-
with gr.Group(visible=True, elem_classes=["patient-select-container"]) as page_patient:
|
| 98 |
-
gr.Markdown("# SkinProAI", elem_classes=["patient-select-title"])
|
| 99 |
-
gr.Markdown("Select a patient to continue or create a new case", elem_classes=["patient-select-subtitle"])
|
| 100 |
-
|
| 101 |
-
with gr.Row(elem_classes=["patient-grid"]):
|
| 102 |
-
btn_demo_melanoma = gr.Button("Demo: Melanocytic Lesion", elem_classes=["patient-card"])
|
| 103 |
-
btn_demo_ak = gr.Button("Demo: Actinic Keratosis", elem_classes=["patient-card"])
|
| 104 |
-
btn_new_patient = gr.Button("+ New Patient", variant="primary", elem_classes=["new-patient-btn"])
|
| 105 |
-
|
| 106 |
-
# =========================================================================
|
| 107 |
-
# PAGE 2: ANALYSIS
|
| 108 |
-
# =========================================================================
|
| 109 |
-
with gr.Group(visible=False) as page_analysis:
|
| 110 |
-
|
| 111 |
-
# Header
|
| 112 |
-
with gr.Row(elem_classes=["app-header"]):
|
| 113 |
-
gr.Markdown(f"**{Config.APP_TITLE}**", elem_classes=["app-title"])
|
| 114 |
-
btn_back = gr.Button("< Back to Patients", elem_classes=["back-btn"])
|
| 115 |
-
|
| 116 |
-
with gr.Row(elem_classes=["analysis-container"]):
|
| 117 |
-
|
| 118 |
-
# Sidebar (previous queries)
|
| 119 |
-
with gr.Column(scale=0, min_width=260, visible=False, elem_classes=["query-sidebar"]) as sidebar:
|
| 120 |
-
gr.Markdown("### Previous Queries", elem_classes=["sidebar-header"])
|
| 121 |
-
sidebar_list = gr.Column(elem_id="sidebar-queries")
|
| 122 |
-
btn_new_query = gr.Button("+ New Query", size="sm", variant="primary")
|
| 123 |
-
|
| 124 |
-
# Main content
|
| 125 |
-
with gr.Column(scale=4, elem_classes=["main-content"]):
|
| 126 |
-
|
| 127 |
-
# Input view (greeting style)
|
| 128 |
-
with gr.Group(visible=True, elem_classes=["input-greeting"]) as view_input:
|
| 129 |
-
gr.Markdown("What would you like to analyze?", elem_classes=["greeting-title"])
|
| 130 |
-
gr.Markdown("Upload an image and describe what you'd like to know", elem_classes=["greeting-subtitle"])
|
| 131 |
-
|
| 132 |
-
with gr.Column(elem_classes=["input-box-container"]):
|
| 133 |
-
input_message = gr.Textbox(
|
| 134 |
-
placeholder="Describe the lesion or ask a question...",
|
| 135 |
-
show_label=False,
|
| 136 |
-
lines=2,
|
| 137 |
-
elem_classes=["message-input"]
|
| 138 |
-
)
|
| 139 |
-
|
| 140 |
-
input_image = gr.Image(
|
| 141 |
-
type="pil",
|
| 142 |
-
height=180,
|
| 143 |
-
show_label=False,
|
| 144 |
-
elem_classes=["image-preview"]
|
| 145 |
-
)
|
| 146 |
-
|
| 147 |
-
with gr.Row(elem_classes=["input-actions"]):
|
| 148 |
-
gr.Markdown("*Upload a skin lesion image*")
|
| 149 |
-
btn_analyze = gr.Button("Analyze", elem_classes=["send-btn"], interactive=False)
|
| 150 |
-
|
| 151 |
-
# Results view (shown after analysis)
|
| 152 |
-
with gr.Group(visible=False, elem_classes=["chat-view"]) as view_results:
|
| 153 |
-
output_html = gr.HTML(
|
| 154 |
-
value='<div class="analysis-output">Starting...</div>',
|
| 155 |
-
elem_classes=["results-area"]
|
| 156 |
-
)
|
| 157 |
-
|
| 158 |
-
# Confirmation
|
| 159 |
-
with gr.Group(visible=False, elem_classes=["confirm-buttons"]) as confirm_box:
|
| 160 |
-
gr.Markdown("**Do you agree with this diagnosis?**")
|
| 161 |
-
with gr.Row():
|
| 162 |
-
btn_confirm_yes = gr.Button("Yes, continue", variant="primary", size="sm")
|
| 163 |
-
btn_confirm_no = gr.Button("No, I disagree", variant="secondary", size="sm")
|
| 164 |
-
input_feedback = gr.Textbox(label="Your assessment", placeholder="Enter diagnosis...", visible=False)
|
| 165 |
-
btn_submit_feedback = gr.Button("Submit", visible=False, size="sm")
|
| 166 |
-
|
| 167 |
-
# Follow-up
|
| 168 |
-
with gr.Row(elem_classes=["chat-input-area"]):
|
| 169 |
-
input_followup = gr.Textbox(placeholder="Ask a follow-up question...", show_label=False, lines=1, scale=4)
|
| 170 |
-
btn_followup = gr.Button("Send", size="sm", scale=1)
|
| 171 |
-
|
| 172 |
-
# =========================================================================
|
| 173 |
-
# DYNAMIC SIDEBAR RENDERING
|
| 174 |
-
# =========================================================================
|
| 175 |
-
@gr.render(inputs=[state], triggers=[state.change])
|
| 176 |
-
def render_sidebar(s):
|
| 177 |
-
case_id = s.get("case_id")
|
| 178 |
-
if not case_id or s.get("page") != "analysis":
|
| 179 |
-
return
|
| 180 |
-
|
| 181 |
-
instances = case_store.list_instances(case_id)
|
| 182 |
-
current = s.get("instance_id")
|
| 183 |
-
|
| 184 |
-
for i, inst in enumerate(instances, 1):
|
| 185 |
-
diagnosis = "Pending"
|
| 186 |
-
if inst.analysis and inst.analysis.get("diagnosis"):
|
| 187 |
-
d = inst.analysis["diagnosis"]
|
| 188 |
-
diagnosis = d.get("class", "?")
|
| 189 |
-
|
| 190 |
-
label = f"#{i}: {diagnosis}"
|
| 191 |
-
variant = "primary" if inst.id == current else "secondary"
|
| 192 |
-
btn = gr.Button(label, size="sm", variant=variant, elem_classes=["query-item"])
|
| 193 |
-
|
| 194 |
-
# Attach click handler to load this instance
|
| 195 |
-
def load_instance(inst_id=inst.id, c_id=case_id):
|
| 196 |
-
def _load(current_state):
|
| 197 |
-
current_state["instance_id"] = inst_id
|
| 198 |
-
instance = case_store.get_instance(c_id, inst_id)
|
| 199 |
-
|
| 200 |
-
# Load saved output if available
|
| 201 |
-
output_html = '<div class="analysis-output"><div class="result">Previous analysis loaded</div></div>'
|
| 202 |
-
if instance and instance.analysis:
|
| 203 |
-
diag = instance.analysis.get("diagnosis", {})
|
| 204 |
-
output_html = f'<div class="analysis-output"><div class="result">Diagnosis: {diag.get("full_name", diag.get("class", "Unknown"))}</div></div>'
|
| 205 |
-
|
| 206 |
-
return (
|
| 207 |
-
current_state,
|
| 208 |
-
gr.update(visible=False), # view_input
|
| 209 |
-
gr.update(visible=True), # view_results
|
| 210 |
-
output_html,
|
| 211 |
-
gr.update(visible=False) # confirm_box
|
| 212 |
-
)
|
| 213 |
-
return _load
|
| 214 |
-
|
| 215 |
-
btn.click(
|
| 216 |
-
load_instance(),
|
| 217 |
-
inputs=[state],
|
| 218 |
-
outputs=[state, view_input, view_results, output_html, confirm_box]
|
| 219 |
-
)
|
| 220 |
-
|
| 221 |
-
# =========================================================================
|
| 222 |
-
# EVENT HANDLERS
|
| 223 |
-
# =========================================================================
|
| 224 |
-
|
| 225 |
-
def select_patient(case_id: str, s: Dict):
|
| 226 |
-
"""Handle patient selection"""
|
| 227 |
-
s["case_id"] = case_id
|
| 228 |
-
s["page"] = "analysis"
|
| 229 |
-
|
| 230 |
-
instances = case_store.list_instances(case_id)
|
| 231 |
-
has_queries = len(instances) > 0
|
| 232 |
-
|
| 233 |
-
if has_queries:
|
| 234 |
-
# Load most recent
|
| 235 |
-
inst = instances[-1]
|
| 236 |
-
s["instance_id"] = inst.id
|
| 237 |
-
|
| 238 |
-
# Load image if exists
|
| 239 |
-
img = None
|
| 240 |
-
if inst.image_path and os.path.exists(inst.image_path):
|
| 241 |
-
from PIL import Image
|
| 242 |
-
img = Image.open(inst.image_path)
|
| 243 |
-
|
| 244 |
-
return (
|
| 245 |
-
s,
|
| 246 |
-
gr.update(visible=False), # page_patient
|
| 247 |
-
gr.update(visible=True), # page_analysis
|
| 248 |
-
gr.update(visible=True), # sidebar
|
| 249 |
-
gr.update(visible=False), # view_input
|
| 250 |
-
gr.update(visible=True), # view_results
|
| 251 |
-
'<div class="analysis-output"><div class="result">Previous analysis loaded</div></div>',
|
| 252 |
-
gr.update(visible=False) # confirm_box
|
| 253 |
-
)
|
| 254 |
-
else:
|
| 255 |
-
# New instance
|
| 256 |
-
inst = case_store.create_instance(case_id)
|
| 257 |
-
s["instance_id"] = inst.id
|
| 258 |
-
s["output"] = ""
|
| 259 |
-
|
| 260 |
-
return (
|
| 261 |
-
s,
|
| 262 |
-
gr.update(visible=False),
|
| 263 |
-
gr.update(visible=True),
|
| 264 |
-
gr.update(visible=False), # sidebar hidden for new patient
|
| 265 |
-
gr.update(visible=True), # view_input
|
| 266 |
-
gr.update(visible=False), # view_results
|
| 267 |
-
"",
|
| 268 |
-
gr.update(visible=False)
|
| 269 |
-
)
|
| 270 |
-
|
| 271 |
-
def new_patient(s: Dict):
|
| 272 |
-
"""Create new patient"""
|
| 273 |
-
case = case_store.create_case(f"Patient {datetime.now().strftime('%Y-%m-%d %H:%M')}")
|
| 274 |
-
return select_patient(case.id, s)
|
| 275 |
-
|
| 276 |
-
def go_back(s: Dict):
|
| 277 |
-
"""Return to patient selection"""
|
| 278 |
-
s["page"] = "patient_select"
|
| 279 |
-
s["case_id"] = None
|
| 280 |
-
s["instance_id"] = None
|
| 281 |
-
s["output"] = ""
|
| 282 |
-
|
| 283 |
-
return (
|
| 284 |
-
s,
|
| 285 |
-
gr.update(visible=True), # page_patient
|
| 286 |
-
gr.update(visible=False), # page_analysis
|
| 287 |
-
gr.update(visible=False), # sidebar
|
| 288 |
-
gr.update(visible=True), # view_input
|
| 289 |
-
gr.update(visible=False), # view_results
|
| 290 |
-
"",
|
| 291 |
-
gr.update(visible=False) # confirm_box
|
| 292 |
-
)
|
| 293 |
-
|
| 294 |
-
def new_query(s: Dict):
|
| 295 |
-
"""Start new query for current patient"""
|
| 296 |
-
case_id = s.get("case_id")
|
| 297 |
-
if not case_id:
|
| 298 |
-
return s, gr.update(), gr.update(), gr.update(), "", gr.update()
|
| 299 |
-
|
| 300 |
-
inst = case_store.create_instance(case_id)
|
| 301 |
-
s["instance_id"] = inst.id
|
| 302 |
-
s["output"] = ""
|
| 303 |
-
s["gradcam_base64"] = None
|
| 304 |
-
|
| 305 |
-
agent.reset()
|
| 306 |
-
|
| 307 |
-
return (
|
| 308 |
-
s,
|
| 309 |
-
gr.update(visible=True), # view_input
|
| 310 |
-
gr.update(visible=False), # view_results
|
| 311 |
-
None, # clear image
|
| 312 |
-
"", # clear output
|
| 313 |
-
gr.update(visible=False) # confirm_box
|
| 314 |
-
)
|
| 315 |
-
|
| 316 |
-
def enable_analyze(img):
|
| 317 |
-
"""Enable analyze button when image uploaded"""
|
| 318 |
-
return gr.update(interactive=img is not None)
|
| 319 |
-
|
| 320 |
-
def run_analysis(image, message, s: Dict):
|
| 321 |
-
"""Run analysis on uploaded image"""
|
| 322 |
-
if image is None:
|
| 323 |
-
yield s, gr.update(), gr.update(), gr.update(), gr.update()
|
| 324 |
-
return
|
| 325 |
-
|
| 326 |
-
case_id = s["case_id"]
|
| 327 |
-
instance_id = s["instance_id"]
|
| 328 |
-
|
| 329 |
-
# Save image
|
| 330 |
-
image_path = case_store.save_image(case_id, instance_id, image)
|
| 331 |
-
case_store.update_analysis(case_id, instance_id, stage="analyzing", image_path=image_path)
|
| 332 |
-
|
| 333 |
-
agent.reset()
|
| 334 |
-
s["output"] = ""
|
| 335 |
-
gradcam_base64 = None
|
| 336 |
-
has_confirm = False
|
| 337 |
-
|
| 338 |
-
# Switch to results view
|
| 339 |
-
yield (
|
| 340 |
-
s,
|
| 341 |
-
gr.update(visible=False), # view_input
|
| 342 |
-
gr.update(visible=True), # view_results
|
| 343 |
-
'<div class="analysis-output">Starting analysis...</div>',
|
| 344 |
-
gr.update(visible=False) # confirm_box
|
| 345 |
-
)
|
| 346 |
-
|
| 347 |
-
partial = ""
|
| 348 |
-
for chunk in agent.analyze(image_path, message or ""):
|
| 349 |
-
partial += chunk
|
| 350 |
-
|
| 351 |
-
# Check for GradCAM
|
| 352 |
-
if gradcam_base64 is None:
|
| 353 |
-
match = re.search(r'\[GRADCAM_IMAGE:([^\]]+)\]', partial)
|
| 354 |
-
if match:
|
| 355 |
-
path = match.group(1)
|
| 356 |
-
if os.path.exists(path):
|
| 357 |
-
try:
|
| 358 |
-
with open(path, "rb") as f:
|
| 359 |
-
gradcam_base64 = base64.b64encode(f.read()).decode('utf-8')
|
| 360 |
-
s["gradcam_base64"] = gradcam_base64
|
| 361 |
-
except:
|
| 362 |
-
pass
|
| 363 |
-
|
| 364 |
-
if '[CONFIRM:' in partial:
|
| 365 |
-
has_confirm = True
|
| 366 |
-
|
| 367 |
-
s["output"] = partial
|
| 368 |
-
|
| 369 |
-
yield (
|
| 370 |
-
s,
|
| 371 |
-
gr.update(visible=False),
|
| 372 |
-
gr.update(visible=True),
|
| 373 |
-
format_output(partial, gradcam_base64),
|
| 374 |
-
gr.update(visible=has_confirm)
|
| 375 |
-
)
|
| 376 |
-
|
| 377 |
-
# Save analysis
|
| 378 |
-
if agent.model and agent.model.last_diagnosis:
|
| 379 |
-
diag = agent.model.last_diagnosis["predictions"][0]
|
| 380 |
-
case_store.update_analysis(
|
| 381 |
-
case_id, instance_id,
|
| 382 |
-
stage="awaiting_confirmation",
|
| 383 |
-
analysis={"diagnosis": diag}
|
| 384 |
-
)
|
| 385 |
-
|
| 386 |
-
def confirm_yes(s: Dict):
|
| 387 |
-
"""User confirmed diagnosis"""
|
| 388 |
-
partial = s.get("output", "")
|
| 389 |
-
gradcam = s.get("gradcam_base64")
|
| 390 |
-
|
| 391 |
-
for chunk in agent.management_guidance(confirmed=True):
|
| 392 |
-
partial += chunk
|
| 393 |
-
s["output"] = partial
|
| 394 |
-
yield s, format_output(partial, gradcam), gr.update(visible=False)
|
| 395 |
-
|
| 396 |
-
case_store.update_analysis(s["case_id"], s["instance_id"], stage="complete")
|
| 397 |
-
|
| 398 |
-
def confirm_no():
|
| 399 |
-
"""Show feedback input"""
|
| 400 |
-
return gr.update(visible=True), gr.update(visible=True)
|
| 401 |
-
|
| 402 |
-
def submit_feedback(feedback: str, s: Dict):
|
| 403 |
-
"""Submit user feedback"""
|
| 404 |
-
partial = s.get("output", "")
|
| 405 |
-
gradcam = s.get("gradcam_base64")
|
| 406 |
-
|
| 407 |
-
for chunk in agent.management_guidance(confirmed=False, feedback=feedback):
|
| 408 |
-
partial += chunk
|
| 409 |
-
s["output"] = partial
|
| 410 |
-
yield (
|
| 411 |
-
s,
|
| 412 |
-
format_output(partial, gradcam),
|
| 413 |
-
gr.update(visible=False),
|
| 414 |
-
gr.update(visible=False),
|
| 415 |
-
gr.update(visible=False),
|
| 416 |
-
""
|
| 417 |
-
)
|
| 418 |
-
|
| 419 |
-
case_store.update_analysis(s["case_id"], s["instance_id"], stage="complete")
|
| 420 |
-
|
| 421 |
-
def send_followup(message: str, s: Dict):
|
| 422 |
-
"""Send follow-up question"""
|
| 423 |
-
if not message.strip():
|
| 424 |
-
return s, gr.update(), ""
|
| 425 |
-
|
| 426 |
-
case_store.add_chat_message(s["case_id"], s["instance_id"], "user", message)
|
| 427 |
-
|
| 428 |
-
partial = s.get("output", "")
|
| 429 |
-
gradcam = s.get("gradcam_base64")
|
| 430 |
-
|
| 431 |
-
partial += f'\n<div class="chat-message user">You: {message}</div>\n'
|
| 432 |
-
|
| 433 |
-
response = ""
|
| 434 |
-
for chunk in agent.followup(message):
|
| 435 |
-
response += chunk
|
| 436 |
-
s["output"] = partial + response
|
| 437 |
-
yield s, format_output(partial + response, gradcam), ""
|
| 438 |
-
|
| 439 |
-
case_store.add_chat_message(s["case_id"], s["instance_id"], "assistant", response)
|
| 440 |
-
|
| 441 |
-
# =========================================================================
|
| 442 |
-
# WIRE EVENTS
|
| 443 |
-
# =========================================================================
|
| 444 |
-
|
| 445 |
-
# Patient selection
|
| 446 |
-
btn_demo_melanoma.click(
|
| 447 |
-
lambda s: select_patient("demo-melanoma", s),
|
| 448 |
-
inputs=[state],
|
| 449 |
-
outputs=[state, page_patient, page_analysis, sidebar, view_input, view_results, output_html, confirm_box]
|
| 450 |
-
)
|
| 451 |
-
|
| 452 |
-
btn_demo_ak.click(
|
| 453 |
-
lambda s: select_patient("demo-ak", s),
|
| 454 |
-
inputs=[state],
|
| 455 |
-
outputs=[state, page_patient, page_analysis, sidebar, view_input, view_results, output_html, confirm_box]
|
| 456 |
-
)
|
| 457 |
-
|
| 458 |
-
btn_new_patient.click(
|
| 459 |
-
new_patient,
|
| 460 |
-
inputs=[state],
|
| 461 |
-
outputs=[state, page_patient, page_analysis, sidebar, view_input, view_results, output_html, confirm_box]
|
| 462 |
-
)
|
| 463 |
-
|
| 464 |
-
# Navigation
|
| 465 |
-
btn_back.click(
|
| 466 |
-
go_back,
|
| 467 |
-
inputs=[state],
|
| 468 |
-
outputs=[state, page_patient, page_analysis, sidebar, view_input, view_results, output_html, confirm_box]
|
| 469 |
-
)
|
| 470 |
-
|
| 471 |
-
btn_new_query.click(
|
| 472 |
-
new_query,
|
| 473 |
-
inputs=[state],
|
| 474 |
-
outputs=[state, view_input, view_results, input_image, output_html, confirm_box]
|
| 475 |
-
)
|
| 476 |
-
|
| 477 |
-
# Analysis
|
| 478 |
-
input_image.change(enable_analyze, inputs=[input_image], outputs=[btn_analyze])
|
| 479 |
-
|
| 480 |
-
btn_analyze.click(
|
| 481 |
-
run_analysis,
|
| 482 |
-
inputs=[input_image, input_message, state],
|
| 483 |
-
outputs=[state, view_input, view_results, output_html, confirm_box]
|
| 484 |
-
)
|
| 485 |
-
|
| 486 |
-
# Confirmation
|
| 487 |
-
btn_confirm_yes.click(
|
| 488 |
-
confirm_yes,
|
| 489 |
-
inputs=[state],
|
| 490 |
-
outputs=[state, output_html, confirm_box]
|
| 491 |
-
)
|
| 492 |
-
|
| 493 |
-
btn_confirm_no.click(
|
| 494 |
-
confirm_no,
|
| 495 |
-
outputs=[input_feedback, btn_submit_feedback]
|
| 496 |
-
)
|
| 497 |
-
|
| 498 |
-
btn_submit_feedback.click(
|
| 499 |
-
submit_feedback,
|
| 500 |
-
inputs=[input_feedback, state],
|
| 501 |
-
outputs=[state, output_html, confirm_box, input_feedback, btn_submit_feedback, input_feedback]
|
| 502 |
-
)
|
| 503 |
-
|
| 504 |
-
# Follow-up
|
| 505 |
-
btn_followup.click(
|
| 506 |
-
send_followup,
|
| 507 |
-
inputs=[input_followup, state],
|
| 508 |
-
outputs=[state, output_html, input_followup]
|
| 509 |
-
)
|
| 510 |
-
|
| 511 |
-
input_followup.submit(
|
| 512 |
-
send_followup,
|
| 513 |
-
inputs=[input_followup, state],
|
| 514 |
-
outputs=[state, output_html, input_followup]
|
| 515 |
-
)
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
# =============================================================================
|
| 519 |
-
# MAIN
|
| 520 |
-
# =============================================================================
|
| 521 |
-
|
| 522 |
-
if __name__ == "__main__":
|
| 523 |
-
print(f"\n{'='*50}")
|
| 524 |
-
print(f" {Config.APP_TITLE}")
|
| 525 |
-
print(f"{'='*50}\n")
|
| 526 |
-
|
| 527 |
-
app.queue().launch(
|
| 528 |
-
server_name="0.0.0.0" if Config.HF_SPACES else "127.0.0.1",
|
| 529 |
-
server_port=Config.SERVER_PORT,
|
| 530 |
-
share=False,
|
| 531 |
-
show_error=True
|
| 532 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
frontend/components/__init__.py
DELETED
|
File without changes
|
frontend/components/analysis_view.py
DELETED
|
@@ -1,214 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Analysis View Component - Main analysis interface with input and results
|
| 3 |
-
"""
|
| 4 |
-
|
| 5 |
-
import gradio as gr
|
| 6 |
-
import re
|
| 7 |
-
from typing import Optional
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
def parse_markdown(text: str) -> str:
|
| 11 |
-
"""Convert basic markdown to HTML"""
|
| 12 |
-
text = re.sub(r'\*\*(.+?)\*\*', r'<strong>\1</strong>', text)
|
| 13 |
-
text = re.sub(r'__(.+?)__', r'<strong>\1</strong>', text)
|
| 14 |
-
text = re.sub(r'\*(.+?)\*', r'<em>\1</em>', text)
|
| 15 |
-
|
| 16 |
-
# Bullet lists
|
| 17 |
-
lines = text.split('\n')
|
| 18 |
-
in_list = False
|
| 19 |
-
result = []
|
| 20 |
-
for line in lines:
|
| 21 |
-
stripped = line.strip()
|
| 22 |
-
if re.match(r'^[\*\-] ', stripped):
|
| 23 |
-
if not in_list:
|
| 24 |
-
result.append('<ul>')
|
| 25 |
-
in_list = True
|
| 26 |
-
item = re.sub(r'^[\*\-] ', '', stripped)
|
| 27 |
-
result.append(f'<li>{item}</li>')
|
| 28 |
-
else:
|
| 29 |
-
if in_list:
|
| 30 |
-
result.append('</ul>')
|
| 31 |
-
in_list = False
|
| 32 |
-
result.append(line)
|
| 33 |
-
if in_list:
|
| 34 |
-
result.append('</ul>')
|
| 35 |
-
|
| 36 |
-
return '\n'.join(result)
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
# Regex patterns for output parsing
|
| 40 |
-
_STAGE_RE = re.compile(r'\[STAGE:(\w+)\](.*?)\[/STAGE\]')
|
| 41 |
-
_THINKING_RE = re.compile(r'\[THINKING\](.*?)\[/THINKING\]')
|
| 42 |
-
_OBSERVATION_RE = re.compile(r'\[OBSERVATION\](.*?)\[/OBSERVATION\]')
|
| 43 |
-
_TOOL_OUTPUT_RE = re.compile(r'\[TOOL_OUTPUT:(.*?)\]\n(.*?)\[/TOOL_OUTPUT\]', re.DOTALL)
|
| 44 |
-
_RESULT_RE = re.compile(r'\[RESULT\](.*?)\[/RESULT\]')
|
| 45 |
-
_ERROR_RE = re.compile(r'\[ERROR\](.*?)\[/ERROR\]')
|
| 46 |
-
_GRADCAM_RE = re.compile(r'\[GRADCAM_IMAGE:[^\]]+\]\n?')
|
| 47 |
-
_RESPONSE_RE = re.compile(r'\[RESPONSE\]\n(.*?)\n\[/RESPONSE\]', re.DOTALL)
|
| 48 |
-
_COMPLETE_RE = re.compile(r'\[COMPLETE\](.*?)\[/COMPLETE\]')
|
| 49 |
-
_CONFIRM_RE = re.compile(r'\[CONFIRM:(\w+)\](.*?)\[/CONFIRM\]')
|
| 50 |
-
_REFERENCES_RE = re.compile(r'\[REFERENCES\](.*?)\[/REFERENCES\]', re.DOTALL)
|
| 51 |
-
_REF_RE = re.compile(r'\[REF:([^:]+):([^:]+):([^:]+):([^:]+):([^\]]+)\]')
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
def format_output(raw_text: str, gradcam_base64: Optional[str] = None) -> str:
|
| 55 |
-
"""Convert tagged output to styled HTML"""
|
| 56 |
-
html = raw_text
|
| 57 |
-
|
| 58 |
-
# Stage headers
|
| 59 |
-
html = _STAGE_RE.sub(
|
| 60 |
-
r'<div class="stage"><span class="stage-indicator"></span><span class="stage-text">\2</span></div>',
|
| 61 |
-
html
|
| 62 |
-
)
|
| 63 |
-
|
| 64 |
-
# Thinking
|
| 65 |
-
html = _THINKING_RE.sub(r'<div class="thinking">\1</div>', html)
|
| 66 |
-
|
| 67 |
-
# Observations
|
| 68 |
-
html = _OBSERVATION_RE.sub(r'<div class="observation">\1</div>', html)
|
| 69 |
-
|
| 70 |
-
# Tool outputs
|
| 71 |
-
html = _TOOL_OUTPUT_RE.sub(
|
| 72 |
-
r'<div class="tool-output"><div class="tool-header">\1</div><pre class="tool-content">\2</pre></div>',
|
| 73 |
-
html
|
| 74 |
-
)
|
| 75 |
-
|
| 76 |
-
# Results
|
| 77 |
-
html = _RESULT_RE.sub(r'<div class="result">\1</div>', html)
|
| 78 |
-
|
| 79 |
-
# Errors
|
| 80 |
-
html = _ERROR_RE.sub(r'<div class="error">\1</div>', html)
|
| 81 |
-
|
| 82 |
-
# GradCAM image
|
| 83 |
-
if gradcam_base64:
|
| 84 |
-
img_html = f'<div class="gradcam-inline"><div class="gradcam-header">Attention Map</div><img src="data:image/png;base64,{gradcam_base64}" alt="Grad-CAM"></div>'
|
| 85 |
-
html = _GRADCAM_RE.sub(img_html, html)
|
| 86 |
-
else:
|
| 87 |
-
html = _GRADCAM_RE.sub('', html)
|
| 88 |
-
|
| 89 |
-
# Response section
|
| 90 |
-
def format_response(match):
|
| 91 |
-
content = match.group(1)
|
| 92 |
-
parsed = parse_markdown(content)
|
| 93 |
-
parsed = re.sub(r'\n\n+', '</p><p>', parsed)
|
| 94 |
-
parsed = parsed.replace('\n', '<br>')
|
| 95 |
-
return f'<div class="response"><p>{parsed}</p></div>'
|
| 96 |
-
|
| 97 |
-
html = _RESPONSE_RE.sub(format_response, html)
|
| 98 |
-
|
| 99 |
-
# Complete
|
| 100 |
-
html = _COMPLETE_RE.sub(r'<div class="complete">\1</div>', html)
|
| 101 |
-
|
| 102 |
-
# Confirmation
|
| 103 |
-
html = _CONFIRM_RE.sub(
|
| 104 |
-
r'<div class="confirm-box"><div class="confirm-text">\2</div></div>',
|
| 105 |
-
html
|
| 106 |
-
)
|
| 107 |
-
|
| 108 |
-
# References
|
| 109 |
-
def format_references(match):
|
| 110 |
-
ref_content = match.group(1)
|
| 111 |
-
refs_html = ['<div class="references"><div class="references-header">References</div><ul>']
|
| 112 |
-
for ref_match in _REF_RE.finditer(ref_content):
|
| 113 |
-
_, source, page, filename, superscript = ref_match.groups()
|
| 114 |
-
refs_html.append(
|
| 115 |
-
f'<li><a href="guidelines/{filename}#page={page}" target="_blank" class="ref-link">'
|
| 116 |
-
f'<sup>{superscript}</sup> {source}, p.{page}</a></li>'
|
| 117 |
-
)
|
| 118 |
-
refs_html.append('</ul></div>')
|
| 119 |
-
return '\n'.join(refs_html)
|
| 120 |
-
|
| 121 |
-
html = _REFERENCES_RE.sub(format_references, html)
|
| 122 |
-
|
| 123 |
-
# Convert newlines
|
| 124 |
-
html = html.replace('\n', '<br>')
|
| 125 |
-
|
| 126 |
-
return f'<div class="analysis-output">{html}</div>'
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
def create_analysis_view():
|
| 130 |
-
"""
|
| 131 |
-
Create the analysis view component.
|
| 132 |
-
|
| 133 |
-
Returns:
|
| 134 |
-
Tuple of (container, components dict)
|
| 135 |
-
"""
|
| 136 |
-
with gr.Group(visible=False, elem_classes=["analysis-container"]) as container:
|
| 137 |
-
|
| 138 |
-
with gr.Row():
|
| 139 |
-
# Main content area
|
| 140 |
-
with gr.Column(elem_classes=["main-content"]):
|
| 141 |
-
|
| 142 |
-
# Input greeting (shown when no analysis yet)
|
| 143 |
-
with gr.Group(visible=True, elem_classes=["input-greeting"]) as input_greeting:
|
| 144 |
-
gr.Markdown("What would you like to analyze?", elem_classes=["greeting-title"])
|
| 145 |
-
gr.Markdown("Upload an image and describe what you'd like to know", elem_classes=["greeting-subtitle"])
|
| 146 |
-
|
| 147 |
-
with gr.Column(elem_classes=["input-box-container"]):
|
| 148 |
-
message_input = gr.Textbox(
|
| 149 |
-
placeholder="Describe the lesion or ask a question...",
|
| 150 |
-
show_label=False,
|
| 151 |
-
lines=3,
|
| 152 |
-
elem_classes=["message-input"]
|
| 153 |
-
)
|
| 154 |
-
|
| 155 |
-
# Image upload (compact)
|
| 156 |
-
image_input = gr.Image(
|
| 157 |
-
label="",
|
| 158 |
-
type="pil",
|
| 159 |
-
height=180,
|
| 160 |
-
elem_classes=["image-preview"],
|
| 161 |
-
show_label=False
|
| 162 |
-
)
|
| 163 |
-
|
| 164 |
-
with gr.Row(elem_classes=["input-actions"]):
|
| 165 |
-
upload_hint = gr.Markdown("*Upload a skin lesion image above*", visible=True)
|
| 166 |
-
send_btn = gr.Button("Analyze", elem_classes=["send-btn"], interactive=False)
|
| 167 |
-
|
| 168 |
-
# Chat/results view (shown after analysis starts)
|
| 169 |
-
with gr.Group(visible=False, elem_classes=["chat-view"]) as chat_view:
|
| 170 |
-
results_output = gr.HTML(
|
| 171 |
-
value='<div class="analysis-output">Starting analysis...</div>',
|
| 172 |
-
elem_classes=["results-area"]
|
| 173 |
-
)
|
| 174 |
-
|
| 175 |
-
# Confirmation buttons
|
| 176 |
-
with gr.Group(visible=False, elem_classes=["confirm-buttons"]) as confirm_group:
|
| 177 |
-
gr.Markdown("**Do you agree with this diagnosis?**")
|
| 178 |
-
with gr.Row():
|
| 179 |
-
confirm_yes_btn = gr.Button("Yes, continue", variant="primary", size="sm")
|
| 180 |
-
confirm_no_btn = gr.Button("No, I disagree", variant="secondary", size="sm")
|
| 181 |
-
feedback_input = gr.Textbox(
|
| 182 |
-
label="Your assessment",
|
| 183 |
-
placeholder="Enter your diagnosis...",
|
| 184 |
-
visible=False
|
| 185 |
-
)
|
| 186 |
-
submit_feedback_btn = gr.Button("Submit", visible=False, size="sm")
|
| 187 |
-
|
| 188 |
-
# Follow-up input
|
| 189 |
-
with gr.Row(elem_classes=["chat-input-area"]):
|
| 190 |
-
followup_input = gr.Textbox(
|
| 191 |
-
placeholder="Ask a follow-up question...",
|
| 192 |
-
show_label=False,
|
| 193 |
-
lines=1
|
| 194 |
-
)
|
| 195 |
-
followup_btn = gr.Button("Send", size="sm", elem_classes=["send-btn"])
|
| 196 |
-
|
| 197 |
-
components = {
|
| 198 |
-
"input_greeting": input_greeting,
|
| 199 |
-
"chat_view": chat_view,
|
| 200 |
-
"message_input": message_input,
|
| 201 |
-
"image_input": image_input,
|
| 202 |
-
"send_btn": send_btn,
|
| 203 |
-
"results_output": results_output,
|
| 204 |
-
"confirm_group": confirm_group,
|
| 205 |
-
"confirm_yes_btn": confirm_yes_btn,
|
| 206 |
-
"confirm_no_btn": confirm_no_btn,
|
| 207 |
-
"feedback_input": feedback_input,
|
| 208 |
-
"submit_feedback_btn": submit_feedback_btn,
|
| 209 |
-
"followup_input": followup_input,
|
| 210 |
-
"followup_btn": followup_btn,
|
| 211 |
-
"upload_hint": upload_hint
|
| 212 |
-
}
|
| 213 |
-
|
| 214 |
-
return container, components
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
frontend/components/patient_select.py
DELETED
|
@@ -1,48 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Patient Selection Component - Landing page for selecting/creating patients
|
| 3 |
-
"""
|
| 4 |
-
|
| 5 |
-
import gradio as gr
|
| 6 |
-
from typing import Callable, List
|
| 7 |
-
from data.case_store import get_case_store, Case
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
def create_patient_select(on_patient_selected: Callable[[str], None]) -> gr.Group:
|
| 11 |
-
"""
|
| 12 |
-
Create the patient selection page component.
|
| 13 |
-
|
| 14 |
-
Args:
|
| 15 |
-
on_patient_selected: Callback when a patient is selected (receives case_id)
|
| 16 |
-
|
| 17 |
-
Returns:
|
| 18 |
-
gr.Group containing the patient selection UI
|
| 19 |
-
"""
|
| 20 |
-
case_store = get_case_store()
|
| 21 |
-
|
| 22 |
-
with gr.Group(visible=True, elem_classes=["patient-select-container"]) as container:
|
| 23 |
-
gr.Markdown("# SkinProAI", elem_classes=["patient-select-title"])
|
| 24 |
-
gr.Markdown("Select a patient to continue or create a new case", elem_classes=["patient-select-subtitle"])
|
| 25 |
-
|
| 26 |
-
with gr.Column(elem_classes=["patient-grid"]):
|
| 27 |
-
# Demo cases
|
| 28 |
-
demo_melanoma_btn = gr.Button(
|
| 29 |
-
"Demo: Melanocytic Lesion",
|
| 30 |
-
elem_classes=["patient-card"]
|
| 31 |
-
)
|
| 32 |
-
demo_ak_btn = gr.Button(
|
| 33 |
-
"Demo: Actinic Keratosis",
|
| 34 |
-
elem_classes=["patient-card"]
|
| 35 |
-
)
|
| 36 |
-
|
| 37 |
-
# New patient button
|
| 38 |
-
new_patient_btn = gr.Button(
|
| 39 |
-
"+ New Patient",
|
| 40 |
-
elem_classes=["new-patient-btn"]
|
| 41 |
-
)
|
| 42 |
-
|
| 43 |
-
return container, demo_melanoma_btn, demo_ak_btn, new_patient_btn
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
def get_patient_cases() -> List[Case]:
|
| 47 |
-
"""Get list of all patient cases"""
|
| 48 |
-
return get_case_store().list_cases()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
frontend/components/sidebar.py
DELETED
|
@@ -1,55 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Sidebar Component - Shows previous queries for a patient
|
| 3 |
-
"""
|
| 4 |
-
|
| 5 |
-
import gradio as gr
|
| 6 |
-
from datetime import datetime
|
| 7 |
-
from typing import List, Optional
|
| 8 |
-
from data.case_store import get_case_store, Instance
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
def format_query_item(instance: Instance, index: int) -> str:
|
| 12 |
-
"""Format an instance as a query item for display"""
|
| 13 |
-
diagnosis = "Pending"
|
| 14 |
-
if instance.analysis and instance.analysis.get("diagnosis"):
|
| 15 |
-
diag = instance.analysis["diagnosis"]
|
| 16 |
-
diagnosis = diag.get("full_name", diag.get("class", "Unknown"))
|
| 17 |
-
|
| 18 |
-
try:
|
| 19 |
-
dt = datetime.fromisoformat(instance.created_at.replace('Z', '+00:00'))
|
| 20 |
-
date_str = dt.strftime("%b %d, %H:%M")
|
| 21 |
-
except:
|
| 22 |
-
date_str = "Unknown"
|
| 23 |
-
|
| 24 |
-
return f"Query #{index}: {diagnosis} ({date_str})"
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
def create_sidebar():
|
| 28 |
-
"""
|
| 29 |
-
Create the sidebar component for showing previous queries.
|
| 30 |
-
|
| 31 |
-
Returns:
|
| 32 |
-
Tuple of (container, components dict)
|
| 33 |
-
"""
|
| 34 |
-
with gr.Column(visible=False, elem_classes=["query-sidebar"]) as container:
|
| 35 |
-
gr.Markdown("### Previous Queries", elem_classes=["sidebar-header"])
|
| 36 |
-
|
| 37 |
-
# Dynamic list of query buttons
|
| 38 |
-
query_list = gr.Column(elem_id="query-list")
|
| 39 |
-
|
| 40 |
-
# New query button
|
| 41 |
-
new_query_btn = gr.Button("+ New Query", size="sm", variant="primary")
|
| 42 |
-
|
| 43 |
-
components = {
|
| 44 |
-
"query_list": query_list,
|
| 45 |
-
"new_query_btn": new_query_btn
|
| 46 |
-
}
|
| 47 |
-
|
| 48 |
-
return container, components
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
def get_queries_for_case(case_id: str) -> List[Instance]:
|
| 52 |
-
"""Get all instances/queries for a case"""
|
| 53 |
-
if not case_id:
|
| 54 |
-
return []
|
| 55 |
-
return get_case_store().list_instances(case_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
frontend/components/styles.py
DELETED
|
@@ -1,517 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
CSS Styles for SkinProAI components
|
| 3 |
-
"""
|
| 4 |
-
|
| 5 |
-
MAIN_CSS = """
|
| 6 |
-
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600&display=swap');
|
| 7 |
-
|
| 8 |
-
* {
|
| 9 |
-
font-family: 'Inter', sans-serif !important;
|
| 10 |
-
}
|
| 11 |
-
|
| 12 |
-
.gradio-container {
|
| 13 |
-
max-width: 1200px !important;
|
| 14 |
-
margin: 0 auto !important;
|
| 15 |
-
}
|
| 16 |
-
|
| 17 |
-
/* Hide Gradio footer */
|
| 18 |
-
.gradio-container footer { display: none !important; }
|
| 19 |
-
|
| 20 |
-
/* ============================================
|
| 21 |
-
PATIENT SELECTION PAGE
|
| 22 |
-
============================================ */
|
| 23 |
-
|
| 24 |
-
.patient-select-container {
|
| 25 |
-
min-height: 80vh;
|
| 26 |
-
display: flex;
|
| 27 |
-
flex-direction: column;
|
| 28 |
-
align-items: center;
|
| 29 |
-
justify-content: center;
|
| 30 |
-
padding: 40px 20px;
|
| 31 |
-
}
|
| 32 |
-
|
| 33 |
-
.patient-select-title {
|
| 34 |
-
font-size: 32px;
|
| 35 |
-
font-weight: 600;
|
| 36 |
-
color: #111827;
|
| 37 |
-
margin-bottom: 8px;
|
| 38 |
-
text-align: center;
|
| 39 |
-
}
|
| 40 |
-
|
| 41 |
-
.patient-select-subtitle {
|
| 42 |
-
font-size: 16px;
|
| 43 |
-
color: #6b7280;
|
| 44 |
-
margin-bottom: 40px;
|
| 45 |
-
text-align: center;
|
| 46 |
-
}
|
| 47 |
-
|
| 48 |
-
.patient-grid {
|
| 49 |
-
display: flex;
|
| 50 |
-
gap: 20px;
|
| 51 |
-
flex-wrap: wrap;
|
| 52 |
-
justify-content: center;
|
| 53 |
-
max-width: 800px;
|
| 54 |
-
}
|
| 55 |
-
|
| 56 |
-
.patient-card {
|
| 57 |
-
background: white !important;
|
| 58 |
-
border: 2px solid #e5e7eb !important;
|
| 59 |
-
border-radius: 16px !important;
|
| 60 |
-
padding: 24px 32px !important;
|
| 61 |
-
min-width: 200px !important;
|
| 62 |
-
cursor: pointer;
|
| 63 |
-
transition: all 0.2s ease !important;
|
| 64 |
-
}
|
| 65 |
-
|
| 66 |
-
.patient-card:hover {
|
| 67 |
-
border-color: #6366f1 !important;
|
| 68 |
-
box-shadow: 0 8px 25px rgba(99, 102, 241, 0.15) !important;
|
| 69 |
-
transform: translateY(-2px);
|
| 70 |
-
}
|
| 71 |
-
|
| 72 |
-
.new-patient-btn {
|
| 73 |
-
background: #6366f1 !important;
|
| 74 |
-
color: white !important;
|
| 75 |
-
border: none !important;
|
| 76 |
-
border-radius: 12px !important;
|
| 77 |
-
padding: 16px 32px !important;
|
| 78 |
-
font-weight: 500 !important;
|
| 79 |
-
margin-top: 24px;
|
| 80 |
-
}
|
| 81 |
-
|
| 82 |
-
.new-patient-btn:hover {
|
| 83 |
-
background: #4f46e5 !important;
|
| 84 |
-
}
|
| 85 |
-
|
| 86 |
-
/* ============================================
|
| 87 |
-
ANALYSIS PAGE - MAIN LAYOUT
|
| 88 |
-
============================================ */
|
| 89 |
-
|
| 90 |
-
.analysis-container {
|
| 91 |
-
display: flex;
|
| 92 |
-
height: calc(100vh - 80px);
|
| 93 |
-
min-height: 600px;
|
| 94 |
-
}
|
| 95 |
-
|
| 96 |
-
/* Sidebar */
|
| 97 |
-
.query-sidebar {
|
| 98 |
-
width: 280px;
|
| 99 |
-
background: #f9fafb;
|
| 100 |
-
border-right: 1px solid #e5e7eb;
|
| 101 |
-
padding: 20px;
|
| 102 |
-
overflow-y: auto;
|
| 103 |
-
flex-shrink: 0;
|
| 104 |
-
}
|
| 105 |
-
|
| 106 |
-
.sidebar-header {
|
| 107 |
-
font-size: 14px;
|
| 108 |
-
font-weight: 600;
|
| 109 |
-
color: #374151;
|
| 110 |
-
margin-bottom: 16px;
|
| 111 |
-
padding-bottom: 12px;
|
| 112 |
-
border-bottom: 1px solid #e5e7eb;
|
| 113 |
-
}
|
| 114 |
-
|
| 115 |
-
.query-item {
|
| 116 |
-
background: white;
|
| 117 |
-
border: 1px solid #e5e7eb;
|
| 118 |
-
border-radius: 8px;
|
| 119 |
-
padding: 12px;
|
| 120 |
-
margin-bottom: 8px;
|
| 121 |
-
cursor: pointer;
|
| 122 |
-
transition: all 0.15s;
|
| 123 |
-
}
|
| 124 |
-
|
| 125 |
-
.query-item:hover {
|
| 126 |
-
border-color: #6366f1;
|
| 127 |
-
background: #f5f3ff;
|
| 128 |
-
}
|
| 129 |
-
|
| 130 |
-
.query-item-title {
|
| 131 |
-
font-size: 13px;
|
| 132 |
-
font-weight: 500;
|
| 133 |
-
color: #111827;
|
| 134 |
-
margin-bottom: 4px;
|
| 135 |
-
}
|
| 136 |
-
|
| 137 |
-
.query-item-meta {
|
| 138 |
-
font-size: 11px;
|
| 139 |
-
color: #6b7280;
|
| 140 |
-
}
|
| 141 |
-
|
| 142 |
-
/* Main content area */
|
| 143 |
-
.main-content {
|
| 144 |
-
flex: 1;
|
| 145 |
-
display: flex;
|
| 146 |
-
flex-direction: column;
|
| 147 |
-
padding: 24px;
|
| 148 |
-
overflow: hidden;
|
| 149 |
-
}
|
| 150 |
-
|
| 151 |
-
/* ============================================
|
| 152 |
-
INPUT AREA (Greeting style)
|
| 153 |
-
============================================ */
|
| 154 |
-
|
| 155 |
-
.input-greeting {
|
| 156 |
-
flex: 1;
|
| 157 |
-
display: flex;
|
| 158 |
-
flex-direction: column;
|
| 159 |
-
align-items: center;
|
| 160 |
-
justify-content: center;
|
| 161 |
-
padding: 40px;
|
| 162 |
-
}
|
| 163 |
-
|
| 164 |
-
.greeting-title {
|
| 165 |
-
font-size: 24px;
|
| 166 |
-
font-weight: 600;
|
| 167 |
-
color: #111827;
|
| 168 |
-
margin-bottom: 8px;
|
| 169 |
-
}
|
| 170 |
-
|
| 171 |
-
.greeting-subtitle {
|
| 172 |
-
font-size: 14px;
|
| 173 |
-
color: #6b7280;
|
| 174 |
-
margin-bottom: 32px;
|
| 175 |
-
}
|
| 176 |
-
|
| 177 |
-
.input-box-container {
|
| 178 |
-
width: 100%;
|
| 179 |
-
max-width: 600px;
|
| 180 |
-
background: white;
|
| 181 |
-
border: 2px solid #e5e7eb;
|
| 182 |
-
border-radius: 16px;
|
| 183 |
-
padding: 20px;
|
| 184 |
-
transition: border-color 0.2s;
|
| 185 |
-
}
|
| 186 |
-
|
| 187 |
-
.input-box-container:focus-within {
|
| 188 |
-
border-color: #6366f1;
|
| 189 |
-
}
|
| 190 |
-
|
| 191 |
-
.message-input textarea {
|
| 192 |
-
border: none !important;
|
| 193 |
-
resize: none !important;
|
| 194 |
-
font-size: 15px !important;
|
| 195 |
-
line-height: 1.5 !important;
|
| 196 |
-
padding: 0 !important;
|
| 197 |
-
}
|
| 198 |
-
|
| 199 |
-
.message-input textarea:focus {
|
| 200 |
-
box-shadow: none !important;
|
| 201 |
-
}
|
| 202 |
-
|
| 203 |
-
.input-actions {
|
| 204 |
-
display: flex;
|
| 205 |
-
align-items: center;
|
| 206 |
-
justify-content: space-between;
|
| 207 |
-
margin-top: 16px;
|
| 208 |
-
padding-top: 16px;
|
| 209 |
-
border-top: 1px solid #f3f4f6;
|
| 210 |
-
}
|
| 211 |
-
|
| 212 |
-
.upload-btn {
|
| 213 |
-
background: #f3f4f6 !important;
|
| 214 |
-
color: #374151 !important;
|
| 215 |
-
border: 1px solid #e5e7eb !important;
|
| 216 |
-
border-radius: 8px !important;
|
| 217 |
-
padding: 8px 16px !important;
|
| 218 |
-
font-size: 13px !important;
|
| 219 |
-
}
|
| 220 |
-
|
| 221 |
-
.upload-btn:hover {
|
| 222 |
-
background: #e5e7eb !important;
|
| 223 |
-
}
|
| 224 |
-
|
| 225 |
-
.send-btn {
|
| 226 |
-
background: #6366f1 !important;
|
| 227 |
-
color: white !important;
|
| 228 |
-
border: none !important;
|
| 229 |
-
border-radius: 8px !important;
|
| 230 |
-
padding: 10px 24px !important;
|
| 231 |
-
font-weight: 500 !important;
|
| 232 |
-
}
|
| 233 |
-
|
| 234 |
-
.send-btn:hover {
|
| 235 |
-
background: #4f46e5 !important;
|
| 236 |
-
}
|
| 237 |
-
|
| 238 |
-
.send-btn:disabled {
|
| 239 |
-
background: #d1d5db !important;
|
| 240 |
-
cursor: not-allowed;
|
| 241 |
-
}
|
| 242 |
-
|
| 243 |
-
/* Image preview */
|
| 244 |
-
.image-preview {
|
| 245 |
-
margin-top: 16px;
|
| 246 |
-
border-radius: 12px;
|
| 247 |
-
overflow: hidden;
|
| 248 |
-
max-height: 200px;
|
| 249 |
-
}
|
| 250 |
-
|
| 251 |
-
.image-preview img {
|
| 252 |
-
max-height: 200px;
|
| 253 |
-
object-fit: contain;
|
| 254 |
-
}
|
| 255 |
-
|
| 256 |
-
/* ============================================
|
| 257 |
-
CHAT/RESULTS VIEW
|
| 258 |
-
============================================ */
|
| 259 |
-
|
| 260 |
-
.chat-view {
|
| 261 |
-
flex: 1;
|
| 262 |
-
display: flex;
|
| 263 |
-
flex-direction: column;
|
| 264 |
-
overflow: hidden;
|
| 265 |
-
}
|
| 266 |
-
|
| 267 |
-
.results-area {
|
| 268 |
-
flex: 1;
|
| 269 |
-
overflow-y: auto;
|
| 270 |
-
padding: 20px;
|
| 271 |
-
background: #ffffff;
|
| 272 |
-
border: 1px solid #e5e7eb;
|
| 273 |
-
border-radius: 12px;
|
| 274 |
-
margin-bottom: 16px;
|
| 275 |
-
}
|
| 276 |
-
|
| 277 |
-
/* Analysis output styling */
|
| 278 |
-
.analysis-output {
|
| 279 |
-
line-height: 1.6;
|
| 280 |
-
color: #333;
|
| 281 |
-
}
|
| 282 |
-
|
| 283 |
-
.stage {
|
| 284 |
-
display: flex;
|
| 285 |
-
align-items: center;
|
| 286 |
-
gap: 10px;
|
| 287 |
-
padding: 8px 0;
|
| 288 |
-
font-weight: 500;
|
| 289 |
-
color: #1a1a1a;
|
| 290 |
-
margin-top: 12px;
|
| 291 |
-
}
|
| 292 |
-
|
| 293 |
-
.stage-indicator {
|
| 294 |
-
width: 8px;
|
| 295 |
-
height: 8px;
|
| 296 |
-
background: #6366f1;
|
| 297 |
-
border-radius: 50%;
|
| 298 |
-
animation: pulse 1.5s ease-in-out infinite;
|
| 299 |
-
}
|
| 300 |
-
|
| 301 |
-
@keyframes pulse {
|
| 302 |
-
0%, 100% { opacity: 1; transform: scale(1); }
|
| 303 |
-
50% { opacity: 0.5; transform: scale(0.8); }
|
| 304 |
-
}
|
| 305 |
-
|
| 306 |
-
.thinking {
|
| 307 |
-
color: #6b7280;
|
| 308 |
-
font-style: italic;
|
| 309 |
-
font-size: 13px;
|
| 310 |
-
padding: 4px 0 4px 16px;
|
| 311 |
-
border-left: 2px solid #e5e7eb;
|
| 312 |
-
margin: 4px 0;
|
| 313 |
-
}
|
| 314 |
-
|
| 315 |
-
.observation {
|
| 316 |
-
color: #374151;
|
| 317 |
-
font-size: 13px;
|
| 318 |
-
padding: 4px 0 4px 16px;
|
| 319 |
-
}
|
| 320 |
-
|
| 321 |
-
.tool-output {
|
| 322 |
-
background: #f8fafc;
|
| 323 |
-
border-radius: 8px;
|
| 324 |
-
margin: 12px 0;
|
| 325 |
-
overflow: hidden;
|
| 326 |
-
border: 1px solid #e2e8f0;
|
| 327 |
-
}
|
| 328 |
-
|
| 329 |
-
.tool-header {
|
| 330 |
-
background: #f1f5f9;
|
| 331 |
-
padding: 8px 12px;
|
| 332 |
-
font-weight: 500;
|
| 333 |
-
font-size: 13px;
|
| 334 |
-
color: #475569;
|
| 335 |
-
border-bottom: 1px solid #e2e8f0;
|
| 336 |
-
}
|
| 337 |
-
|
| 338 |
-
.tool-content {
|
| 339 |
-
padding: 12px;
|
| 340 |
-
margin: 0;
|
| 341 |
-
font-family: 'SF Mono', Monaco, monospace !important;
|
| 342 |
-
font-size: 12px;
|
| 343 |
-
line-height: 1.5;
|
| 344 |
-
white-space: pre-wrap;
|
| 345 |
-
color: #334155;
|
| 346 |
-
}
|
| 347 |
-
|
| 348 |
-
.result {
|
| 349 |
-
background: #ecfdf5;
|
| 350 |
-
border: 1px solid #a7f3d0;
|
| 351 |
-
border-radius: 8px;
|
| 352 |
-
padding: 12px 16px;
|
| 353 |
-
margin: 12px 0;
|
| 354 |
-
font-weight: 500;
|
| 355 |
-
color: #065f46;
|
| 356 |
-
}
|
| 357 |
-
|
| 358 |
-
.error {
|
| 359 |
-
background: #fef2f2;
|
| 360 |
-
border: 1px solid #fecaca;
|
| 361 |
-
border-radius: 8px;
|
| 362 |
-
padding: 12px 16px;
|
| 363 |
-
margin: 8px 0;
|
| 364 |
-
color: #b91c1c;
|
| 365 |
-
}
|
| 366 |
-
|
| 367 |
-
.response {
|
| 368 |
-
background: #ffffff;
|
| 369 |
-
border: 1px solid #e5e7eb;
|
| 370 |
-
border-radius: 8px;
|
| 371 |
-
padding: 16px;
|
| 372 |
-
margin: 16px 0;
|
| 373 |
-
line-height: 1.7;
|
| 374 |
-
}
|
| 375 |
-
|
| 376 |
-
.response ul, .response ol {
|
| 377 |
-
margin: 8px 0;
|
| 378 |
-
padding-left: 24px;
|
| 379 |
-
}
|
| 380 |
-
|
| 381 |
-
.response li {
|
| 382 |
-
margin: 4px 0;
|
| 383 |
-
}
|
| 384 |
-
|
| 385 |
-
.complete {
|
| 386 |
-
color: #6b7280;
|
| 387 |
-
font-size: 12px;
|
| 388 |
-
padding: 8px 0;
|
| 389 |
-
text-align: center;
|
| 390 |
-
}
|
| 391 |
-
|
| 392 |
-
/* Confirmation */
|
| 393 |
-
.confirm-box {
|
| 394 |
-
background: #eff6ff;
|
| 395 |
-
border: 1px solid #bfdbfe;
|
| 396 |
-
border-radius: 8px;
|
| 397 |
-
padding: 16px;
|
| 398 |
-
margin: 16px 0;
|
| 399 |
-
text-align: center;
|
| 400 |
-
}
|
| 401 |
-
|
| 402 |
-
.confirm-buttons {
|
| 403 |
-
background: #f0f9ff;
|
| 404 |
-
border: 1px solid #bae6fd;
|
| 405 |
-
border-radius: 8px;
|
| 406 |
-
padding: 12px;
|
| 407 |
-
margin-top: 12px;
|
| 408 |
-
}
|
| 409 |
-
|
| 410 |
-
/* References */
|
| 411 |
-
.references {
|
| 412 |
-
background: #f9fafb;
|
| 413 |
-
border: 1px solid #e5e7eb;
|
| 414 |
-
border-radius: 8px;
|
| 415 |
-
margin: 16px 0;
|
| 416 |
-
overflow: hidden;
|
| 417 |
-
}
|
| 418 |
-
|
| 419 |
-
.references-header {
|
| 420 |
-
background: #f3f4f6;
|
| 421 |
-
padding: 8px 12px;
|
| 422 |
-
font-weight: 500;
|
| 423 |
-
font-size: 13px;
|
| 424 |
-
border-bottom: 1px solid #e5e7eb;
|
| 425 |
-
}
|
| 426 |
-
|
| 427 |
-
.references ul {
|
| 428 |
-
list-style: none;
|
| 429 |
-
padding: 12px;
|
| 430 |
-
margin: 0;
|
| 431 |
-
}
|
| 432 |
-
|
| 433 |
-
.ref-link {
|
| 434 |
-
color: #6366f1;
|
| 435 |
-
text-decoration: none;
|
| 436 |
-
font-size: 13px;
|
| 437 |
-
}
|
| 438 |
-
|
| 439 |
-
.ref-link:hover {
|
| 440 |
-
text-decoration: underline;
|
| 441 |
-
}
|
| 442 |
-
|
| 443 |
-
/* GradCAM */
|
| 444 |
-
.gradcam-inline {
|
| 445 |
-
margin: 16px 0;
|
| 446 |
-
background: #f8fafc;
|
| 447 |
-
border-radius: 8px;
|
| 448 |
-
overflow: hidden;
|
| 449 |
-
border: 1px solid #e2e8f0;
|
| 450 |
-
}
|
| 451 |
-
|
| 452 |
-
.gradcam-header {
|
| 453 |
-
background: #f1f5f9;
|
| 454 |
-
padding: 8px 12px;
|
| 455 |
-
font-weight: 500;
|
| 456 |
-
font-size: 13px;
|
| 457 |
-
border-bottom: 1px solid #e2e8f0;
|
| 458 |
-
}
|
| 459 |
-
|
| 460 |
-
.gradcam-inline img {
|
| 461 |
-
max-width: 100%;
|
| 462 |
-
max-height: 300px;
|
| 463 |
-
display: block;
|
| 464 |
-
margin: 12px auto;
|
| 465 |
-
}
|
| 466 |
-
|
| 467 |
-
/* Chat input at bottom */
|
| 468 |
-
.chat-input-area {
|
| 469 |
-
background: white;
|
| 470 |
-
border: 1px solid #e5e7eb;
|
| 471 |
-
border-radius: 12px;
|
| 472 |
-
padding: 12px 16px;
|
| 473 |
-
display: flex;
|
| 474 |
-
gap: 12px;
|
| 475 |
-
align-items: flex-end;
|
| 476 |
-
}
|
| 477 |
-
|
| 478 |
-
.chat-input-area textarea {
|
| 479 |
-
flex: 1;
|
| 480 |
-
border: none !important;
|
| 481 |
-
resize: none !important;
|
| 482 |
-
font-size: 14px !important;
|
| 483 |
-
}
|
| 484 |
-
|
| 485 |
-
/* ============================================
|
| 486 |
-
HEADER
|
| 487 |
-
============================================ */
|
| 488 |
-
|
| 489 |
-
.app-header {
|
| 490 |
-
display: flex;
|
| 491 |
-
align-items: center;
|
| 492 |
-
justify-content: space-between;
|
| 493 |
-
padding: 16px 24px;
|
| 494 |
-
border-bottom: 1px solid #e5e7eb;
|
| 495 |
-
background: white;
|
| 496 |
-
}
|
| 497 |
-
|
| 498 |
-
.app-title {
|
| 499 |
-
font-size: 20px;
|
| 500 |
-
font-weight: 600;
|
| 501 |
-
color: #111827;
|
| 502 |
-
}
|
| 503 |
-
|
| 504 |
-
.back-btn {
|
| 505 |
-
background: transparent !important;
|
| 506 |
-
color: #6b7280 !important;
|
| 507 |
-
border: 1px solid #e5e7eb !important;
|
| 508 |
-
border-radius: 8px !important;
|
| 509 |
-
padding: 8px 16px !important;
|
| 510 |
-
font-size: 13px !important;
|
| 511 |
-
}
|
| 512 |
-
|
| 513 |
-
.back-btn:hover {
|
| 514 |
-
background: #f9fafb !important;
|
| 515 |
-
color: #111827 !important;
|
| 516 |
-
}
|
| 517 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/explainability.py
DELETED
|
@@ -1,183 +0,0 @@
|
|
| 1 |
-
# models/explainability.py
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
import torch.nn.functional as F
|
| 5 |
-
import numpy as np
|
| 6 |
-
import cv2
|
| 7 |
-
from typing import Tuple
|
| 8 |
-
from PIL import Image
|
| 9 |
-
|
| 10 |
-
class GradCAM:
|
| 11 |
-
"""
|
| 12 |
-
Gradient-weighted Class Activation Mapping
|
| 13 |
-
Shows which regions of image are important for prediction
|
| 14 |
-
"""
|
| 15 |
-
|
| 16 |
-
def __init__(self, model: torch.nn.Module, target_layer: str = None):
|
| 17 |
-
"""
|
| 18 |
-
Args:
|
| 19 |
-
model: The neural network
|
| 20 |
-
target_layer: Layer name to compute CAM on (usually last conv layer)
|
| 21 |
-
"""
|
| 22 |
-
self.model = model
|
| 23 |
-
self.gradients = None
|
| 24 |
-
self.activations = None
|
| 25 |
-
|
| 26 |
-
# Auto-detect target layer if not specified
|
| 27 |
-
if target_layer is None:
|
| 28 |
-
# Use last ConvNeXt stage
|
| 29 |
-
self.target_layer = model.convnext.stages[-1]
|
| 30 |
-
else:
|
| 31 |
-
self.target_layer = dict(model.named_modules())[target_layer]
|
| 32 |
-
|
| 33 |
-
# Register hooks
|
| 34 |
-
self.target_layer.register_forward_hook(self._save_activation)
|
| 35 |
-
self.target_layer.register_full_backward_hook(self._save_gradient)
|
| 36 |
-
|
| 37 |
-
def _save_activation(self, module, input, output):
|
| 38 |
-
"""Save forward activations"""
|
| 39 |
-
self.activations = output.detach()
|
| 40 |
-
|
| 41 |
-
def _save_gradient(self, module, grad_input, grad_output):
|
| 42 |
-
"""Save backward gradients"""
|
| 43 |
-
self.gradients = grad_output[0].detach()
|
| 44 |
-
|
| 45 |
-
def generate_cam(
|
| 46 |
-
self,
|
| 47 |
-
image: torch.Tensor,
|
| 48 |
-
target_class: int = None
|
| 49 |
-
) -> np.ndarray:
|
| 50 |
-
"""
|
| 51 |
-
Generate Class Activation Map
|
| 52 |
-
|
| 53 |
-
Args:
|
| 54 |
-
image: Input image [1, 3, H, W]
|
| 55 |
-
target_class: Class to generate CAM for (None = predicted class)
|
| 56 |
-
|
| 57 |
-
Returns:
|
| 58 |
-
cam: Activation map [H, W] normalized to 0-1
|
| 59 |
-
"""
|
| 60 |
-
self.model.eval()
|
| 61 |
-
|
| 62 |
-
# Forward pass
|
| 63 |
-
output = self.model(image)
|
| 64 |
-
|
| 65 |
-
# Use predicted class if not specified
|
| 66 |
-
if target_class is None:
|
| 67 |
-
target_class = output.argmax(dim=1).item()
|
| 68 |
-
|
| 69 |
-
# Zero gradients
|
| 70 |
-
self.model.zero_grad()
|
| 71 |
-
|
| 72 |
-
# Backward pass for target class
|
| 73 |
-
output[0, target_class].backward()
|
| 74 |
-
|
| 75 |
-
# Get gradients and activations
|
| 76 |
-
gradients = self.gradients[0] # [C, H, W]
|
| 77 |
-
activations = self.activations[0] # [C, H, W]
|
| 78 |
-
|
| 79 |
-
# Global average pooling of gradients
|
| 80 |
-
weights = gradients.mean(dim=(1, 2)) # [C]
|
| 81 |
-
|
| 82 |
-
# Weighted sum of activations
|
| 83 |
-
cam = torch.zeros(activations.shape[1:], dtype=torch.float32)
|
| 84 |
-
for i, w in enumerate(weights):
|
| 85 |
-
cam += w * activations[i]
|
| 86 |
-
|
| 87 |
-
# ReLU
|
| 88 |
-
cam = F.relu(cam)
|
| 89 |
-
|
| 90 |
-
# Normalize to 0-1
|
| 91 |
-
cam = cam.cpu().numpy()
|
| 92 |
-
cam = cam - cam.min()
|
| 93 |
-
if cam.max() > 0:
|
| 94 |
-
cam = cam / cam.max()
|
| 95 |
-
|
| 96 |
-
return cam
|
| 97 |
-
|
| 98 |
-
def overlay_cam_on_image(
|
| 99 |
-
self,
|
| 100 |
-
image: np.ndarray, # [H, W, 3] RGB
|
| 101 |
-
cam: np.ndarray, # [h, w]
|
| 102 |
-
alpha: float = 0.5,
|
| 103 |
-
colormap: int = cv2.COLORMAP_JET
|
| 104 |
-
) -> np.ndarray:
|
| 105 |
-
"""
|
| 106 |
-
Overlay CAM heatmap on original image
|
| 107 |
-
|
| 108 |
-
Returns:
|
| 109 |
-
overlay: [H, W, 3] RGB image with heatmap
|
| 110 |
-
"""
|
| 111 |
-
H, W = image.shape[:2]
|
| 112 |
-
|
| 113 |
-
# Resize CAM to image size
|
| 114 |
-
cam_resized = cv2.resize(cam, (W, H))
|
| 115 |
-
|
| 116 |
-
# Convert to heatmap
|
| 117 |
-
heatmap = cv2.applyColorMap(
|
| 118 |
-
np.uint8(255 * cam_resized),
|
| 119 |
-
colormap
|
| 120 |
-
)
|
| 121 |
-
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
|
| 122 |
-
|
| 123 |
-
# Blend with original image
|
| 124 |
-
overlay = (alpha * heatmap + (1 - alpha) * image).astype(np.uint8)
|
| 125 |
-
|
| 126 |
-
return overlay
|
| 127 |
-
|
| 128 |
-
class AttentionVisualizer:
|
| 129 |
-
"""Visualize MedSigLIP attention maps"""
|
| 130 |
-
|
| 131 |
-
def __init__(self, model):
|
| 132 |
-
self.model = model
|
| 133 |
-
|
| 134 |
-
def get_attention_maps(self, image: torch.Tensor) -> np.ndarray:
|
| 135 |
-
"""
|
| 136 |
-
Extract attention maps from MedSigLIP
|
| 137 |
-
|
| 138 |
-
Returns:
|
| 139 |
-
attention: [num_heads, H, W] attention weights
|
| 140 |
-
"""
|
| 141 |
-
# Forward pass
|
| 142 |
-
with torch.no_grad():
|
| 143 |
-
_ = self.model(image)
|
| 144 |
-
|
| 145 |
-
# Get last layer attention from MedSigLIP
|
| 146 |
-
# Shape: [batch, num_heads, seq_len, seq_len]
|
| 147 |
-
attention = self.model.medsiglip_features
|
| 148 |
-
|
| 149 |
-
# Average across heads and extract spatial attention
|
| 150 |
-
# This is model-dependent - adjust based on MedSigLIP architecture
|
| 151 |
-
|
| 152 |
-
# Placeholder implementation
|
| 153 |
-
# You'll need to adapt this to your specific MedSigLIP implementation
|
| 154 |
-
return np.random.rand(14, 14) # Placeholder
|
| 155 |
-
|
| 156 |
-
def overlay_attention(
|
| 157 |
-
self,
|
| 158 |
-
image: np.ndarray,
|
| 159 |
-
attention: np.ndarray,
|
| 160 |
-
alpha: float = 0.6
|
| 161 |
-
) -> np.ndarray:
|
| 162 |
-
"""Overlay attention map on image"""
|
| 163 |
-
H, W = image.shape[:2]
|
| 164 |
-
|
| 165 |
-
# Resize attention to image size
|
| 166 |
-
attention_resized = cv2.resize(attention, (W, H))
|
| 167 |
-
|
| 168 |
-
# Normalize
|
| 169 |
-
attention_resized = (attention_resized - attention_resized.min())
|
| 170 |
-
if attention_resized.max() > 0:
|
| 171 |
-
attention_resized = attention_resized / attention_resized.max()
|
| 172 |
-
|
| 173 |
-
# Create colored overlay
|
| 174 |
-
heatmap = cv2.applyColorMap(
|
| 175 |
-
np.uint8(255 * attention_resized),
|
| 176 |
-
cv2.COLORMAP_VIRIDIS
|
| 177 |
-
)
|
| 178 |
-
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
|
| 179 |
-
|
| 180 |
-
# Blend
|
| 181 |
-
overlay = (alpha * heatmap + (1 - alpha) * image).astype(np.uint8)
|
| 182 |
-
|
| 183 |
-
return overlay
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/medsiglip_convnext_fusion.py
DELETED
|
@@ -1,224 +0,0 @@
|
|
| 1 |
-
# models/medsiglip_convnext_fusion.py
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
import torch.nn as nn
|
| 5 |
-
from typing import Dict, List, Tuple, Optional
|
| 6 |
-
import numpy as np
|
| 7 |
-
import timm
|
| 8 |
-
from transformers import AutoModel, AutoProcessor
|
| 9 |
-
|
| 10 |
-
class MedSigLIPConvNeXtFusion(nn.Module):
|
| 11 |
-
"""
|
| 12 |
-
Your trained MedSigLIP-ConvNeXt fusion model from MILK10 challenge
|
| 13 |
-
Supports 11-class skin lesion classification
|
| 14 |
-
"""
|
| 15 |
-
|
| 16 |
-
# Class names from your training
|
| 17 |
-
CLASS_NAMES = [
|
| 18 |
-
'AKIEC', # Actinic Keratoses and Intraepithelial Carcinoma
|
| 19 |
-
'BCC', # Basal Cell Carcinoma
|
| 20 |
-
'BEN_OTH', # Benign Other
|
| 21 |
-
'BKL', # Benign Keratosis-like Lesions
|
| 22 |
-
'DF', # Dermatofibroma
|
| 23 |
-
'INF', # Inflammatory
|
| 24 |
-
'MAL_OTH', # Malignant Other
|
| 25 |
-
'MEL', # Melanoma
|
| 26 |
-
'NV', # Melanocytic Nevi
|
| 27 |
-
'SCCKA', # Squamous Cell Carcinoma and Keratoacanthoma
|
| 28 |
-
'VASC' # Vascular Lesions
|
| 29 |
-
]
|
| 30 |
-
|
| 31 |
-
def __init__(
|
| 32 |
-
self,
|
| 33 |
-
num_classes: int = 11,
|
| 34 |
-
medsiglip_model: str = "google/medsiglip-base",
|
| 35 |
-
convnext_variant: str = "convnext_base",
|
| 36 |
-
fusion_dim: int = 512,
|
| 37 |
-
dropout: float = 0.3,
|
| 38 |
-
metadata_dim: int = 20 # For metadata features
|
| 39 |
-
):
|
| 40 |
-
super().__init__()
|
| 41 |
-
|
| 42 |
-
self.num_classes = num_classes
|
| 43 |
-
|
| 44 |
-
# MedSigLIP Vision Encoder
|
| 45 |
-
print(f"Loading MedSigLIP: {medsiglip_model}")
|
| 46 |
-
self.medsiglip = AutoModel.from_pretrained(medsiglip_model)
|
| 47 |
-
self.medsiglip_processor = AutoProcessor.from_pretrained(medsiglip_model)
|
| 48 |
-
|
| 49 |
-
# ConvNeXt Backbone
|
| 50 |
-
print(f"Loading ConvNeXt: {convnext_variant}")
|
| 51 |
-
self.convnext = timm.create_model(
|
| 52 |
-
convnext_variant,
|
| 53 |
-
pretrained=True,
|
| 54 |
-
num_classes=0,
|
| 55 |
-
global_pool='avg'
|
| 56 |
-
)
|
| 57 |
-
|
| 58 |
-
# Feature dimensions
|
| 59 |
-
self.medsiglip_dim = self.medsiglip.config.hidden_size # 768
|
| 60 |
-
self.convnext_dim = self.convnext.num_features # 1024
|
| 61 |
-
|
| 62 |
-
# Optional metadata branch
|
| 63 |
-
self.use_metadata = metadata_dim > 0
|
| 64 |
-
if self.use_metadata:
|
| 65 |
-
self.metadata_encoder = nn.Sequential(
|
| 66 |
-
nn.Linear(metadata_dim, 64),
|
| 67 |
-
nn.LayerNorm(64),
|
| 68 |
-
nn.GELU(),
|
| 69 |
-
nn.Dropout(0.2),
|
| 70 |
-
nn.Linear(64, 32)
|
| 71 |
-
)
|
| 72 |
-
total_dim = self.medsiglip_dim + self.convnext_dim + 32
|
| 73 |
-
else:
|
| 74 |
-
total_dim = self.medsiglip_dim + self.convnext_dim
|
| 75 |
-
|
| 76 |
-
# Fusion layers
|
| 77 |
-
self.fusion = nn.Sequential(
|
| 78 |
-
nn.Linear(total_dim, fusion_dim),
|
| 79 |
-
nn.LayerNorm(fusion_dim),
|
| 80 |
-
nn.GELU(),
|
| 81 |
-
nn.Dropout(dropout),
|
| 82 |
-
nn.Linear(fusion_dim, fusion_dim // 2),
|
| 83 |
-
nn.LayerNorm(fusion_dim // 2),
|
| 84 |
-
nn.GELU(),
|
| 85 |
-
nn.Dropout(dropout)
|
| 86 |
-
)
|
| 87 |
-
|
| 88 |
-
# Classification head
|
| 89 |
-
self.classifier = nn.Linear(fusion_dim // 2, num_classes)
|
| 90 |
-
|
| 91 |
-
# Store intermediate features for Grad-CAM
|
| 92 |
-
self.convnext_features = None
|
| 93 |
-
self.medsiglip_features = None
|
| 94 |
-
|
| 95 |
-
# Register hooks
|
| 96 |
-
self.convnext.stages[-1].register_forward_hook(self._save_convnext_features)
|
| 97 |
-
|
| 98 |
-
def _save_convnext_features(self, module, input, output):
|
| 99 |
-
"""Hook to save ConvNeXt feature maps for Grad-CAM"""
|
| 100 |
-
self.convnext_features = output
|
| 101 |
-
|
| 102 |
-
def forward(
|
| 103 |
-
self,
|
| 104 |
-
image: torch.Tensor,
|
| 105 |
-
metadata: Optional[torch.Tensor] = None
|
| 106 |
-
) -> torch.Tensor:
|
| 107 |
-
"""
|
| 108 |
-
Forward pass
|
| 109 |
-
|
| 110 |
-
Args:
|
| 111 |
-
image: [B, 3, H, W] tensor
|
| 112 |
-
metadata: [B, metadata_dim] optional metadata features
|
| 113 |
-
|
| 114 |
-
Returns:
|
| 115 |
-
logits: [B, num_classes]
|
| 116 |
-
"""
|
| 117 |
-
# MedSigLIP features
|
| 118 |
-
medsiglip_out = self.medsiglip.vision_model(image)
|
| 119 |
-
medsiglip_features = medsiglip_out.pooler_output # [B, 768]
|
| 120 |
-
|
| 121 |
-
# ConvNeXt features
|
| 122 |
-
convnext_features = self.convnext(image) # [B, 1024]
|
| 123 |
-
|
| 124 |
-
# Concatenate vision features
|
| 125 |
-
fused = torch.cat([medsiglip_features, convnext_features], dim=1)
|
| 126 |
-
|
| 127 |
-
# Add metadata if available
|
| 128 |
-
if self.use_metadata and metadata is not None:
|
| 129 |
-
metadata_features = self.metadata_encoder(metadata)
|
| 130 |
-
fused = torch.cat([fused, metadata_features], dim=1)
|
| 131 |
-
|
| 132 |
-
# Fusion layers
|
| 133 |
-
fused = self.fusion(fused)
|
| 134 |
-
|
| 135 |
-
# Classification
|
| 136 |
-
logits = self.classifier(fused)
|
| 137 |
-
|
| 138 |
-
return logits
|
| 139 |
-
|
| 140 |
-
def predict(
|
| 141 |
-
self,
|
| 142 |
-
image: torch.Tensor,
|
| 143 |
-
metadata: Optional[torch.Tensor] = None,
|
| 144 |
-
top_k: int = 5
|
| 145 |
-
) -> Dict:
|
| 146 |
-
"""
|
| 147 |
-
Get predictions with probabilities
|
| 148 |
-
|
| 149 |
-
Args:
|
| 150 |
-
image: [B, 3, H, W] or [3, H, W]
|
| 151 |
-
metadata: Optional metadata features
|
| 152 |
-
top_k: Number of top predictions
|
| 153 |
-
|
| 154 |
-
Returns:
|
| 155 |
-
Dictionary with predictions and features
|
| 156 |
-
"""
|
| 157 |
-
if image.dim() == 3:
|
| 158 |
-
image = image.unsqueeze(0)
|
| 159 |
-
|
| 160 |
-
self.eval()
|
| 161 |
-
with torch.no_grad():
|
| 162 |
-
logits = self.forward(image, metadata)
|
| 163 |
-
probs = torch.softmax(logits, dim=1)
|
| 164 |
-
|
| 165 |
-
# Top-k predictions
|
| 166 |
-
top_probs, top_indices = torch.topk(
|
| 167 |
-
probs,
|
| 168 |
-
k=min(top_k, self.num_classes),
|
| 169 |
-
dim=1
|
| 170 |
-
)
|
| 171 |
-
|
| 172 |
-
# Format results
|
| 173 |
-
predictions = []
|
| 174 |
-
for i in range(top_probs.size(1)):
|
| 175 |
-
predictions.append({
|
| 176 |
-
'class': self.CLASS_NAMES[top_indices[0, i].item()],
|
| 177 |
-
'probability': top_probs[0, i].item(),
|
| 178 |
-
'class_idx': top_indices[0, i].item()
|
| 179 |
-
})
|
| 180 |
-
|
| 181 |
-
return {
|
| 182 |
-
'predictions': predictions,
|
| 183 |
-
'all_probabilities': probs[0].cpu().numpy(),
|
| 184 |
-
'logits': logits[0].cpu().numpy(),
|
| 185 |
-
'convnext_features': self.convnext_features,
|
| 186 |
-
'medsiglip_features': self.medsiglip_features
|
| 187 |
-
}
|
| 188 |
-
|
| 189 |
-
@classmethod
|
| 190 |
-
def load_from_checkpoint(
|
| 191 |
-
cls,
|
| 192 |
-
medsiglip_path: str,
|
| 193 |
-
convnext_path: Optional[str] = None,
|
| 194 |
-
ensemble_weights: tuple = (0.6, 0.4),
|
| 195 |
-
device: str = 'cpu'
|
| 196 |
-
):
|
| 197 |
-
"""
|
| 198 |
-
Load model from your training checkpoints
|
| 199 |
-
|
| 200 |
-
Args:
|
| 201 |
-
medsiglip_path: Path to MedSigLIP model weights
|
| 202 |
-
convnext_path: Path to ConvNeXt model weights (optional)
|
| 203 |
-
ensemble_weights: (w_medsiglip, w_convnext)
|
| 204 |
-
device: Device to load on
|
| 205 |
-
"""
|
| 206 |
-
model = cls(num_classes=11)
|
| 207 |
-
|
| 208 |
-
# Load MedSigLIP weights
|
| 209 |
-
print(f"Loading MedSigLIP from: {medsiglip_path}")
|
| 210 |
-
medsiglip_state = torch.load(medsiglip_path, map_location=device)
|
| 211 |
-
|
| 212 |
-
# Handle different checkpoint formats
|
| 213 |
-
if 'model_state_dict' in medsiglip_state:
|
| 214 |
-
model.load_state_dict(medsiglip_state['model_state_dict'])
|
| 215 |
-
else:
|
| 216 |
-
model.load_state_dict(medsiglip_state)
|
| 217 |
-
|
| 218 |
-
# Store ensemble weights for prediction fusion
|
| 219 |
-
model.ensemble_weights = ensemble_weights
|
| 220 |
-
|
| 221 |
-
model.to(device)
|
| 222 |
-
model.eval()
|
| 223 |
-
|
| 224 |
-
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/monet_concepts.py
DELETED
|
@@ -1,332 +0,0 @@
|
|
| 1 |
-
# models/monet_concepts.py
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
import numpy as np
|
| 5 |
-
from typing import Dict, List
|
| 6 |
-
from dataclasses import dataclass
|
| 7 |
-
|
| 8 |
-
@dataclass
|
| 9 |
-
class ConceptScore:
|
| 10 |
-
"""Single MONET concept with score and evidence"""
|
| 11 |
-
name: str
|
| 12 |
-
score: float
|
| 13 |
-
confidence: float
|
| 14 |
-
description: str
|
| 15 |
-
clinical_relevance: str # How this affects diagnosis
|
| 16 |
-
|
| 17 |
-
class MONETConceptScorer:
|
| 18 |
-
"""
|
| 19 |
-
MONET concept scoring using your trained metadata patterns
|
| 20 |
-
Integrates the boosting logic from your ensemble code
|
| 21 |
-
"""
|
| 22 |
-
|
| 23 |
-
# MONET concepts used in your training
|
| 24 |
-
CONCEPT_DEFINITIONS = {
|
| 25 |
-
'MONET_ulceration_crust': {
|
| 26 |
-
'description': 'Ulceration or crusting present',
|
| 27 |
-
'high_in': ['SCCKA', 'BCC', 'MAL_OTH'],
|
| 28 |
-
'low_in': ['NV', 'BKL'],
|
| 29 |
-
'threshold_high': 0.50
|
| 30 |
-
},
|
| 31 |
-
'MONET_erythema': {
|
| 32 |
-
'description': 'Redness or inflammation',
|
| 33 |
-
'high_in': ['INF', 'BCC', 'SCCKA'],
|
| 34 |
-
'low_in': ['MEL', 'NV'],
|
| 35 |
-
'threshold_high': 0.40
|
| 36 |
-
},
|
| 37 |
-
'MONET_pigmented': {
|
| 38 |
-
'description': 'Pigmentation present',
|
| 39 |
-
'high_in': ['MEL', 'NV', 'BKL'],
|
| 40 |
-
'low_in': ['BCC', 'SCCKA', 'INF'],
|
| 41 |
-
'threshold_high': 0.55
|
| 42 |
-
},
|
| 43 |
-
'MONET_vasculature_vessels': {
|
| 44 |
-
'description': 'Vascular structures visible',
|
| 45 |
-
'high_in': ['VASC', 'BCC'],
|
| 46 |
-
'low_in': ['MEL', 'NV'],
|
| 47 |
-
'threshold_high': 0.35
|
| 48 |
-
},
|
| 49 |
-
'MONET_hair': {
|
| 50 |
-
'description': 'Hair follicles present',
|
| 51 |
-
'high_in': ['NV', 'BKL'],
|
| 52 |
-
'low_in': ['BCC', 'MEL'],
|
| 53 |
-
'threshold_high': 0.30
|
| 54 |
-
},
|
| 55 |
-
'MONET_gel_water_drop_fluid_dermoscopy_liquid': {
|
| 56 |
-
'description': 'Gel/fluid artifacts',
|
| 57 |
-
'high_in': [],
|
| 58 |
-
'low_in': [],
|
| 59 |
-
'threshold_high': 0.40
|
| 60 |
-
},
|
| 61 |
-
'MONET_skin_markings_pen_ink_purple_pen': {
|
| 62 |
-
'description': 'Pen markings present',
|
| 63 |
-
'high_in': [],
|
| 64 |
-
'low_in': [],
|
| 65 |
-
'threshold_high': 0.40
|
| 66 |
-
}
|
| 67 |
-
}
|
| 68 |
-
|
| 69 |
-
# Class-specific patterns from your metadata boosting
|
| 70 |
-
CLASS_PATTERNS = {
|
| 71 |
-
'MAL_OTH': {
|
| 72 |
-
'sex': 'male', # 88.9% male
|
| 73 |
-
'site_preference': 'trunk',
|
| 74 |
-
'age_range': (60, 80),
|
| 75 |
-
'key_concepts': {'MONET_ulceration_crust': 0.35}
|
| 76 |
-
},
|
| 77 |
-
'INF': {
|
| 78 |
-
'key_concepts': {
|
| 79 |
-
'MONET_erythema': 0.42,
|
| 80 |
-
'MONET_pigmented': (None, 0.30) # Low pigmentation
|
| 81 |
-
}
|
| 82 |
-
},
|
| 83 |
-
'BEN_OTH': {
|
| 84 |
-
'site_preference': ['head', 'neck', 'face'], # 47.7%
|
| 85 |
-
'key_concepts': {'MONET_pigmented': (0.30, 0.50)}
|
| 86 |
-
},
|
| 87 |
-
'DF': {
|
| 88 |
-
'site_preference': ['lower', 'leg', 'ankle', 'foot'], # 65.4%
|
| 89 |
-
'age_range': (40, 65)
|
| 90 |
-
},
|
| 91 |
-
'SCCKA': {
|
| 92 |
-
'age_range': (65, None),
|
| 93 |
-
'key_concepts': {
|
| 94 |
-
'MONET_ulceration_crust': 0.50,
|
| 95 |
-
'MONET_pigmented': (None, 0.15)
|
| 96 |
-
}
|
| 97 |
-
},
|
| 98 |
-
'MEL': {
|
| 99 |
-
'age_range': (55, None), # 61.8 years average
|
| 100 |
-
'key_concepts': {'MONET_pigmented': 0.55}
|
| 101 |
-
},
|
| 102 |
-
'NV': {
|
| 103 |
-
'age_range': (None, 45), # 42.0 years average
|
| 104 |
-
'key_concepts': {'MONET_pigmented': 0.55}
|
| 105 |
-
}
|
| 106 |
-
}
|
| 107 |
-
|
| 108 |
-
def __init__(self):
|
| 109 |
-
"""Initialize MONET scorer with class patterns"""
|
| 110 |
-
self.class_names = [
|
| 111 |
-
'AKIEC', 'BCC', 'BEN_OTH', 'BKL', 'DF',
|
| 112 |
-
'INF', 'MAL_OTH', 'MEL', 'NV', 'SCCKA', 'VASC'
|
| 113 |
-
]
|
| 114 |
-
|
| 115 |
-
def compute_concept_scores(
|
| 116 |
-
self,
|
| 117 |
-
metadata: Dict[str, float]
|
| 118 |
-
) -> Dict[str, ConceptScore]:
|
| 119 |
-
"""
|
| 120 |
-
Compute MONET concept scores from metadata
|
| 121 |
-
|
| 122 |
-
Args:
|
| 123 |
-
metadata: Dictionary with MONET scores, age, sex, site, etc.
|
| 124 |
-
|
| 125 |
-
Returns:
|
| 126 |
-
Dictionary of concept scores
|
| 127 |
-
"""
|
| 128 |
-
concept_scores = {}
|
| 129 |
-
|
| 130 |
-
for concept_name, definition in self.CONCEPT_DEFINITIONS.items():
|
| 131 |
-
score = metadata.get(concept_name, 0.0)
|
| 132 |
-
|
| 133 |
-
# Determine confidence based on how extreme the score is
|
| 134 |
-
if score > definition['threshold_high']:
|
| 135 |
-
confidence = min((score - definition['threshold_high']) / 0.2, 1.0)
|
| 136 |
-
level = "HIGH"
|
| 137 |
-
elif score < 0.2:
|
| 138 |
-
confidence = min((0.2 - score) / 0.2, 1.0)
|
| 139 |
-
level = "LOW"
|
| 140 |
-
else:
|
| 141 |
-
confidence = 0.5
|
| 142 |
-
level = "MODERATE"
|
| 143 |
-
|
| 144 |
-
# Clinical relevance
|
| 145 |
-
if level == "HIGH":
|
| 146 |
-
relevant_classes = definition['high_in']
|
| 147 |
-
clinical_relevance = f"Supports: {', '.join(relevant_classes)}"
|
| 148 |
-
elif level == "LOW":
|
| 149 |
-
excluded_classes = definition['low_in']
|
| 150 |
-
clinical_relevance = f"Against: {', '.join(excluded_classes)}"
|
| 151 |
-
else:
|
| 152 |
-
clinical_relevance = "Non-specific"
|
| 153 |
-
|
| 154 |
-
concept_scores[concept_name] = ConceptScore(
|
| 155 |
-
name=concept_name.replace('MONET_', '').replace('_', ' ').title(),
|
| 156 |
-
score=score,
|
| 157 |
-
confidence=confidence,
|
| 158 |
-
description=f"{definition['description']} ({level})",
|
| 159 |
-
clinical_relevance=clinical_relevance
|
| 160 |
-
)
|
| 161 |
-
|
| 162 |
-
return concept_scores
|
| 163 |
-
|
| 164 |
-
def apply_metadata_boosting(
|
| 165 |
-
self,
|
| 166 |
-
probs: np.ndarray,
|
| 167 |
-
metadata: Dict
|
| 168 |
-
) -> np.ndarray:
|
| 169 |
-
"""
|
| 170 |
-
Apply your metadata boosting logic
|
| 171 |
-
This is directly from your ensemble optimization code
|
| 172 |
-
|
| 173 |
-
Args:
|
| 174 |
-
probs: [11] probability array
|
| 175 |
-
metadata: Dictionary with age, sex, site, MONET scores
|
| 176 |
-
|
| 177 |
-
Returns:
|
| 178 |
-
boosted_probs: [11] adjusted probabilities
|
| 179 |
-
"""
|
| 180 |
-
boosted_probs = probs.copy()
|
| 181 |
-
|
| 182 |
-
# 1. MAL_OTH boosting
|
| 183 |
-
if metadata.get('sex') == 'male':
|
| 184 |
-
site = str(metadata.get('site', '')).lower()
|
| 185 |
-
if 'trunk' in site:
|
| 186 |
-
age = metadata.get('age_approx', 60)
|
| 187 |
-
ulceration = metadata.get('MONET_ulceration_crust', 0)
|
| 188 |
-
|
| 189 |
-
score = 0
|
| 190 |
-
score += 3 if metadata.get('sex') == 'male' else 0
|
| 191 |
-
score += 2 if 'trunk' in site else 0
|
| 192 |
-
score += 1 if 60 <= age <= 80 else 0
|
| 193 |
-
score += 2 if ulceration > 0.35 else 0
|
| 194 |
-
|
| 195 |
-
confidence = score / 8.0
|
| 196 |
-
if confidence > 0.5:
|
| 197 |
-
boosted_probs[6] *= (1.0 + confidence) # MAL_OTH index
|
| 198 |
-
|
| 199 |
-
# 2. INF boosting
|
| 200 |
-
erythema = metadata.get('MONET_erythema', 0)
|
| 201 |
-
pigmentation = metadata.get('MONET_pigmented', 0)
|
| 202 |
-
|
| 203 |
-
if erythema > 0.42 and pigmentation < 0.30:
|
| 204 |
-
confidence = min((erythema - 0.42) / 0.10 + 0.5, 1.0)
|
| 205 |
-
boosted_probs[5] *= (1.0 + confidence * 0.8) # INF index
|
| 206 |
-
|
| 207 |
-
# 3. BEN_OTH boosting
|
| 208 |
-
site = str(metadata.get('site', '')).lower()
|
| 209 |
-
is_head_neck = any(x in site for x in ['head', 'neck', 'face'])
|
| 210 |
-
|
| 211 |
-
if is_head_neck and 0.30 < pigmentation < 0.50:
|
| 212 |
-
ulceration = metadata.get('MONET_ulceration_crust', 0)
|
| 213 |
-
confidence = 0.7 if ulceration < 0.30 else 0.4
|
| 214 |
-
boosted_probs[2] *= (1.0 + confidence * 0.5) # BEN_OTH index
|
| 215 |
-
|
| 216 |
-
# 4. DF boosting
|
| 217 |
-
is_lower_ext = any(x in site for x in ['lower', 'leg', 'ankle', 'foot'])
|
| 218 |
-
|
| 219 |
-
if is_lower_ext:
|
| 220 |
-
age = metadata.get('age_approx', 60)
|
| 221 |
-
if 40 <= age <= 65:
|
| 222 |
-
boosted_probs[4] *= 1.8 # DF index
|
| 223 |
-
elif 30 <= age <= 75:
|
| 224 |
-
boosted_probs[4] *= 1.5
|
| 225 |
-
|
| 226 |
-
# 5. SCCKA boosting
|
| 227 |
-
ulceration = metadata.get('MONET_ulceration_crust', 0)
|
| 228 |
-
age = metadata.get('age_approx', 60)
|
| 229 |
-
|
| 230 |
-
if ulceration > 0.50 and age >= 65 and pigmentation < 0.15:
|
| 231 |
-
boosted_probs[9] *= 1.9 # SCCKA index
|
| 232 |
-
elif ulceration > 0.45 and age >= 60 and pigmentation < 0.20:
|
| 233 |
-
boosted_probs[9] *= 1.5
|
| 234 |
-
|
| 235 |
-
# 6. MEL vs NV age separation
|
| 236 |
-
if pigmentation > 0.55:
|
| 237 |
-
if age >= 55:
|
| 238 |
-
age_score = min((age - 55) / 20.0, 1.0)
|
| 239 |
-
boosted_probs[7] *= (1.0 + age_score * 0.5) # MEL
|
| 240 |
-
boosted_probs[8] *= (1.0 - age_score * 0.3) # NV
|
| 241 |
-
elif age <= 45:
|
| 242 |
-
age_score = min((45 - age) / 30.0, 1.0)
|
| 243 |
-
boosted_probs[7] *= (1.0 - age_score * 0.3) # MEL
|
| 244 |
-
boosted_probs[8] *= (1.0 + age_score * 0.5) # NV
|
| 245 |
-
|
| 246 |
-
# 7. Exclusions based on pigmentation/erythema
|
| 247 |
-
if pigmentation > 0.50:
|
| 248 |
-
boosted_probs[0] *= 0.7 # AKIEC
|
| 249 |
-
boosted_probs[1] *= 0.6 # BCC
|
| 250 |
-
boosted_probs[5] *= 0.5 # INF
|
| 251 |
-
boosted_probs[9] *= 0.3 # SCCKA
|
| 252 |
-
|
| 253 |
-
if erythema > 0.40:
|
| 254 |
-
boosted_probs[7] *= 0.7 # MEL
|
| 255 |
-
boosted_probs[8] *= 0.7 # NV
|
| 256 |
-
|
| 257 |
-
if pigmentation < 0.20:
|
| 258 |
-
boosted_probs[7] *= 0.5 # MEL
|
| 259 |
-
boosted_probs[8] *= 0.5 # NV
|
| 260 |
-
|
| 261 |
-
# Renormalize
|
| 262 |
-
return boosted_probs / boosted_probs.sum()
|
| 263 |
-
|
| 264 |
-
def explain_prediction(
|
| 265 |
-
self,
|
| 266 |
-
probs: np.ndarray,
|
| 267 |
-
concept_scores: Dict[str, ConceptScore],
|
| 268 |
-
metadata: Dict
|
| 269 |
-
) -> str:
|
| 270 |
-
"""
|
| 271 |
-
Generate natural language explanation
|
| 272 |
-
|
| 273 |
-
Args:
|
| 274 |
-
probs: Class probabilities
|
| 275 |
-
concept_scores: MONET concept scores
|
| 276 |
-
metadata: Clinical metadata
|
| 277 |
-
|
| 278 |
-
Returns:
|
| 279 |
-
Natural language explanation
|
| 280 |
-
"""
|
| 281 |
-
predicted_idx = np.argmax(probs)
|
| 282 |
-
predicted_class = self.class_names[predicted_idx]
|
| 283 |
-
confidence = probs[predicted_idx]
|
| 284 |
-
|
| 285 |
-
explanation = f"**Primary Diagnosis: {predicted_class}**\n"
|
| 286 |
-
explanation += f"Confidence: {confidence:.1%}\n\n"
|
| 287 |
-
|
| 288 |
-
# Key MONET features
|
| 289 |
-
explanation += "**Key Dermoscopic Features:**\n"
|
| 290 |
-
|
| 291 |
-
sorted_concepts = sorted(
|
| 292 |
-
concept_scores.values(),
|
| 293 |
-
key=lambda x: x.score * x.confidence,
|
| 294 |
-
reverse=True
|
| 295 |
-
)
|
| 296 |
-
|
| 297 |
-
for i, concept in enumerate(sorted_concepts[:5], 1):
|
| 298 |
-
if concept.score > 0.3 or concept.score < 0.2:
|
| 299 |
-
explanation += f"{i}. {concept.name}: {concept.score:.2f} - {concept.description}\n"
|
| 300 |
-
if concept.clinical_relevance != "Non-specific":
|
| 301 |
-
explanation += f" β {concept.clinical_relevance}\n"
|
| 302 |
-
|
| 303 |
-
# Clinical context
|
| 304 |
-
explanation += "\n**Clinical Context:**\n"
|
| 305 |
-
if 'age_approx' in metadata:
|
| 306 |
-
explanation += f"β’ Age: {metadata['age_approx']} years\n"
|
| 307 |
-
if 'sex' in metadata:
|
| 308 |
-
explanation += f"β’ Sex: {metadata['sex']}\n"
|
| 309 |
-
if 'site' in metadata:
|
| 310 |
-
explanation += f"β’ Location: {metadata['site']}\n"
|
| 311 |
-
|
| 312 |
-
return explanation
|
| 313 |
-
|
| 314 |
-
def get_top_concepts(
|
| 315 |
-
self,
|
| 316 |
-
concept_scores: Dict[str, ConceptScore],
|
| 317 |
-
top_k: int = 5,
|
| 318 |
-
min_score: float = 0.3
|
| 319 |
-
) -> List[ConceptScore]:
|
| 320 |
-
"""Get top-k most important concepts"""
|
| 321 |
-
filtered = [
|
| 322 |
-
cs for cs in concept_scores.values()
|
| 323 |
-
if cs.score >= min_score or cs.score < 0.2 # High or low
|
| 324 |
-
]
|
| 325 |
-
|
| 326 |
-
sorted_concepts = sorted(
|
| 327 |
-
filtered,
|
| 328 |
-
key=lambda x: x.score * x.confidence,
|
| 329 |
-
reverse=True
|
| 330 |
-
)
|
| 331 |
-
|
| 332 |
-
return sorted_concepts[:top_k]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_models.py
DELETED
|
@@ -1,86 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
"""Test script to verify model loading"""
|
| 3 |
-
|
| 4 |
-
import torch
|
| 5 |
-
import torch.nn as nn
|
| 6 |
-
import timm
|
| 7 |
-
from transformers import AutoModel, AutoProcessor
|
| 8 |
-
import numpy as np
|
| 9 |
-
|
| 10 |
-
DEVICE = "cpu"
|
| 11 |
-
print(f"Device: {DEVICE}")
|
| 12 |
-
|
| 13 |
-
# ConvNeXt model definition (matching checkpoint)
|
| 14 |
-
class ConvNeXtDualEncoder(nn.Module):
|
| 15 |
-
def __init__(self, model_name="convnext_base.fb_in22k_ft_in1k",
|
| 16 |
-
metadata_dim=19, num_classes=11, dropout=0.3):
|
| 17 |
-
super().__init__()
|
| 18 |
-
self.backbone = timm.create_model(model_name, pretrained=False, num_classes=0)
|
| 19 |
-
backbone_dim = self.backbone.num_features
|
| 20 |
-
self.meta_mlp = nn.Sequential(
|
| 21 |
-
nn.Linear(metadata_dim, 64), nn.LayerNorm(64), nn.GELU(), nn.Dropout(dropout)
|
| 22 |
-
)
|
| 23 |
-
fusion_dim = backbone_dim * 2 + 64
|
| 24 |
-
self.classifier = nn.Sequential(
|
| 25 |
-
nn.Linear(fusion_dim, 512), nn.LayerNorm(512), nn.GELU(), nn.Dropout(dropout),
|
| 26 |
-
nn.Linear(512, 256), nn.LayerNorm(256), nn.GELU(), nn.Dropout(dropout),
|
| 27 |
-
nn.Linear(256, num_classes)
|
| 28 |
-
)
|
| 29 |
-
|
| 30 |
-
def forward(self, clinical_img, derm_img=None, metadata=None):
|
| 31 |
-
clinical_features = self.backbone(clinical_img)
|
| 32 |
-
derm_features = self.backbone(derm_img) if derm_img is not None else clinical_features
|
| 33 |
-
if metadata is not None:
|
| 34 |
-
meta_features = self.meta_mlp(metadata)
|
| 35 |
-
else:
|
| 36 |
-
meta_features = torch.zeros(clinical_features.size(0), 64, device=clinical_features.device)
|
| 37 |
-
fused = torch.cat([clinical_features, derm_features, meta_features], dim=1)
|
| 38 |
-
return self.classifier(fused)
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
# MedSigLIP model definition
|
| 42 |
-
class MedSigLIPClassifier(nn.Module):
|
| 43 |
-
def __init__(self, num_classes=11, model_name="google/siglip-base-patch16-384"):
|
| 44 |
-
super().__init__()
|
| 45 |
-
self.siglip = AutoModel.from_pretrained(model_name)
|
| 46 |
-
self.processor = AutoProcessor.from_pretrained(model_name)
|
| 47 |
-
hidden_dim = self.siglip.config.vision_config.hidden_size
|
| 48 |
-
self.classifier = nn.Sequential(
|
| 49 |
-
nn.Linear(hidden_dim, 512), nn.LayerNorm(512), nn.GELU(), nn.Dropout(0.3),
|
| 50 |
-
nn.Linear(512, num_classes)
|
| 51 |
-
)
|
| 52 |
-
for param in self.siglip.parameters():
|
| 53 |
-
param.requires_grad = False
|
| 54 |
-
|
| 55 |
-
def forward(self, pixel_values):
|
| 56 |
-
vision_outputs = self.siglip.vision_model(pixel_values=pixel_values)
|
| 57 |
-
pooled_features = vision_outputs.pooler_output
|
| 58 |
-
return self.classifier(pooled_features)
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
if __name__ == "__main__":
|
| 62 |
-
print("\n[1/2] Loading ConvNeXt...")
|
| 63 |
-
convnext_model = ConvNeXtDualEncoder()
|
| 64 |
-
ckpt = torch.load("models/seed42_fold0.pt", map_location=DEVICE, weights_only=False)
|
| 65 |
-
convnext_model.load_state_dict(ckpt)
|
| 66 |
-
convnext_model.eval()
|
| 67 |
-
print(" ConvNeXt loaded!")
|
| 68 |
-
|
| 69 |
-
print("\n[2/2] Loading MedSigLIP...")
|
| 70 |
-
medsiglip_model = MedSigLIPClassifier()
|
| 71 |
-
medsiglip_model.eval()
|
| 72 |
-
print(" MedSigLIP loaded!")
|
| 73 |
-
|
| 74 |
-
# Quick inference test
|
| 75 |
-
print("\nTesting inference...")
|
| 76 |
-
dummy_img = torch.randn(1, 3, 384, 384)
|
| 77 |
-
with torch.no_grad():
|
| 78 |
-
convnext_out = convnext_model(dummy_img)
|
| 79 |
-
print(f" ConvNeXt output: {convnext_out.shape}")
|
| 80 |
-
|
| 81 |
-
dummy_pil = np.random.randint(0, 255, (384, 384, 3), dtype=np.uint8)
|
| 82 |
-
siglip_input = medsiglip_model.processor(images=[dummy_pil], return_tensors="pt")
|
| 83 |
-
siglip_out = medsiglip_model(siglip_input["pixel_values"])
|
| 84 |
-
print(f" MedSigLIP output: {siglip_out.shape}")
|
| 85 |
-
|
| 86 |
-
print("\nAll tests passed!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|