cgoodmaker Claude Opus 4.6 commited on
Commit
672ed11
Β·
1 Parent(s): 1a97904

Remove unused files: old Gradio frontend, dead model code, test artifacts

Browse files

- frontend/ β€” entire old Gradio UI, replaced by React
- models/explainability.py β€” old GradCAM, replaced by gradcam_tool.py
- models/medsiglip_convnext_fusion.py β€” abandoned experimental model
- models/monet_concepts.py β€” old concept scorer, replaced by monet_tool.py
- test_models.py β€” one-off test script
- KAGGLE_REPORT.md β€” temporary report file

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

KAGGLE_REPORT.md DELETED
@@ -1,104 +0,0 @@
1
- # SkinProAI: Explainable Multi-Model Dermatology Decision Support
2
-
3
- ## 1. System Architecture & AI Pipeline
4
-
5
- SkinProAI is a clinician-first dermatology decision-support system that orchestrates multiple specialised AI models through a four-phase analysis pipeline. The system is designed to run entirely on-device using Google's MedGemma 4B β€” a 4-billion-parameter vision-language model purpose-built for edge deployment in clinical settings β€” executing on consumer hardware (Apple M4 Pro / MPS) in bfloat16 precision without requiring cloud GPU infrastructure.
6
-
7
- ### 1.1 Multi-Model Orchestration via MCP
8
-
9
- Rather than relying on a single monolithic model, SkinProAI employs a Model Context Protocol (MCP) architecture that isolates each AI capability as an independent tool callable via JSON-RPC over subprocess stdio. This design provides fault isolation, independent model loading, and a clear audit trail of which model produced which output.
10
-
11
- | Tool | Model | Purpose |
12
- |------|-------|---------|
13
- | `monet_analyze` | MONET (contrastive vision-language) | Extract 7 dermoscopic concept scores: ulceration, hair structures, vascular patterns, erythema, pigmented structures, gel artefacts, skin markings |
14
- | `classify_lesion` | ConvNeXt Base + metadata MLP | 11-class lesion classification using dual-encoder architecture combining image features with MONET concept scores |
15
- | `generate_gradcam` | Grad-CAM on ConvNeXt | Generate visual attention heatmap overlays highlighting regions driving the classification decision |
16
- | `search_guidelines` | FAISS + all-MiniLM-L6-v2 | Retrieve relevant clinical guideline passages from a RAG index of 286 chunks across 7 dermatology PDFs |
17
- | `compare_images` | MONET + overlay | Temporal change detection comparing dermoscopic feature vectors between sequential images |
18
-
19
- All tools receive absolute file paths as input (never raw image data), enabling secure subprocess isolation between the main FastAPI process (Python 3.9) and the MCP server (Python 3.11).
20
-
21
- ### 1.2 Four-Phase Analysis Pipeline
22
-
23
- Every image analysis follows a structured, transparent pipeline:
24
-
25
- **Phase 1 β€” Independent Visual Examination.** MedGemma 4B performs a systematic dermoscopic assessment *before* seeing any AI tool output. It evaluates pattern architecture, colour distribution, border characteristics, and structural features, producing differential diagnoses ranked by clinical probability. This deliberate sequencing prevents anchoring bias β€” the language model forms its own clinical impression independently.
26
-
27
- **Phase 2 β€” AI Classification Tools.** Three MCP tools execute in sequence: MONET extracts quantitative concept scores (0–1 range for each dermoscopic feature), ConvNeXt classifies the lesion into one of 11 diagnostic categories with calibrated probabilities, and Grad-CAM generates an attention overlay showing which image regions drove the classification. Each tool's output is streamed to the clinician in real-time with visual bar charts.
28
-
29
- **Phase 3 β€” Reconciliation.** MedGemma receives both its own Phase 1 assessment and the Phase 2 tool outputs, then performs explicit agreement/disagreement analysis. It identifies where its visual findings align with or diverge from the quantitative classifiers, produces an integrated assessment with a stated confidence level, and explains its reasoning. This adversarial cross-check between independent assessments is central to the system's reliability.
30
-
31
- **Phase 4 β€” Management Guidance with RAG.** The system automatically queries a FAISS-indexed knowledge base of clinical dermatology guidelines (BAD, NICE, and specialist PDFs covering BCC, SCC, melanoma, actinic keratosis, contact dermatitis, lichen sclerosus, and cutaneous warts). Using the diagnosed condition as a search query, the RAG system retrieves the top-5 most relevant guideline passages via cosine similarity over sentence-transformer embeddings (all-MiniLM-L6-v2, 384 dimensions). MedGemma then synthesises lesion-specific management recommendations β€” biopsy, excision, monitoring, or discharge β€” grounded in the retrieved evidence, with inline superscript citations linking back to source documents and page numbers.
32
-
33
- ---
34
-
35
- ## 2. AI Explainability
36
-
37
- Explainability is not an afterthought in SkinProAI β€” it is embedded at every layer of the architecture. The system provides three complementary forms of explanation: visual, quantitative, and narrative.
38
-
39
- ### 2.1 Visual Explainability: Grad-CAM Attention Maps
40
-
41
- Grad-CAM (Gradient-weighted Class Activation Mapping) hooks into the final convolutional layer of the ConvNeXt classifier to produce a spatial heatmap overlay on the original dermoscopic image. This directly answers the question "where is the model looking?" β€” showing clinicians which morphological features (border irregularity, colour variegation, structural asymmetry) are driving the classification. For temporal comparisons, the system generates side-by-side Grad-CAM pairs (previous vs. current) so clinicians can visually assess whether attention regions have shifted, expanded, or resolved.
42
-
43
- ### 2.2 Quantitative Explainability: MONET Concept Scores
44
-
45
- MONET provides human-interpretable concept decomposition by scoring seven clinically meaningful dermoscopic features on a continuous 0–1 scale. Unlike black-box classifiers that output only a label and probability, MONET's concept scores reveal *why* a lesion received its classification: a high vascular score combined with low pigmentation explains a vascular lesion diagnosis; high pigmented-structure scores with border irregularity support melanocytic concern. These scores are rendered as visual bar charts in the streaming UI, giving clinicians immediate quantitative insight into the feature profile driving the AI's assessment.
46
-
47
- When comparing sequential images over time, the system computes MONET feature deltas β€” the signed difference in each concept score between timepoints β€” enabling objective quantification of lesion evolution. A change from 0.3 to 0.7 in the vascular score, for example, signals new vessel formation that warrants clinical attention, independent of any subjective visual comparison.
48
-
49
- ### 2.3 Narrative Explainability: MedGemma Reasoning Transparency
50
-
51
- The streaming interface exposes MedGemma's reasoning process through structured markup segments. `[THINKING]` blocks display the model's intermediate reasoning (differential construction, feature weighting, agreement analysis) in real-time, with animated spinners that resolve to completion indicators as each reasoning phase finishes. `[RESPONSE]` blocks contain the synthesised clinical narrative. This staged transparency allows clinicians to follow the AI's analytical process rather than receiving only a final pronouncement.
52
-
53
- The reconciliation phase (Phase 3) is particularly significant for explainability: MedGemma explicitly states where its independent visual assessment agrees or disagrees with the quantitative classifiers, and explains the basis for its integrated conclusion. This adversarial structure makes disagreements visible and auditable.
54
-
55
- ### 2.4 Evidence Grounding: RAG Citations
56
-
57
- Management recommendations in Phase 4 include inline superscript references (e.g., "Wide local excision with 2cm margins is recommended for tumours >2mm Breslow thicknessΒΉ") linked to specific guideline documents and page numbers. This evidence chain β€” from diagnosis through to management recommendation β€” is fully traceable to published clinical guidelines, supporting regulatory compliance and clinical governance requirements.
58
-
59
- ---
60
-
61
- ## 3. Clinician-First Design & User Interface
62
-
63
- ### 3.1 Design Philosophy
64
-
65
- SkinProAI is built around the principle that AI should augment clinical decision-making, not replace it. The interface is designed for the workflow of a clinician reviewing dermoscopic images: patient selection, lesion documentation, temporal tracking, and structured analysis with clear next-step guidance.
66
-
67
- The system deliberately avoids presenting AI conclusions as definitive diagnoses. Instead, results are framed as ranked differentials with calibrated confidence scores, supported by quantitative feature evidence and visual attention maps. The clinician retains full agency over the diagnostic and management decision.
68
-
69
- ### 3.2 Streaming Real-Time Interface
70
-
71
- The UI employs Server-Sent Events (SSE) to stream analysis output in real-time, maintaining clinician engagement during model inference. Rather than a loading spinner followed by a wall of text, clinicians observe the analysis unfolding phase by phase:
72
-
73
- - **Tool status lines** appear as compact single-line indicators with animated spinners that resolve to green completion dots, showing tool name and summary result (e.g., "ConvNeXt classification β€” Melanoma (89%)")
74
- - **Thinking indicators** display MedGemma's reasoning steps with spinners that transition to done states as each phase completes
75
- - **Streamed text** appears word-by-word, allowing clinicians to begin reading findings before analysis completes
76
-
77
- This streaming pattern reduces perceived latency and provides transparency into the multi-model pipeline's progress.
78
-
79
- ### 3.3 Temporal Lesion Tracking
80
-
81
- The data model supports longitudinal monitoring through a Patient β†’ Lesion β†’ LesionImage hierarchy. Each lesion maintains a timeline of images with timestamps, and the system automatically triggers temporal comparison when a new image is uploaded for a previously-analysed lesion. Comparison results include MONET feature deltas, side-by-side Grad-CAM overlays, and a status classification (Stable / Minor Change / Significant Change / Improved) rendered with colour-coded indicators.
82
-
83
- This temporal architecture supports the clinical workflow for monitoring suspicious lesions over time β€” a common scenario in dermatology where serial dermoscopy is preferred over immediate biopsy for lesions with intermediate risk profiles.
84
-
85
- ### 3.4 Conversational Follow-Up
86
-
87
- After analysis completes, the system transitions to a conversational interface where clinicians can ask follow-up questions grounded in the analysis context. The chat maintains full awareness of the diagnosed condition, MONET features, classification results, and guideline context, enabling queries like "what are the excision margin recommendations for this depth?" or "how does the asymmetry score compare to the previous image?" The text-only chat pathway routes through MedGemma with the full analysis state, ensuring responses remain clinically contextualised.
88
-
89
- ### 3.5 Edge Deployment Architecture
90
-
91
- SkinProAI runs entirely on-device using MedGemma 4B in bfloat16 precision on consumer hardware (demonstrated on Apple M4 Pro with 24GB RAM using Metal Performance Shaders). No patient data leaves the device β€” all inference, image storage, and guideline retrieval execute locally. This edge-first architecture addresses data sovereignty requirements in clinical settings where patient images cannot be transmitted to cloud services. The system also supports containerised deployment via Docker for institutional environments, with optional GPU acceleration on CUDA-equipped hardware or HuggingFace Spaces.
92
-
93
- ### 3.6 Interface Components
94
-
95
- The frontend is built in React 18 with a glassmorphism-inspired design language. Key interface elements include:
96
-
97
- - **Patient grid** β€” Card-based patient selection with creation workflow
98
- - **Chat interface** β€” Unified conversation view combining image analysis, tool outputs, and text chat in a single scrollable thread
99
- - **Tool call cards** β€” Inline status indicators for each AI tool invocation (analyse, classify, Grad-CAM, guidelines search, compare) with expandable result summaries
100
- - **Image upload** β€” Drag-and-drop or click-to-upload with preview, supporting both initial analysis and temporal comparison workflows
101
- - **Post-analysis prompt** β€” Contextual hint guiding clinicians to ask follow-up questions, provide additional context, or upload comparison images
102
- - **Markdown rendering** β€” Clinical narratives rendered with proper formatting, headers, lists, and inline citations for readability
103
-
104
- The interface prioritises information density without clutter: tool outputs collapse to single-line summaries by default, reasoning phases are visually distinguished from conclusions, and temporal comparisons use colour-coded status indicators (green/amber/red/blue) that map to clinical urgency levels.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
frontend/app.py DELETED
@@ -1,532 +0,0 @@
1
- """
2
- SkinProAI Frontend - Modular Gradio application
3
- """
4
-
5
- import gradio as gr
6
- from typing import Dict, Generator, Optional
7
- from datetime import datetime
8
- import sys
9
- import os
10
- import re
11
- import base64
12
-
13
- sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
14
-
15
- from data.case_store import get_case_store
16
- from frontend.components.styles import MAIN_CSS
17
- from frontend.components.analysis_view import format_output
18
-
19
-
20
- # =============================================================================
21
- # CONFIG
22
- # =============================================================================
23
-
24
- class Config:
25
- APP_TITLE = "SkinProAI"
26
- SERVER_PORT = int(os.environ.get("GRADIO_SERVER_PORT", "7860"))
27
- HF_SPACES = os.environ.get("SPACE_ID") is not None
28
-
29
-
30
- # =============================================================================
31
- # AGENT
32
- # =============================================================================
33
-
34
- class AnalysisAgent:
35
- """Wrapper for the MedGemma analysis agent"""
36
-
37
- def __init__(self):
38
- self.model = None
39
- self.loaded = False
40
-
41
- def load(self):
42
- if self.loaded:
43
- return
44
- from models.medgemma_agent import MedGemmaAgent
45
- self.model = MedGemmaAgent(verbose=True)
46
- self.model.load_model()
47
- self.loaded = True
48
-
49
- def analyze(self, image_path: str, question: str = "") -> Generator[str, None, None]:
50
- if not self.loaded:
51
- yield "[STAGE:loading]Loading AI models...[/STAGE]\n"
52
- self.load()
53
-
54
- for chunk in self.model.analyze_image_stream(image_path, question=question):
55
- yield chunk
56
-
57
- def management_guidance(self, confirmed: bool, feedback: str = None) -> Generator[str, None, None]:
58
- for chunk in self.model.generate_management_guidance(confirmed, feedback):
59
- yield chunk
60
-
61
- def followup(self, message: str) -> Generator[str, None, None]:
62
- if not self.loaded or not self.model.last_diagnosis:
63
- yield "[ERROR]No analysis context available.[/ERROR]\n"
64
- return
65
- for chunk in self.model.chat_followup(message):
66
- yield chunk
67
-
68
- def reset(self):
69
- if self.model:
70
- self.model.reset_state()
71
-
72
-
73
- agent = AnalysisAgent()
74
- case_store = get_case_store()
75
-
76
-
77
- # =============================================================================
78
- # APP
79
- # =============================================================================
80
-
81
- with gr.Blocks(title=Config.APP_TITLE, css=MAIN_CSS, theme=gr.themes.Soft()) as app:
82
-
83
- # =========================================================================
84
- # STATE
85
- # =========================================================================
86
- state = gr.State({
87
- "page": "patient_select", # patient_select | analysis
88
- "case_id": None,
89
- "instance_id": None,
90
- "output": "",
91
- "gradcam_base64": None
92
- })
93
-
94
- # =========================================================================
95
- # PAGE 1: PATIENT SELECTION
96
- # =========================================================================
97
- with gr.Group(visible=True, elem_classes=["patient-select-container"]) as page_patient:
98
- gr.Markdown("# SkinProAI", elem_classes=["patient-select-title"])
99
- gr.Markdown("Select a patient to continue or create a new case", elem_classes=["patient-select-subtitle"])
100
-
101
- with gr.Row(elem_classes=["patient-grid"]):
102
- btn_demo_melanoma = gr.Button("Demo: Melanocytic Lesion", elem_classes=["patient-card"])
103
- btn_demo_ak = gr.Button("Demo: Actinic Keratosis", elem_classes=["patient-card"])
104
- btn_new_patient = gr.Button("+ New Patient", variant="primary", elem_classes=["new-patient-btn"])
105
-
106
- # =========================================================================
107
- # PAGE 2: ANALYSIS
108
- # =========================================================================
109
- with gr.Group(visible=False) as page_analysis:
110
-
111
- # Header
112
- with gr.Row(elem_classes=["app-header"]):
113
- gr.Markdown(f"**{Config.APP_TITLE}**", elem_classes=["app-title"])
114
- btn_back = gr.Button("< Back to Patients", elem_classes=["back-btn"])
115
-
116
- with gr.Row(elem_classes=["analysis-container"]):
117
-
118
- # Sidebar (previous queries)
119
- with gr.Column(scale=0, min_width=260, visible=False, elem_classes=["query-sidebar"]) as sidebar:
120
- gr.Markdown("### Previous Queries", elem_classes=["sidebar-header"])
121
- sidebar_list = gr.Column(elem_id="sidebar-queries")
122
- btn_new_query = gr.Button("+ New Query", size="sm", variant="primary")
123
-
124
- # Main content
125
- with gr.Column(scale=4, elem_classes=["main-content"]):
126
-
127
- # Input view (greeting style)
128
- with gr.Group(visible=True, elem_classes=["input-greeting"]) as view_input:
129
- gr.Markdown("What would you like to analyze?", elem_classes=["greeting-title"])
130
- gr.Markdown("Upload an image and describe what you'd like to know", elem_classes=["greeting-subtitle"])
131
-
132
- with gr.Column(elem_classes=["input-box-container"]):
133
- input_message = gr.Textbox(
134
- placeholder="Describe the lesion or ask a question...",
135
- show_label=False,
136
- lines=2,
137
- elem_classes=["message-input"]
138
- )
139
-
140
- input_image = gr.Image(
141
- type="pil",
142
- height=180,
143
- show_label=False,
144
- elem_classes=["image-preview"]
145
- )
146
-
147
- with gr.Row(elem_classes=["input-actions"]):
148
- gr.Markdown("*Upload a skin lesion image*")
149
- btn_analyze = gr.Button("Analyze", elem_classes=["send-btn"], interactive=False)
150
-
151
- # Results view (shown after analysis)
152
- with gr.Group(visible=False, elem_classes=["chat-view"]) as view_results:
153
- output_html = gr.HTML(
154
- value='<div class="analysis-output">Starting...</div>',
155
- elem_classes=["results-area"]
156
- )
157
-
158
- # Confirmation
159
- with gr.Group(visible=False, elem_classes=["confirm-buttons"]) as confirm_box:
160
- gr.Markdown("**Do you agree with this diagnosis?**")
161
- with gr.Row():
162
- btn_confirm_yes = gr.Button("Yes, continue", variant="primary", size="sm")
163
- btn_confirm_no = gr.Button("No, I disagree", variant="secondary", size="sm")
164
- input_feedback = gr.Textbox(label="Your assessment", placeholder="Enter diagnosis...", visible=False)
165
- btn_submit_feedback = gr.Button("Submit", visible=False, size="sm")
166
-
167
- # Follow-up
168
- with gr.Row(elem_classes=["chat-input-area"]):
169
- input_followup = gr.Textbox(placeholder="Ask a follow-up question...", show_label=False, lines=1, scale=4)
170
- btn_followup = gr.Button("Send", size="sm", scale=1)
171
-
172
- # =========================================================================
173
- # DYNAMIC SIDEBAR RENDERING
174
- # =========================================================================
175
- @gr.render(inputs=[state], triggers=[state.change])
176
- def render_sidebar(s):
177
- case_id = s.get("case_id")
178
- if not case_id or s.get("page") != "analysis":
179
- return
180
-
181
- instances = case_store.list_instances(case_id)
182
- current = s.get("instance_id")
183
-
184
- for i, inst in enumerate(instances, 1):
185
- diagnosis = "Pending"
186
- if inst.analysis and inst.analysis.get("diagnosis"):
187
- d = inst.analysis["diagnosis"]
188
- diagnosis = d.get("class", "?")
189
-
190
- label = f"#{i}: {diagnosis}"
191
- variant = "primary" if inst.id == current else "secondary"
192
- btn = gr.Button(label, size="sm", variant=variant, elem_classes=["query-item"])
193
-
194
- # Attach click handler to load this instance
195
- def load_instance(inst_id=inst.id, c_id=case_id):
196
- def _load(current_state):
197
- current_state["instance_id"] = inst_id
198
- instance = case_store.get_instance(c_id, inst_id)
199
-
200
- # Load saved output if available
201
- output_html = '<div class="analysis-output"><div class="result">Previous analysis loaded</div></div>'
202
- if instance and instance.analysis:
203
- diag = instance.analysis.get("diagnosis", {})
204
- output_html = f'<div class="analysis-output"><div class="result">Diagnosis: {diag.get("full_name", diag.get("class", "Unknown"))}</div></div>'
205
-
206
- return (
207
- current_state,
208
- gr.update(visible=False), # view_input
209
- gr.update(visible=True), # view_results
210
- output_html,
211
- gr.update(visible=False) # confirm_box
212
- )
213
- return _load
214
-
215
- btn.click(
216
- load_instance(),
217
- inputs=[state],
218
- outputs=[state, view_input, view_results, output_html, confirm_box]
219
- )
220
-
221
- # =========================================================================
222
- # EVENT HANDLERS
223
- # =========================================================================
224
-
225
- def select_patient(case_id: str, s: Dict):
226
- """Handle patient selection"""
227
- s["case_id"] = case_id
228
- s["page"] = "analysis"
229
-
230
- instances = case_store.list_instances(case_id)
231
- has_queries = len(instances) > 0
232
-
233
- if has_queries:
234
- # Load most recent
235
- inst = instances[-1]
236
- s["instance_id"] = inst.id
237
-
238
- # Load image if exists
239
- img = None
240
- if inst.image_path and os.path.exists(inst.image_path):
241
- from PIL import Image
242
- img = Image.open(inst.image_path)
243
-
244
- return (
245
- s,
246
- gr.update(visible=False), # page_patient
247
- gr.update(visible=True), # page_analysis
248
- gr.update(visible=True), # sidebar
249
- gr.update(visible=False), # view_input
250
- gr.update(visible=True), # view_results
251
- '<div class="analysis-output"><div class="result">Previous analysis loaded</div></div>',
252
- gr.update(visible=False) # confirm_box
253
- )
254
- else:
255
- # New instance
256
- inst = case_store.create_instance(case_id)
257
- s["instance_id"] = inst.id
258
- s["output"] = ""
259
-
260
- return (
261
- s,
262
- gr.update(visible=False),
263
- gr.update(visible=True),
264
- gr.update(visible=False), # sidebar hidden for new patient
265
- gr.update(visible=True), # view_input
266
- gr.update(visible=False), # view_results
267
- "",
268
- gr.update(visible=False)
269
- )
270
-
271
- def new_patient(s: Dict):
272
- """Create new patient"""
273
- case = case_store.create_case(f"Patient {datetime.now().strftime('%Y-%m-%d %H:%M')}")
274
- return select_patient(case.id, s)
275
-
276
- def go_back(s: Dict):
277
- """Return to patient selection"""
278
- s["page"] = "patient_select"
279
- s["case_id"] = None
280
- s["instance_id"] = None
281
- s["output"] = ""
282
-
283
- return (
284
- s,
285
- gr.update(visible=True), # page_patient
286
- gr.update(visible=False), # page_analysis
287
- gr.update(visible=False), # sidebar
288
- gr.update(visible=True), # view_input
289
- gr.update(visible=False), # view_results
290
- "",
291
- gr.update(visible=False) # confirm_box
292
- )
293
-
294
- def new_query(s: Dict):
295
- """Start new query for current patient"""
296
- case_id = s.get("case_id")
297
- if not case_id:
298
- return s, gr.update(), gr.update(), gr.update(), "", gr.update()
299
-
300
- inst = case_store.create_instance(case_id)
301
- s["instance_id"] = inst.id
302
- s["output"] = ""
303
- s["gradcam_base64"] = None
304
-
305
- agent.reset()
306
-
307
- return (
308
- s,
309
- gr.update(visible=True), # view_input
310
- gr.update(visible=False), # view_results
311
- None, # clear image
312
- "", # clear output
313
- gr.update(visible=False) # confirm_box
314
- )
315
-
316
- def enable_analyze(img):
317
- """Enable analyze button when image uploaded"""
318
- return gr.update(interactive=img is not None)
319
-
320
- def run_analysis(image, message, s: Dict):
321
- """Run analysis on uploaded image"""
322
- if image is None:
323
- yield s, gr.update(), gr.update(), gr.update(), gr.update()
324
- return
325
-
326
- case_id = s["case_id"]
327
- instance_id = s["instance_id"]
328
-
329
- # Save image
330
- image_path = case_store.save_image(case_id, instance_id, image)
331
- case_store.update_analysis(case_id, instance_id, stage="analyzing", image_path=image_path)
332
-
333
- agent.reset()
334
- s["output"] = ""
335
- gradcam_base64 = None
336
- has_confirm = False
337
-
338
- # Switch to results view
339
- yield (
340
- s,
341
- gr.update(visible=False), # view_input
342
- gr.update(visible=True), # view_results
343
- '<div class="analysis-output">Starting analysis...</div>',
344
- gr.update(visible=False) # confirm_box
345
- )
346
-
347
- partial = ""
348
- for chunk in agent.analyze(image_path, message or ""):
349
- partial += chunk
350
-
351
- # Check for GradCAM
352
- if gradcam_base64 is None:
353
- match = re.search(r'\[GRADCAM_IMAGE:([^\]]+)\]', partial)
354
- if match:
355
- path = match.group(1)
356
- if os.path.exists(path):
357
- try:
358
- with open(path, "rb") as f:
359
- gradcam_base64 = base64.b64encode(f.read()).decode('utf-8')
360
- s["gradcam_base64"] = gradcam_base64
361
- except:
362
- pass
363
-
364
- if '[CONFIRM:' in partial:
365
- has_confirm = True
366
-
367
- s["output"] = partial
368
-
369
- yield (
370
- s,
371
- gr.update(visible=False),
372
- gr.update(visible=True),
373
- format_output(partial, gradcam_base64),
374
- gr.update(visible=has_confirm)
375
- )
376
-
377
- # Save analysis
378
- if agent.model and agent.model.last_diagnosis:
379
- diag = agent.model.last_diagnosis["predictions"][0]
380
- case_store.update_analysis(
381
- case_id, instance_id,
382
- stage="awaiting_confirmation",
383
- analysis={"diagnosis": diag}
384
- )
385
-
386
- def confirm_yes(s: Dict):
387
- """User confirmed diagnosis"""
388
- partial = s.get("output", "")
389
- gradcam = s.get("gradcam_base64")
390
-
391
- for chunk in agent.management_guidance(confirmed=True):
392
- partial += chunk
393
- s["output"] = partial
394
- yield s, format_output(partial, gradcam), gr.update(visible=False)
395
-
396
- case_store.update_analysis(s["case_id"], s["instance_id"], stage="complete")
397
-
398
- def confirm_no():
399
- """Show feedback input"""
400
- return gr.update(visible=True), gr.update(visible=True)
401
-
402
- def submit_feedback(feedback: str, s: Dict):
403
- """Submit user feedback"""
404
- partial = s.get("output", "")
405
- gradcam = s.get("gradcam_base64")
406
-
407
- for chunk in agent.management_guidance(confirmed=False, feedback=feedback):
408
- partial += chunk
409
- s["output"] = partial
410
- yield (
411
- s,
412
- format_output(partial, gradcam),
413
- gr.update(visible=False),
414
- gr.update(visible=False),
415
- gr.update(visible=False),
416
- ""
417
- )
418
-
419
- case_store.update_analysis(s["case_id"], s["instance_id"], stage="complete")
420
-
421
- def send_followup(message: str, s: Dict):
422
- """Send follow-up question"""
423
- if not message.strip():
424
- return s, gr.update(), ""
425
-
426
- case_store.add_chat_message(s["case_id"], s["instance_id"], "user", message)
427
-
428
- partial = s.get("output", "")
429
- gradcam = s.get("gradcam_base64")
430
-
431
- partial += f'\n<div class="chat-message user">You: {message}</div>\n'
432
-
433
- response = ""
434
- for chunk in agent.followup(message):
435
- response += chunk
436
- s["output"] = partial + response
437
- yield s, format_output(partial + response, gradcam), ""
438
-
439
- case_store.add_chat_message(s["case_id"], s["instance_id"], "assistant", response)
440
-
441
- # =========================================================================
442
- # WIRE EVENTS
443
- # =========================================================================
444
-
445
- # Patient selection
446
- btn_demo_melanoma.click(
447
- lambda s: select_patient("demo-melanoma", s),
448
- inputs=[state],
449
- outputs=[state, page_patient, page_analysis, sidebar, view_input, view_results, output_html, confirm_box]
450
- )
451
-
452
- btn_demo_ak.click(
453
- lambda s: select_patient("demo-ak", s),
454
- inputs=[state],
455
- outputs=[state, page_patient, page_analysis, sidebar, view_input, view_results, output_html, confirm_box]
456
- )
457
-
458
- btn_new_patient.click(
459
- new_patient,
460
- inputs=[state],
461
- outputs=[state, page_patient, page_analysis, sidebar, view_input, view_results, output_html, confirm_box]
462
- )
463
-
464
- # Navigation
465
- btn_back.click(
466
- go_back,
467
- inputs=[state],
468
- outputs=[state, page_patient, page_analysis, sidebar, view_input, view_results, output_html, confirm_box]
469
- )
470
-
471
- btn_new_query.click(
472
- new_query,
473
- inputs=[state],
474
- outputs=[state, view_input, view_results, input_image, output_html, confirm_box]
475
- )
476
-
477
- # Analysis
478
- input_image.change(enable_analyze, inputs=[input_image], outputs=[btn_analyze])
479
-
480
- btn_analyze.click(
481
- run_analysis,
482
- inputs=[input_image, input_message, state],
483
- outputs=[state, view_input, view_results, output_html, confirm_box]
484
- )
485
-
486
- # Confirmation
487
- btn_confirm_yes.click(
488
- confirm_yes,
489
- inputs=[state],
490
- outputs=[state, output_html, confirm_box]
491
- )
492
-
493
- btn_confirm_no.click(
494
- confirm_no,
495
- outputs=[input_feedback, btn_submit_feedback]
496
- )
497
-
498
- btn_submit_feedback.click(
499
- submit_feedback,
500
- inputs=[input_feedback, state],
501
- outputs=[state, output_html, confirm_box, input_feedback, btn_submit_feedback, input_feedback]
502
- )
503
-
504
- # Follow-up
505
- btn_followup.click(
506
- send_followup,
507
- inputs=[input_followup, state],
508
- outputs=[state, output_html, input_followup]
509
- )
510
-
511
- input_followup.submit(
512
- send_followup,
513
- inputs=[input_followup, state],
514
- outputs=[state, output_html, input_followup]
515
- )
516
-
517
-
518
- # =============================================================================
519
- # MAIN
520
- # =============================================================================
521
-
522
- if __name__ == "__main__":
523
- print(f"\n{'='*50}")
524
- print(f" {Config.APP_TITLE}")
525
- print(f"{'='*50}\n")
526
-
527
- app.queue().launch(
528
- server_name="0.0.0.0" if Config.HF_SPACES else "127.0.0.1",
529
- server_port=Config.SERVER_PORT,
530
- share=False,
531
- show_error=True
532
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
frontend/components/__init__.py DELETED
File without changes
frontend/components/analysis_view.py DELETED
@@ -1,214 +0,0 @@
1
- """
2
- Analysis View Component - Main analysis interface with input and results
3
- """
4
-
5
- import gradio as gr
6
- import re
7
- from typing import Optional
8
-
9
-
10
- def parse_markdown(text: str) -> str:
11
- """Convert basic markdown to HTML"""
12
- text = re.sub(r'\*\*(.+?)\*\*', r'<strong>\1</strong>', text)
13
- text = re.sub(r'__(.+?)__', r'<strong>\1</strong>', text)
14
- text = re.sub(r'\*(.+?)\*', r'<em>\1</em>', text)
15
-
16
- # Bullet lists
17
- lines = text.split('\n')
18
- in_list = False
19
- result = []
20
- for line in lines:
21
- stripped = line.strip()
22
- if re.match(r'^[\*\-] ', stripped):
23
- if not in_list:
24
- result.append('<ul>')
25
- in_list = True
26
- item = re.sub(r'^[\*\-] ', '', stripped)
27
- result.append(f'<li>{item}</li>')
28
- else:
29
- if in_list:
30
- result.append('</ul>')
31
- in_list = False
32
- result.append(line)
33
- if in_list:
34
- result.append('</ul>')
35
-
36
- return '\n'.join(result)
37
-
38
-
39
- # Regex patterns for output parsing
40
- _STAGE_RE = re.compile(r'\[STAGE:(\w+)\](.*?)\[/STAGE\]')
41
- _THINKING_RE = re.compile(r'\[THINKING\](.*?)\[/THINKING\]')
42
- _OBSERVATION_RE = re.compile(r'\[OBSERVATION\](.*?)\[/OBSERVATION\]')
43
- _TOOL_OUTPUT_RE = re.compile(r'\[TOOL_OUTPUT:(.*?)\]\n(.*?)\[/TOOL_OUTPUT\]', re.DOTALL)
44
- _RESULT_RE = re.compile(r'\[RESULT\](.*?)\[/RESULT\]')
45
- _ERROR_RE = re.compile(r'\[ERROR\](.*?)\[/ERROR\]')
46
- _GRADCAM_RE = re.compile(r'\[GRADCAM_IMAGE:[^\]]+\]\n?')
47
- _RESPONSE_RE = re.compile(r'\[RESPONSE\]\n(.*?)\n\[/RESPONSE\]', re.DOTALL)
48
- _COMPLETE_RE = re.compile(r'\[COMPLETE\](.*?)\[/COMPLETE\]')
49
- _CONFIRM_RE = re.compile(r'\[CONFIRM:(\w+)\](.*?)\[/CONFIRM\]')
50
- _REFERENCES_RE = re.compile(r'\[REFERENCES\](.*?)\[/REFERENCES\]', re.DOTALL)
51
- _REF_RE = re.compile(r'\[REF:([^:]+):([^:]+):([^:]+):([^:]+):([^\]]+)\]')
52
-
53
-
54
- def format_output(raw_text: str, gradcam_base64: Optional[str] = None) -> str:
55
- """Convert tagged output to styled HTML"""
56
- html = raw_text
57
-
58
- # Stage headers
59
- html = _STAGE_RE.sub(
60
- r'<div class="stage"><span class="stage-indicator"></span><span class="stage-text">\2</span></div>',
61
- html
62
- )
63
-
64
- # Thinking
65
- html = _THINKING_RE.sub(r'<div class="thinking">\1</div>', html)
66
-
67
- # Observations
68
- html = _OBSERVATION_RE.sub(r'<div class="observation">\1</div>', html)
69
-
70
- # Tool outputs
71
- html = _TOOL_OUTPUT_RE.sub(
72
- r'<div class="tool-output"><div class="tool-header">\1</div><pre class="tool-content">\2</pre></div>',
73
- html
74
- )
75
-
76
- # Results
77
- html = _RESULT_RE.sub(r'<div class="result">\1</div>', html)
78
-
79
- # Errors
80
- html = _ERROR_RE.sub(r'<div class="error">\1</div>', html)
81
-
82
- # GradCAM image
83
- if gradcam_base64:
84
- img_html = f'<div class="gradcam-inline"><div class="gradcam-header">Attention Map</div><img src="data:image/png;base64,{gradcam_base64}" alt="Grad-CAM"></div>'
85
- html = _GRADCAM_RE.sub(img_html, html)
86
- else:
87
- html = _GRADCAM_RE.sub('', html)
88
-
89
- # Response section
90
- def format_response(match):
91
- content = match.group(1)
92
- parsed = parse_markdown(content)
93
- parsed = re.sub(r'\n\n+', '</p><p>', parsed)
94
- parsed = parsed.replace('\n', '<br>')
95
- return f'<div class="response"><p>{parsed}</p></div>'
96
-
97
- html = _RESPONSE_RE.sub(format_response, html)
98
-
99
- # Complete
100
- html = _COMPLETE_RE.sub(r'<div class="complete">\1</div>', html)
101
-
102
- # Confirmation
103
- html = _CONFIRM_RE.sub(
104
- r'<div class="confirm-box"><div class="confirm-text">\2</div></div>',
105
- html
106
- )
107
-
108
- # References
109
- def format_references(match):
110
- ref_content = match.group(1)
111
- refs_html = ['<div class="references"><div class="references-header">References</div><ul>']
112
- for ref_match in _REF_RE.finditer(ref_content):
113
- _, source, page, filename, superscript = ref_match.groups()
114
- refs_html.append(
115
- f'<li><a href="guidelines/{filename}#page={page}" target="_blank" class="ref-link">'
116
- f'<sup>{superscript}</sup> {source}, p.{page}</a></li>'
117
- )
118
- refs_html.append('</ul></div>')
119
- return '\n'.join(refs_html)
120
-
121
- html = _REFERENCES_RE.sub(format_references, html)
122
-
123
- # Convert newlines
124
- html = html.replace('\n', '<br>')
125
-
126
- return f'<div class="analysis-output">{html}</div>'
127
-
128
-
129
- def create_analysis_view():
130
- """
131
- Create the analysis view component.
132
-
133
- Returns:
134
- Tuple of (container, components dict)
135
- """
136
- with gr.Group(visible=False, elem_classes=["analysis-container"]) as container:
137
-
138
- with gr.Row():
139
- # Main content area
140
- with gr.Column(elem_classes=["main-content"]):
141
-
142
- # Input greeting (shown when no analysis yet)
143
- with gr.Group(visible=True, elem_classes=["input-greeting"]) as input_greeting:
144
- gr.Markdown("What would you like to analyze?", elem_classes=["greeting-title"])
145
- gr.Markdown("Upload an image and describe what you'd like to know", elem_classes=["greeting-subtitle"])
146
-
147
- with gr.Column(elem_classes=["input-box-container"]):
148
- message_input = gr.Textbox(
149
- placeholder="Describe the lesion or ask a question...",
150
- show_label=False,
151
- lines=3,
152
- elem_classes=["message-input"]
153
- )
154
-
155
- # Image upload (compact)
156
- image_input = gr.Image(
157
- label="",
158
- type="pil",
159
- height=180,
160
- elem_classes=["image-preview"],
161
- show_label=False
162
- )
163
-
164
- with gr.Row(elem_classes=["input-actions"]):
165
- upload_hint = gr.Markdown("*Upload a skin lesion image above*", visible=True)
166
- send_btn = gr.Button("Analyze", elem_classes=["send-btn"], interactive=False)
167
-
168
- # Chat/results view (shown after analysis starts)
169
- with gr.Group(visible=False, elem_classes=["chat-view"]) as chat_view:
170
- results_output = gr.HTML(
171
- value='<div class="analysis-output">Starting analysis...</div>',
172
- elem_classes=["results-area"]
173
- )
174
-
175
- # Confirmation buttons
176
- with gr.Group(visible=False, elem_classes=["confirm-buttons"]) as confirm_group:
177
- gr.Markdown("**Do you agree with this diagnosis?**")
178
- with gr.Row():
179
- confirm_yes_btn = gr.Button("Yes, continue", variant="primary", size="sm")
180
- confirm_no_btn = gr.Button("No, I disagree", variant="secondary", size="sm")
181
- feedback_input = gr.Textbox(
182
- label="Your assessment",
183
- placeholder="Enter your diagnosis...",
184
- visible=False
185
- )
186
- submit_feedback_btn = gr.Button("Submit", visible=False, size="sm")
187
-
188
- # Follow-up input
189
- with gr.Row(elem_classes=["chat-input-area"]):
190
- followup_input = gr.Textbox(
191
- placeholder="Ask a follow-up question...",
192
- show_label=False,
193
- lines=1
194
- )
195
- followup_btn = gr.Button("Send", size="sm", elem_classes=["send-btn"])
196
-
197
- components = {
198
- "input_greeting": input_greeting,
199
- "chat_view": chat_view,
200
- "message_input": message_input,
201
- "image_input": image_input,
202
- "send_btn": send_btn,
203
- "results_output": results_output,
204
- "confirm_group": confirm_group,
205
- "confirm_yes_btn": confirm_yes_btn,
206
- "confirm_no_btn": confirm_no_btn,
207
- "feedback_input": feedback_input,
208
- "submit_feedback_btn": submit_feedback_btn,
209
- "followup_input": followup_input,
210
- "followup_btn": followup_btn,
211
- "upload_hint": upload_hint
212
- }
213
-
214
- return container, components
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
frontend/components/patient_select.py DELETED
@@ -1,48 +0,0 @@
1
- """
2
- Patient Selection Component - Landing page for selecting/creating patients
3
- """
4
-
5
- import gradio as gr
6
- from typing import Callable, List
7
- from data.case_store import get_case_store, Case
8
-
9
-
10
- def create_patient_select(on_patient_selected: Callable[[str], None]) -> gr.Group:
11
- """
12
- Create the patient selection page component.
13
-
14
- Args:
15
- on_patient_selected: Callback when a patient is selected (receives case_id)
16
-
17
- Returns:
18
- gr.Group containing the patient selection UI
19
- """
20
- case_store = get_case_store()
21
-
22
- with gr.Group(visible=True, elem_classes=["patient-select-container"]) as container:
23
- gr.Markdown("# SkinProAI", elem_classes=["patient-select-title"])
24
- gr.Markdown("Select a patient to continue or create a new case", elem_classes=["patient-select-subtitle"])
25
-
26
- with gr.Column(elem_classes=["patient-grid"]):
27
- # Demo cases
28
- demo_melanoma_btn = gr.Button(
29
- "Demo: Melanocytic Lesion",
30
- elem_classes=["patient-card"]
31
- )
32
- demo_ak_btn = gr.Button(
33
- "Demo: Actinic Keratosis",
34
- elem_classes=["patient-card"]
35
- )
36
-
37
- # New patient button
38
- new_patient_btn = gr.Button(
39
- "+ New Patient",
40
- elem_classes=["new-patient-btn"]
41
- )
42
-
43
- return container, demo_melanoma_btn, demo_ak_btn, new_patient_btn
44
-
45
-
46
- def get_patient_cases() -> List[Case]:
47
- """Get list of all patient cases"""
48
- return get_case_store().list_cases()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
frontend/components/sidebar.py DELETED
@@ -1,55 +0,0 @@
1
- """
2
- Sidebar Component - Shows previous queries for a patient
3
- """
4
-
5
- import gradio as gr
6
- from datetime import datetime
7
- from typing import List, Optional
8
- from data.case_store import get_case_store, Instance
9
-
10
-
11
- def format_query_item(instance: Instance, index: int) -> str:
12
- """Format an instance as a query item for display"""
13
- diagnosis = "Pending"
14
- if instance.analysis and instance.analysis.get("diagnosis"):
15
- diag = instance.analysis["diagnosis"]
16
- diagnosis = diag.get("full_name", diag.get("class", "Unknown"))
17
-
18
- try:
19
- dt = datetime.fromisoformat(instance.created_at.replace('Z', '+00:00'))
20
- date_str = dt.strftime("%b %d, %H:%M")
21
- except:
22
- date_str = "Unknown"
23
-
24
- return f"Query #{index}: {diagnosis} ({date_str})"
25
-
26
-
27
- def create_sidebar():
28
- """
29
- Create the sidebar component for showing previous queries.
30
-
31
- Returns:
32
- Tuple of (container, components dict)
33
- """
34
- with gr.Column(visible=False, elem_classes=["query-sidebar"]) as container:
35
- gr.Markdown("### Previous Queries", elem_classes=["sidebar-header"])
36
-
37
- # Dynamic list of query buttons
38
- query_list = gr.Column(elem_id="query-list")
39
-
40
- # New query button
41
- new_query_btn = gr.Button("+ New Query", size="sm", variant="primary")
42
-
43
- components = {
44
- "query_list": query_list,
45
- "new_query_btn": new_query_btn
46
- }
47
-
48
- return container, components
49
-
50
-
51
- def get_queries_for_case(case_id: str) -> List[Instance]:
52
- """Get all instances/queries for a case"""
53
- if not case_id:
54
- return []
55
- return get_case_store().list_instances(case_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
frontend/components/styles.py DELETED
@@ -1,517 +0,0 @@
1
- """
2
- CSS Styles for SkinProAI components
3
- """
4
-
5
- MAIN_CSS = """
6
- @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600&display=swap');
7
-
8
- * {
9
- font-family: 'Inter', sans-serif !important;
10
- }
11
-
12
- .gradio-container {
13
- max-width: 1200px !important;
14
- margin: 0 auto !important;
15
- }
16
-
17
- /* Hide Gradio footer */
18
- .gradio-container footer { display: none !important; }
19
-
20
- /* ============================================
21
- PATIENT SELECTION PAGE
22
- ============================================ */
23
-
24
- .patient-select-container {
25
- min-height: 80vh;
26
- display: flex;
27
- flex-direction: column;
28
- align-items: center;
29
- justify-content: center;
30
- padding: 40px 20px;
31
- }
32
-
33
- .patient-select-title {
34
- font-size: 32px;
35
- font-weight: 600;
36
- color: #111827;
37
- margin-bottom: 8px;
38
- text-align: center;
39
- }
40
-
41
- .patient-select-subtitle {
42
- font-size: 16px;
43
- color: #6b7280;
44
- margin-bottom: 40px;
45
- text-align: center;
46
- }
47
-
48
- .patient-grid {
49
- display: flex;
50
- gap: 20px;
51
- flex-wrap: wrap;
52
- justify-content: center;
53
- max-width: 800px;
54
- }
55
-
56
- .patient-card {
57
- background: white !important;
58
- border: 2px solid #e5e7eb !important;
59
- border-radius: 16px !important;
60
- padding: 24px 32px !important;
61
- min-width: 200px !important;
62
- cursor: pointer;
63
- transition: all 0.2s ease !important;
64
- }
65
-
66
- .patient-card:hover {
67
- border-color: #6366f1 !important;
68
- box-shadow: 0 8px 25px rgba(99, 102, 241, 0.15) !important;
69
- transform: translateY(-2px);
70
- }
71
-
72
- .new-patient-btn {
73
- background: #6366f1 !important;
74
- color: white !important;
75
- border: none !important;
76
- border-radius: 12px !important;
77
- padding: 16px 32px !important;
78
- font-weight: 500 !important;
79
- margin-top: 24px;
80
- }
81
-
82
- .new-patient-btn:hover {
83
- background: #4f46e5 !important;
84
- }
85
-
86
- /* ============================================
87
- ANALYSIS PAGE - MAIN LAYOUT
88
- ============================================ */
89
-
90
- .analysis-container {
91
- display: flex;
92
- height: calc(100vh - 80px);
93
- min-height: 600px;
94
- }
95
-
96
- /* Sidebar */
97
- .query-sidebar {
98
- width: 280px;
99
- background: #f9fafb;
100
- border-right: 1px solid #e5e7eb;
101
- padding: 20px;
102
- overflow-y: auto;
103
- flex-shrink: 0;
104
- }
105
-
106
- .sidebar-header {
107
- font-size: 14px;
108
- font-weight: 600;
109
- color: #374151;
110
- margin-bottom: 16px;
111
- padding-bottom: 12px;
112
- border-bottom: 1px solid #e5e7eb;
113
- }
114
-
115
- .query-item {
116
- background: white;
117
- border: 1px solid #e5e7eb;
118
- border-radius: 8px;
119
- padding: 12px;
120
- margin-bottom: 8px;
121
- cursor: pointer;
122
- transition: all 0.15s;
123
- }
124
-
125
- .query-item:hover {
126
- border-color: #6366f1;
127
- background: #f5f3ff;
128
- }
129
-
130
- .query-item-title {
131
- font-size: 13px;
132
- font-weight: 500;
133
- color: #111827;
134
- margin-bottom: 4px;
135
- }
136
-
137
- .query-item-meta {
138
- font-size: 11px;
139
- color: #6b7280;
140
- }
141
-
142
- /* Main content area */
143
- .main-content {
144
- flex: 1;
145
- display: flex;
146
- flex-direction: column;
147
- padding: 24px;
148
- overflow: hidden;
149
- }
150
-
151
- /* ============================================
152
- INPUT AREA (Greeting style)
153
- ============================================ */
154
-
155
- .input-greeting {
156
- flex: 1;
157
- display: flex;
158
- flex-direction: column;
159
- align-items: center;
160
- justify-content: center;
161
- padding: 40px;
162
- }
163
-
164
- .greeting-title {
165
- font-size: 24px;
166
- font-weight: 600;
167
- color: #111827;
168
- margin-bottom: 8px;
169
- }
170
-
171
- .greeting-subtitle {
172
- font-size: 14px;
173
- color: #6b7280;
174
- margin-bottom: 32px;
175
- }
176
-
177
- .input-box-container {
178
- width: 100%;
179
- max-width: 600px;
180
- background: white;
181
- border: 2px solid #e5e7eb;
182
- border-radius: 16px;
183
- padding: 20px;
184
- transition: border-color 0.2s;
185
- }
186
-
187
- .input-box-container:focus-within {
188
- border-color: #6366f1;
189
- }
190
-
191
- .message-input textarea {
192
- border: none !important;
193
- resize: none !important;
194
- font-size: 15px !important;
195
- line-height: 1.5 !important;
196
- padding: 0 !important;
197
- }
198
-
199
- .message-input textarea:focus {
200
- box-shadow: none !important;
201
- }
202
-
203
- .input-actions {
204
- display: flex;
205
- align-items: center;
206
- justify-content: space-between;
207
- margin-top: 16px;
208
- padding-top: 16px;
209
- border-top: 1px solid #f3f4f6;
210
- }
211
-
212
- .upload-btn {
213
- background: #f3f4f6 !important;
214
- color: #374151 !important;
215
- border: 1px solid #e5e7eb !important;
216
- border-radius: 8px !important;
217
- padding: 8px 16px !important;
218
- font-size: 13px !important;
219
- }
220
-
221
- .upload-btn:hover {
222
- background: #e5e7eb !important;
223
- }
224
-
225
- .send-btn {
226
- background: #6366f1 !important;
227
- color: white !important;
228
- border: none !important;
229
- border-radius: 8px !important;
230
- padding: 10px 24px !important;
231
- font-weight: 500 !important;
232
- }
233
-
234
- .send-btn:hover {
235
- background: #4f46e5 !important;
236
- }
237
-
238
- .send-btn:disabled {
239
- background: #d1d5db !important;
240
- cursor: not-allowed;
241
- }
242
-
243
- /* Image preview */
244
- .image-preview {
245
- margin-top: 16px;
246
- border-radius: 12px;
247
- overflow: hidden;
248
- max-height: 200px;
249
- }
250
-
251
- .image-preview img {
252
- max-height: 200px;
253
- object-fit: contain;
254
- }
255
-
256
- /* ============================================
257
- CHAT/RESULTS VIEW
258
- ============================================ */
259
-
260
- .chat-view {
261
- flex: 1;
262
- display: flex;
263
- flex-direction: column;
264
- overflow: hidden;
265
- }
266
-
267
- .results-area {
268
- flex: 1;
269
- overflow-y: auto;
270
- padding: 20px;
271
- background: #ffffff;
272
- border: 1px solid #e5e7eb;
273
- border-radius: 12px;
274
- margin-bottom: 16px;
275
- }
276
-
277
- /* Analysis output styling */
278
- .analysis-output {
279
- line-height: 1.6;
280
- color: #333;
281
- }
282
-
283
- .stage {
284
- display: flex;
285
- align-items: center;
286
- gap: 10px;
287
- padding: 8px 0;
288
- font-weight: 500;
289
- color: #1a1a1a;
290
- margin-top: 12px;
291
- }
292
-
293
- .stage-indicator {
294
- width: 8px;
295
- height: 8px;
296
- background: #6366f1;
297
- border-radius: 50%;
298
- animation: pulse 1.5s ease-in-out infinite;
299
- }
300
-
301
- @keyframes pulse {
302
- 0%, 100% { opacity: 1; transform: scale(1); }
303
- 50% { opacity: 0.5; transform: scale(0.8); }
304
- }
305
-
306
- .thinking {
307
- color: #6b7280;
308
- font-style: italic;
309
- font-size: 13px;
310
- padding: 4px 0 4px 16px;
311
- border-left: 2px solid #e5e7eb;
312
- margin: 4px 0;
313
- }
314
-
315
- .observation {
316
- color: #374151;
317
- font-size: 13px;
318
- padding: 4px 0 4px 16px;
319
- }
320
-
321
- .tool-output {
322
- background: #f8fafc;
323
- border-radius: 8px;
324
- margin: 12px 0;
325
- overflow: hidden;
326
- border: 1px solid #e2e8f0;
327
- }
328
-
329
- .tool-header {
330
- background: #f1f5f9;
331
- padding: 8px 12px;
332
- font-weight: 500;
333
- font-size: 13px;
334
- color: #475569;
335
- border-bottom: 1px solid #e2e8f0;
336
- }
337
-
338
- .tool-content {
339
- padding: 12px;
340
- margin: 0;
341
- font-family: 'SF Mono', Monaco, monospace !important;
342
- font-size: 12px;
343
- line-height: 1.5;
344
- white-space: pre-wrap;
345
- color: #334155;
346
- }
347
-
348
- .result {
349
- background: #ecfdf5;
350
- border: 1px solid #a7f3d0;
351
- border-radius: 8px;
352
- padding: 12px 16px;
353
- margin: 12px 0;
354
- font-weight: 500;
355
- color: #065f46;
356
- }
357
-
358
- .error {
359
- background: #fef2f2;
360
- border: 1px solid #fecaca;
361
- border-radius: 8px;
362
- padding: 12px 16px;
363
- margin: 8px 0;
364
- color: #b91c1c;
365
- }
366
-
367
- .response {
368
- background: #ffffff;
369
- border: 1px solid #e5e7eb;
370
- border-radius: 8px;
371
- padding: 16px;
372
- margin: 16px 0;
373
- line-height: 1.7;
374
- }
375
-
376
- .response ul, .response ol {
377
- margin: 8px 0;
378
- padding-left: 24px;
379
- }
380
-
381
- .response li {
382
- margin: 4px 0;
383
- }
384
-
385
- .complete {
386
- color: #6b7280;
387
- font-size: 12px;
388
- padding: 8px 0;
389
- text-align: center;
390
- }
391
-
392
- /* Confirmation */
393
- .confirm-box {
394
- background: #eff6ff;
395
- border: 1px solid #bfdbfe;
396
- border-radius: 8px;
397
- padding: 16px;
398
- margin: 16px 0;
399
- text-align: center;
400
- }
401
-
402
- .confirm-buttons {
403
- background: #f0f9ff;
404
- border: 1px solid #bae6fd;
405
- border-radius: 8px;
406
- padding: 12px;
407
- margin-top: 12px;
408
- }
409
-
410
- /* References */
411
- .references {
412
- background: #f9fafb;
413
- border: 1px solid #e5e7eb;
414
- border-radius: 8px;
415
- margin: 16px 0;
416
- overflow: hidden;
417
- }
418
-
419
- .references-header {
420
- background: #f3f4f6;
421
- padding: 8px 12px;
422
- font-weight: 500;
423
- font-size: 13px;
424
- border-bottom: 1px solid #e5e7eb;
425
- }
426
-
427
- .references ul {
428
- list-style: none;
429
- padding: 12px;
430
- margin: 0;
431
- }
432
-
433
- .ref-link {
434
- color: #6366f1;
435
- text-decoration: none;
436
- font-size: 13px;
437
- }
438
-
439
- .ref-link:hover {
440
- text-decoration: underline;
441
- }
442
-
443
- /* GradCAM */
444
- .gradcam-inline {
445
- margin: 16px 0;
446
- background: #f8fafc;
447
- border-radius: 8px;
448
- overflow: hidden;
449
- border: 1px solid #e2e8f0;
450
- }
451
-
452
- .gradcam-header {
453
- background: #f1f5f9;
454
- padding: 8px 12px;
455
- font-weight: 500;
456
- font-size: 13px;
457
- border-bottom: 1px solid #e2e8f0;
458
- }
459
-
460
- .gradcam-inline img {
461
- max-width: 100%;
462
- max-height: 300px;
463
- display: block;
464
- margin: 12px auto;
465
- }
466
-
467
- /* Chat input at bottom */
468
- .chat-input-area {
469
- background: white;
470
- border: 1px solid #e5e7eb;
471
- border-radius: 12px;
472
- padding: 12px 16px;
473
- display: flex;
474
- gap: 12px;
475
- align-items: flex-end;
476
- }
477
-
478
- .chat-input-area textarea {
479
- flex: 1;
480
- border: none !important;
481
- resize: none !important;
482
- font-size: 14px !important;
483
- }
484
-
485
- /* ============================================
486
- HEADER
487
- ============================================ */
488
-
489
- .app-header {
490
- display: flex;
491
- align-items: center;
492
- justify-content: space-between;
493
- padding: 16px 24px;
494
- border-bottom: 1px solid #e5e7eb;
495
- background: white;
496
- }
497
-
498
- .app-title {
499
- font-size: 20px;
500
- font-weight: 600;
501
- color: #111827;
502
- }
503
-
504
- .back-btn {
505
- background: transparent !important;
506
- color: #6b7280 !important;
507
- border: 1px solid #e5e7eb !important;
508
- border-radius: 8px !important;
509
- padding: 8px 16px !important;
510
- font-size: 13px !important;
511
- }
512
-
513
- .back-btn:hover {
514
- background: #f9fafb !important;
515
- color: #111827 !important;
516
- }
517
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/explainability.py DELETED
@@ -1,183 +0,0 @@
1
- # models/explainability.py
2
-
3
- import torch
4
- import torch.nn.functional as F
5
- import numpy as np
6
- import cv2
7
- from typing import Tuple
8
- from PIL import Image
9
-
10
- class GradCAM:
11
- """
12
- Gradient-weighted Class Activation Mapping
13
- Shows which regions of image are important for prediction
14
- """
15
-
16
- def __init__(self, model: torch.nn.Module, target_layer: str = None):
17
- """
18
- Args:
19
- model: The neural network
20
- target_layer: Layer name to compute CAM on (usually last conv layer)
21
- """
22
- self.model = model
23
- self.gradients = None
24
- self.activations = None
25
-
26
- # Auto-detect target layer if not specified
27
- if target_layer is None:
28
- # Use last ConvNeXt stage
29
- self.target_layer = model.convnext.stages[-1]
30
- else:
31
- self.target_layer = dict(model.named_modules())[target_layer]
32
-
33
- # Register hooks
34
- self.target_layer.register_forward_hook(self._save_activation)
35
- self.target_layer.register_full_backward_hook(self._save_gradient)
36
-
37
- def _save_activation(self, module, input, output):
38
- """Save forward activations"""
39
- self.activations = output.detach()
40
-
41
- def _save_gradient(self, module, grad_input, grad_output):
42
- """Save backward gradients"""
43
- self.gradients = grad_output[0].detach()
44
-
45
- def generate_cam(
46
- self,
47
- image: torch.Tensor,
48
- target_class: int = None
49
- ) -> np.ndarray:
50
- """
51
- Generate Class Activation Map
52
-
53
- Args:
54
- image: Input image [1, 3, H, W]
55
- target_class: Class to generate CAM for (None = predicted class)
56
-
57
- Returns:
58
- cam: Activation map [H, W] normalized to 0-1
59
- """
60
- self.model.eval()
61
-
62
- # Forward pass
63
- output = self.model(image)
64
-
65
- # Use predicted class if not specified
66
- if target_class is None:
67
- target_class = output.argmax(dim=1).item()
68
-
69
- # Zero gradients
70
- self.model.zero_grad()
71
-
72
- # Backward pass for target class
73
- output[0, target_class].backward()
74
-
75
- # Get gradients and activations
76
- gradients = self.gradients[0] # [C, H, W]
77
- activations = self.activations[0] # [C, H, W]
78
-
79
- # Global average pooling of gradients
80
- weights = gradients.mean(dim=(1, 2)) # [C]
81
-
82
- # Weighted sum of activations
83
- cam = torch.zeros(activations.shape[1:], dtype=torch.float32)
84
- for i, w in enumerate(weights):
85
- cam += w * activations[i]
86
-
87
- # ReLU
88
- cam = F.relu(cam)
89
-
90
- # Normalize to 0-1
91
- cam = cam.cpu().numpy()
92
- cam = cam - cam.min()
93
- if cam.max() > 0:
94
- cam = cam / cam.max()
95
-
96
- return cam
97
-
98
- def overlay_cam_on_image(
99
- self,
100
- image: np.ndarray, # [H, W, 3] RGB
101
- cam: np.ndarray, # [h, w]
102
- alpha: float = 0.5,
103
- colormap: int = cv2.COLORMAP_JET
104
- ) -> np.ndarray:
105
- """
106
- Overlay CAM heatmap on original image
107
-
108
- Returns:
109
- overlay: [H, W, 3] RGB image with heatmap
110
- """
111
- H, W = image.shape[:2]
112
-
113
- # Resize CAM to image size
114
- cam_resized = cv2.resize(cam, (W, H))
115
-
116
- # Convert to heatmap
117
- heatmap = cv2.applyColorMap(
118
- np.uint8(255 * cam_resized),
119
- colormap
120
- )
121
- heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
122
-
123
- # Blend with original image
124
- overlay = (alpha * heatmap + (1 - alpha) * image).astype(np.uint8)
125
-
126
- return overlay
127
-
128
- class AttentionVisualizer:
129
- """Visualize MedSigLIP attention maps"""
130
-
131
- def __init__(self, model):
132
- self.model = model
133
-
134
- def get_attention_maps(self, image: torch.Tensor) -> np.ndarray:
135
- """
136
- Extract attention maps from MedSigLIP
137
-
138
- Returns:
139
- attention: [num_heads, H, W] attention weights
140
- """
141
- # Forward pass
142
- with torch.no_grad():
143
- _ = self.model(image)
144
-
145
- # Get last layer attention from MedSigLIP
146
- # Shape: [batch, num_heads, seq_len, seq_len]
147
- attention = self.model.medsiglip_features
148
-
149
- # Average across heads and extract spatial attention
150
- # This is model-dependent - adjust based on MedSigLIP architecture
151
-
152
- # Placeholder implementation
153
- # You'll need to adapt this to your specific MedSigLIP implementation
154
- return np.random.rand(14, 14) # Placeholder
155
-
156
- def overlay_attention(
157
- self,
158
- image: np.ndarray,
159
- attention: np.ndarray,
160
- alpha: float = 0.6
161
- ) -> np.ndarray:
162
- """Overlay attention map on image"""
163
- H, W = image.shape[:2]
164
-
165
- # Resize attention to image size
166
- attention_resized = cv2.resize(attention, (W, H))
167
-
168
- # Normalize
169
- attention_resized = (attention_resized - attention_resized.min())
170
- if attention_resized.max() > 0:
171
- attention_resized = attention_resized / attention_resized.max()
172
-
173
- # Create colored overlay
174
- heatmap = cv2.applyColorMap(
175
- np.uint8(255 * attention_resized),
176
- cv2.COLORMAP_VIRIDIS
177
- )
178
- heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
179
-
180
- # Blend
181
- overlay = (alpha * heatmap + (1 - alpha) * image).astype(np.uint8)
182
-
183
- return overlay
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/medsiglip_convnext_fusion.py DELETED
@@ -1,224 +0,0 @@
1
- # models/medsiglip_convnext_fusion.py
2
-
3
- import torch
4
- import torch.nn as nn
5
- from typing import Dict, List, Tuple, Optional
6
- import numpy as np
7
- import timm
8
- from transformers import AutoModel, AutoProcessor
9
-
10
- class MedSigLIPConvNeXtFusion(nn.Module):
11
- """
12
- Your trained MedSigLIP-ConvNeXt fusion model from MILK10 challenge
13
- Supports 11-class skin lesion classification
14
- """
15
-
16
- # Class names from your training
17
- CLASS_NAMES = [
18
- 'AKIEC', # Actinic Keratoses and Intraepithelial Carcinoma
19
- 'BCC', # Basal Cell Carcinoma
20
- 'BEN_OTH', # Benign Other
21
- 'BKL', # Benign Keratosis-like Lesions
22
- 'DF', # Dermatofibroma
23
- 'INF', # Inflammatory
24
- 'MAL_OTH', # Malignant Other
25
- 'MEL', # Melanoma
26
- 'NV', # Melanocytic Nevi
27
- 'SCCKA', # Squamous Cell Carcinoma and Keratoacanthoma
28
- 'VASC' # Vascular Lesions
29
- ]
30
-
31
- def __init__(
32
- self,
33
- num_classes: int = 11,
34
- medsiglip_model: str = "google/medsiglip-base",
35
- convnext_variant: str = "convnext_base",
36
- fusion_dim: int = 512,
37
- dropout: float = 0.3,
38
- metadata_dim: int = 20 # For metadata features
39
- ):
40
- super().__init__()
41
-
42
- self.num_classes = num_classes
43
-
44
- # MedSigLIP Vision Encoder
45
- print(f"Loading MedSigLIP: {medsiglip_model}")
46
- self.medsiglip = AutoModel.from_pretrained(medsiglip_model)
47
- self.medsiglip_processor = AutoProcessor.from_pretrained(medsiglip_model)
48
-
49
- # ConvNeXt Backbone
50
- print(f"Loading ConvNeXt: {convnext_variant}")
51
- self.convnext = timm.create_model(
52
- convnext_variant,
53
- pretrained=True,
54
- num_classes=0,
55
- global_pool='avg'
56
- )
57
-
58
- # Feature dimensions
59
- self.medsiglip_dim = self.medsiglip.config.hidden_size # 768
60
- self.convnext_dim = self.convnext.num_features # 1024
61
-
62
- # Optional metadata branch
63
- self.use_metadata = metadata_dim > 0
64
- if self.use_metadata:
65
- self.metadata_encoder = nn.Sequential(
66
- nn.Linear(metadata_dim, 64),
67
- nn.LayerNorm(64),
68
- nn.GELU(),
69
- nn.Dropout(0.2),
70
- nn.Linear(64, 32)
71
- )
72
- total_dim = self.medsiglip_dim + self.convnext_dim + 32
73
- else:
74
- total_dim = self.medsiglip_dim + self.convnext_dim
75
-
76
- # Fusion layers
77
- self.fusion = nn.Sequential(
78
- nn.Linear(total_dim, fusion_dim),
79
- nn.LayerNorm(fusion_dim),
80
- nn.GELU(),
81
- nn.Dropout(dropout),
82
- nn.Linear(fusion_dim, fusion_dim // 2),
83
- nn.LayerNorm(fusion_dim // 2),
84
- nn.GELU(),
85
- nn.Dropout(dropout)
86
- )
87
-
88
- # Classification head
89
- self.classifier = nn.Linear(fusion_dim // 2, num_classes)
90
-
91
- # Store intermediate features for Grad-CAM
92
- self.convnext_features = None
93
- self.medsiglip_features = None
94
-
95
- # Register hooks
96
- self.convnext.stages[-1].register_forward_hook(self._save_convnext_features)
97
-
98
- def _save_convnext_features(self, module, input, output):
99
- """Hook to save ConvNeXt feature maps for Grad-CAM"""
100
- self.convnext_features = output
101
-
102
- def forward(
103
- self,
104
- image: torch.Tensor,
105
- metadata: Optional[torch.Tensor] = None
106
- ) -> torch.Tensor:
107
- """
108
- Forward pass
109
-
110
- Args:
111
- image: [B, 3, H, W] tensor
112
- metadata: [B, metadata_dim] optional metadata features
113
-
114
- Returns:
115
- logits: [B, num_classes]
116
- """
117
- # MedSigLIP features
118
- medsiglip_out = self.medsiglip.vision_model(image)
119
- medsiglip_features = medsiglip_out.pooler_output # [B, 768]
120
-
121
- # ConvNeXt features
122
- convnext_features = self.convnext(image) # [B, 1024]
123
-
124
- # Concatenate vision features
125
- fused = torch.cat([medsiglip_features, convnext_features], dim=1)
126
-
127
- # Add metadata if available
128
- if self.use_metadata and metadata is not None:
129
- metadata_features = self.metadata_encoder(metadata)
130
- fused = torch.cat([fused, metadata_features], dim=1)
131
-
132
- # Fusion layers
133
- fused = self.fusion(fused)
134
-
135
- # Classification
136
- logits = self.classifier(fused)
137
-
138
- return logits
139
-
140
- def predict(
141
- self,
142
- image: torch.Tensor,
143
- metadata: Optional[torch.Tensor] = None,
144
- top_k: int = 5
145
- ) -> Dict:
146
- """
147
- Get predictions with probabilities
148
-
149
- Args:
150
- image: [B, 3, H, W] or [3, H, W]
151
- metadata: Optional metadata features
152
- top_k: Number of top predictions
153
-
154
- Returns:
155
- Dictionary with predictions and features
156
- """
157
- if image.dim() == 3:
158
- image = image.unsqueeze(0)
159
-
160
- self.eval()
161
- with torch.no_grad():
162
- logits = self.forward(image, metadata)
163
- probs = torch.softmax(logits, dim=1)
164
-
165
- # Top-k predictions
166
- top_probs, top_indices = torch.topk(
167
- probs,
168
- k=min(top_k, self.num_classes),
169
- dim=1
170
- )
171
-
172
- # Format results
173
- predictions = []
174
- for i in range(top_probs.size(1)):
175
- predictions.append({
176
- 'class': self.CLASS_NAMES[top_indices[0, i].item()],
177
- 'probability': top_probs[0, i].item(),
178
- 'class_idx': top_indices[0, i].item()
179
- })
180
-
181
- return {
182
- 'predictions': predictions,
183
- 'all_probabilities': probs[0].cpu().numpy(),
184
- 'logits': logits[0].cpu().numpy(),
185
- 'convnext_features': self.convnext_features,
186
- 'medsiglip_features': self.medsiglip_features
187
- }
188
-
189
- @classmethod
190
- def load_from_checkpoint(
191
- cls,
192
- medsiglip_path: str,
193
- convnext_path: Optional[str] = None,
194
- ensemble_weights: tuple = (0.6, 0.4),
195
- device: str = 'cpu'
196
- ):
197
- """
198
- Load model from your training checkpoints
199
-
200
- Args:
201
- medsiglip_path: Path to MedSigLIP model weights
202
- convnext_path: Path to ConvNeXt model weights (optional)
203
- ensemble_weights: (w_medsiglip, w_convnext)
204
- device: Device to load on
205
- """
206
- model = cls(num_classes=11)
207
-
208
- # Load MedSigLIP weights
209
- print(f"Loading MedSigLIP from: {medsiglip_path}")
210
- medsiglip_state = torch.load(medsiglip_path, map_location=device)
211
-
212
- # Handle different checkpoint formats
213
- if 'model_state_dict' in medsiglip_state:
214
- model.load_state_dict(medsiglip_state['model_state_dict'])
215
- else:
216
- model.load_state_dict(medsiglip_state)
217
-
218
- # Store ensemble weights for prediction fusion
219
- model.ensemble_weights = ensemble_weights
220
-
221
- model.to(device)
222
- model.eval()
223
-
224
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/monet_concepts.py DELETED
@@ -1,332 +0,0 @@
1
- # models/monet_concepts.py
2
-
3
- import torch
4
- import numpy as np
5
- from typing import Dict, List
6
- from dataclasses import dataclass
7
-
8
- @dataclass
9
- class ConceptScore:
10
- """Single MONET concept with score and evidence"""
11
- name: str
12
- score: float
13
- confidence: float
14
- description: str
15
- clinical_relevance: str # How this affects diagnosis
16
-
17
- class MONETConceptScorer:
18
- """
19
- MONET concept scoring using your trained metadata patterns
20
- Integrates the boosting logic from your ensemble code
21
- """
22
-
23
- # MONET concepts used in your training
24
- CONCEPT_DEFINITIONS = {
25
- 'MONET_ulceration_crust': {
26
- 'description': 'Ulceration or crusting present',
27
- 'high_in': ['SCCKA', 'BCC', 'MAL_OTH'],
28
- 'low_in': ['NV', 'BKL'],
29
- 'threshold_high': 0.50
30
- },
31
- 'MONET_erythema': {
32
- 'description': 'Redness or inflammation',
33
- 'high_in': ['INF', 'BCC', 'SCCKA'],
34
- 'low_in': ['MEL', 'NV'],
35
- 'threshold_high': 0.40
36
- },
37
- 'MONET_pigmented': {
38
- 'description': 'Pigmentation present',
39
- 'high_in': ['MEL', 'NV', 'BKL'],
40
- 'low_in': ['BCC', 'SCCKA', 'INF'],
41
- 'threshold_high': 0.55
42
- },
43
- 'MONET_vasculature_vessels': {
44
- 'description': 'Vascular structures visible',
45
- 'high_in': ['VASC', 'BCC'],
46
- 'low_in': ['MEL', 'NV'],
47
- 'threshold_high': 0.35
48
- },
49
- 'MONET_hair': {
50
- 'description': 'Hair follicles present',
51
- 'high_in': ['NV', 'BKL'],
52
- 'low_in': ['BCC', 'MEL'],
53
- 'threshold_high': 0.30
54
- },
55
- 'MONET_gel_water_drop_fluid_dermoscopy_liquid': {
56
- 'description': 'Gel/fluid artifacts',
57
- 'high_in': [],
58
- 'low_in': [],
59
- 'threshold_high': 0.40
60
- },
61
- 'MONET_skin_markings_pen_ink_purple_pen': {
62
- 'description': 'Pen markings present',
63
- 'high_in': [],
64
- 'low_in': [],
65
- 'threshold_high': 0.40
66
- }
67
- }
68
-
69
- # Class-specific patterns from your metadata boosting
70
- CLASS_PATTERNS = {
71
- 'MAL_OTH': {
72
- 'sex': 'male', # 88.9% male
73
- 'site_preference': 'trunk',
74
- 'age_range': (60, 80),
75
- 'key_concepts': {'MONET_ulceration_crust': 0.35}
76
- },
77
- 'INF': {
78
- 'key_concepts': {
79
- 'MONET_erythema': 0.42,
80
- 'MONET_pigmented': (None, 0.30) # Low pigmentation
81
- }
82
- },
83
- 'BEN_OTH': {
84
- 'site_preference': ['head', 'neck', 'face'], # 47.7%
85
- 'key_concepts': {'MONET_pigmented': (0.30, 0.50)}
86
- },
87
- 'DF': {
88
- 'site_preference': ['lower', 'leg', 'ankle', 'foot'], # 65.4%
89
- 'age_range': (40, 65)
90
- },
91
- 'SCCKA': {
92
- 'age_range': (65, None),
93
- 'key_concepts': {
94
- 'MONET_ulceration_crust': 0.50,
95
- 'MONET_pigmented': (None, 0.15)
96
- }
97
- },
98
- 'MEL': {
99
- 'age_range': (55, None), # 61.8 years average
100
- 'key_concepts': {'MONET_pigmented': 0.55}
101
- },
102
- 'NV': {
103
- 'age_range': (None, 45), # 42.0 years average
104
- 'key_concepts': {'MONET_pigmented': 0.55}
105
- }
106
- }
107
-
108
- def __init__(self):
109
- """Initialize MONET scorer with class patterns"""
110
- self.class_names = [
111
- 'AKIEC', 'BCC', 'BEN_OTH', 'BKL', 'DF',
112
- 'INF', 'MAL_OTH', 'MEL', 'NV', 'SCCKA', 'VASC'
113
- ]
114
-
115
- def compute_concept_scores(
116
- self,
117
- metadata: Dict[str, float]
118
- ) -> Dict[str, ConceptScore]:
119
- """
120
- Compute MONET concept scores from metadata
121
-
122
- Args:
123
- metadata: Dictionary with MONET scores, age, sex, site, etc.
124
-
125
- Returns:
126
- Dictionary of concept scores
127
- """
128
- concept_scores = {}
129
-
130
- for concept_name, definition in self.CONCEPT_DEFINITIONS.items():
131
- score = metadata.get(concept_name, 0.0)
132
-
133
- # Determine confidence based on how extreme the score is
134
- if score > definition['threshold_high']:
135
- confidence = min((score - definition['threshold_high']) / 0.2, 1.0)
136
- level = "HIGH"
137
- elif score < 0.2:
138
- confidence = min((0.2 - score) / 0.2, 1.0)
139
- level = "LOW"
140
- else:
141
- confidence = 0.5
142
- level = "MODERATE"
143
-
144
- # Clinical relevance
145
- if level == "HIGH":
146
- relevant_classes = definition['high_in']
147
- clinical_relevance = f"Supports: {', '.join(relevant_classes)}"
148
- elif level == "LOW":
149
- excluded_classes = definition['low_in']
150
- clinical_relevance = f"Against: {', '.join(excluded_classes)}"
151
- else:
152
- clinical_relevance = "Non-specific"
153
-
154
- concept_scores[concept_name] = ConceptScore(
155
- name=concept_name.replace('MONET_', '').replace('_', ' ').title(),
156
- score=score,
157
- confidence=confidence,
158
- description=f"{definition['description']} ({level})",
159
- clinical_relevance=clinical_relevance
160
- )
161
-
162
- return concept_scores
163
-
164
- def apply_metadata_boosting(
165
- self,
166
- probs: np.ndarray,
167
- metadata: Dict
168
- ) -> np.ndarray:
169
- """
170
- Apply your metadata boosting logic
171
- This is directly from your ensemble optimization code
172
-
173
- Args:
174
- probs: [11] probability array
175
- metadata: Dictionary with age, sex, site, MONET scores
176
-
177
- Returns:
178
- boosted_probs: [11] adjusted probabilities
179
- """
180
- boosted_probs = probs.copy()
181
-
182
- # 1. MAL_OTH boosting
183
- if metadata.get('sex') == 'male':
184
- site = str(metadata.get('site', '')).lower()
185
- if 'trunk' in site:
186
- age = metadata.get('age_approx', 60)
187
- ulceration = metadata.get('MONET_ulceration_crust', 0)
188
-
189
- score = 0
190
- score += 3 if metadata.get('sex') == 'male' else 0
191
- score += 2 if 'trunk' in site else 0
192
- score += 1 if 60 <= age <= 80 else 0
193
- score += 2 if ulceration > 0.35 else 0
194
-
195
- confidence = score / 8.0
196
- if confidence > 0.5:
197
- boosted_probs[6] *= (1.0 + confidence) # MAL_OTH index
198
-
199
- # 2. INF boosting
200
- erythema = metadata.get('MONET_erythema', 0)
201
- pigmentation = metadata.get('MONET_pigmented', 0)
202
-
203
- if erythema > 0.42 and pigmentation < 0.30:
204
- confidence = min((erythema - 0.42) / 0.10 + 0.5, 1.0)
205
- boosted_probs[5] *= (1.0 + confidence * 0.8) # INF index
206
-
207
- # 3. BEN_OTH boosting
208
- site = str(metadata.get('site', '')).lower()
209
- is_head_neck = any(x in site for x in ['head', 'neck', 'face'])
210
-
211
- if is_head_neck and 0.30 < pigmentation < 0.50:
212
- ulceration = metadata.get('MONET_ulceration_crust', 0)
213
- confidence = 0.7 if ulceration < 0.30 else 0.4
214
- boosted_probs[2] *= (1.0 + confidence * 0.5) # BEN_OTH index
215
-
216
- # 4. DF boosting
217
- is_lower_ext = any(x in site for x in ['lower', 'leg', 'ankle', 'foot'])
218
-
219
- if is_lower_ext:
220
- age = metadata.get('age_approx', 60)
221
- if 40 <= age <= 65:
222
- boosted_probs[4] *= 1.8 # DF index
223
- elif 30 <= age <= 75:
224
- boosted_probs[4] *= 1.5
225
-
226
- # 5. SCCKA boosting
227
- ulceration = metadata.get('MONET_ulceration_crust', 0)
228
- age = metadata.get('age_approx', 60)
229
-
230
- if ulceration > 0.50 and age >= 65 and pigmentation < 0.15:
231
- boosted_probs[9] *= 1.9 # SCCKA index
232
- elif ulceration > 0.45 and age >= 60 and pigmentation < 0.20:
233
- boosted_probs[9] *= 1.5
234
-
235
- # 6. MEL vs NV age separation
236
- if pigmentation > 0.55:
237
- if age >= 55:
238
- age_score = min((age - 55) / 20.0, 1.0)
239
- boosted_probs[7] *= (1.0 + age_score * 0.5) # MEL
240
- boosted_probs[8] *= (1.0 - age_score * 0.3) # NV
241
- elif age <= 45:
242
- age_score = min((45 - age) / 30.0, 1.0)
243
- boosted_probs[7] *= (1.0 - age_score * 0.3) # MEL
244
- boosted_probs[8] *= (1.0 + age_score * 0.5) # NV
245
-
246
- # 7. Exclusions based on pigmentation/erythema
247
- if pigmentation > 0.50:
248
- boosted_probs[0] *= 0.7 # AKIEC
249
- boosted_probs[1] *= 0.6 # BCC
250
- boosted_probs[5] *= 0.5 # INF
251
- boosted_probs[9] *= 0.3 # SCCKA
252
-
253
- if erythema > 0.40:
254
- boosted_probs[7] *= 0.7 # MEL
255
- boosted_probs[8] *= 0.7 # NV
256
-
257
- if pigmentation < 0.20:
258
- boosted_probs[7] *= 0.5 # MEL
259
- boosted_probs[8] *= 0.5 # NV
260
-
261
- # Renormalize
262
- return boosted_probs / boosted_probs.sum()
263
-
264
- def explain_prediction(
265
- self,
266
- probs: np.ndarray,
267
- concept_scores: Dict[str, ConceptScore],
268
- metadata: Dict
269
- ) -> str:
270
- """
271
- Generate natural language explanation
272
-
273
- Args:
274
- probs: Class probabilities
275
- concept_scores: MONET concept scores
276
- metadata: Clinical metadata
277
-
278
- Returns:
279
- Natural language explanation
280
- """
281
- predicted_idx = np.argmax(probs)
282
- predicted_class = self.class_names[predicted_idx]
283
- confidence = probs[predicted_idx]
284
-
285
- explanation = f"**Primary Diagnosis: {predicted_class}**\n"
286
- explanation += f"Confidence: {confidence:.1%}\n\n"
287
-
288
- # Key MONET features
289
- explanation += "**Key Dermoscopic Features:**\n"
290
-
291
- sorted_concepts = sorted(
292
- concept_scores.values(),
293
- key=lambda x: x.score * x.confidence,
294
- reverse=True
295
- )
296
-
297
- for i, concept in enumerate(sorted_concepts[:5], 1):
298
- if concept.score > 0.3 or concept.score < 0.2:
299
- explanation += f"{i}. {concept.name}: {concept.score:.2f} - {concept.description}\n"
300
- if concept.clinical_relevance != "Non-specific":
301
- explanation += f" β†’ {concept.clinical_relevance}\n"
302
-
303
- # Clinical context
304
- explanation += "\n**Clinical Context:**\n"
305
- if 'age_approx' in metadata:
306
- explanation += f"β€’ Age: {metadata['age_approx']} years\n"
307
- if 'sex' in metadata:
308
- explanation += f"β€’ Sex: {metadata['sex']}\n"
309
- if 'site' in metadata:
310
- explanation += f"β€’ Location: {metadata['site']}\n"
311
-
312
- return explanation
313
-
314
- def get_top_concepts(
315
- self,
316
- concept_scores: Dict[str, ConceptScore],
317
- top_k: int = 5,
318
- min_score: float = 0.3
319
- ) -> List[ConceptScore]:
320
- """Get top-k most important concepts"""
321
- filtered = [
322
- cs for cs in concept_scores.values()
323
- if cs.score >= min_score or cs.score < 0.2 # High or low
324
- ]
325
-
326
- sorted_concepts = sorted(
327
- filtered,
328
- key=lambda x: x.score * x.confidence,
329
- reverse=True
330
- )
331
-
332
- return sorted_concepts[:top_k]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_models.py DELETED
@@ -1,86 +0,0 @@
1
- #!/usr/bin/env python3
2
- """Test script to verify model loading"""
3
-
4
- import torch
5
- import torch.nn as nn
6
- import timm
7
- from transformers import AutoModel, AutoProcessor
8
- import numpy as np
9
-
10
- DEVICE = "cpu"
11
- print(f"Device: {DEVICE}")
12
-
13
- # ConvNeXt model definition (matching checkpoint)
14
- class ConvNeXtDualEncoder(nn.Module):
15
- def __init__(self, model_name="convnext_base.fb_in22k_ft_in1k",
16
- metadata_dim=19, num_classes=11, dropout=0.3):
17
- super().__init__()
18
- self.backbone = timm.create_model(model_name, pretrained=False, num_classes=0)
19
- backbone_dim = self.backbone.num_features
20
- self.meta_mlp = nn.Sequential(
21
- nn.Linear(metadata_dim, 64), nn.LayerNorm(64), nn.GELU(), nn.Dropout(dropout)
22
- )
23
- fusion_dim = backbone_dim * 2 + 64
24
- self.classifier = nn.Sequential(
25
- nn.Linear(fusion_dim, 512), nn.LayerNorm(512), nn.GELU(), nn.Dropout(dropout),
26
- nn.Linear(512, 256), nn.LayerNorm(256), nn.GELU(), nn.Dropout(dropout),
27
- nn.Linear(256, num_classes)
28
- )
29
-
30
- def forward(self, clinical_img, derm_img=None, metadata=None):
31
- clinical_features = self.backbone(clinical_img)
32
- derm_features = self.backbone(derm_img) if derm_img is not None else clinical_features
33
- if metadata is not None:
34
- meta_features = self.meta_mlp(metadata)
35
- else:
36
- meta_features = torch.zeros(clinical_features.size(0), 64, device=clinical_features.device)
37
- fused = torch.cat([clinical_features, derm_features, meta_features], dim=1)
38
- return self.classifier(fused)
39
-
40
-
41
- # MedSigLIP model definition
42
- class MedSigLIPClassifier(nn.Module):
43
- def __init__(self, num_classes=11, model_name="google/siglip-base-patch16-384"):
44
- super().__init__()
45
- self.siglip = AutoModel.from_pretrained(model_name)
46
- self.processor = AutoProcessor.from_pretrained(model_name)
47
- hidden_dim = self.siglip.config.vision_config.hidden_size
48
- self.classifier = nn.Sequential(
49
- nn.Linear(hidden_dim, 512), nn.LayerNorm(512), nn.GELU(), nn.Dropout(0.3),
50
- nn.Linear(512, num_classes)
51
- )
52
- for param in self.siglip.parameters():
53
- param.requires_grad = False
54
-
55
- def forward(self, pixel_values):
56
- vision_outputs = self.siglip.vision_model(pixel_values=pixel_values)
57
- pooled_features = vision_outputs.pooler_output
58
- return self.classifier(pooled_features)
59
-
60
-
61
- if __name__ == "__main__":
62
- print("\n[1/2] Loading ConvNeXt...")
63
- convnext_model = ConvNeXtDualEncoder()
64
- ckpt = torch.load("models/seed42_fold0.pt", map_location=DEVICE, weights_only=False)
65
- convnext_model.load_state_dict(ckpt)
66
- convnext_model.eval()
67
- print(" ConvNeXt loaded!")
68
-
69
- print("\n[2/2] Loading MedSigLIP...")
70
- medsiglip_model = MedSigLIPClassifier()
71
- medsiglip_model.eval()
72
- print(" MedSigLIP loaded!")
73
-
74
- # Quick inference test
75
- print("\nTesting inference...")
76
- dummy_img = torch.randn(1, 3, 384, 384)
77
- with torch.no_grad():
78
- convnext_out = convnext_model(dummy_img)
79
- print(f" ConvNeXt output: {convnext_out.shape}")
80
-
81
- dummy_pil = np.random.randint(0, 255, (384, 384, 3), dtype=np.uint8)
82
- siglip_input = medsiglip_model.processor(images=[dummy_pil], return_tensors="pt")
83
- siglip_out = medsiglip_model(siglip_input["pixel_values"])
84
- print(f" MedSigLIP output: {siglip_out.shape}")
85
-
86
- print("\nAll tests passed!")