cgoodmaker commited on
Commit
86f402d
·
0 Parent(s):

Initial commit — SkinProAI dermoscopic analysis platform

Browse files

Patient-level chat UI with streaming AI analysis pipeline:
- MedGemma visual examination with stage-by-stage cascade
- ConvNeXt classifier with confidence scores and differential
- MONET feature extraction and Grad-CAM attention maps
- Temporal comparison between sequential lesion images
- Persistent chat history with full cascade replay on reload
- FastAPI backend with SSE streaming + React/TypeScript frontend
- Docker build ready for Hugging Face Spaces deployment

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .dockerignore +28 -0
  2. .gitattributes +1 -0
  3. .gitignore +45 -0
  4. Dockerfile +40 -0
  5. README.md +44 -0
  6. backend/__init__.py +0 -0
  7. backend/main.py +63 -0
  8. backend/requirements.txt +3 -0
  9. backend/routes/__init__.py +1 -0
  10. backend/routes/analysis.py +181 -0
  11. backend/routes/chat.py +97 -0
  12. backend/routes/lesions.py +241 -0
  13. backend/routes/patients.py +72 -0
  14. backend/services/__init__.py +0 -0
  15. backend/services/analysis_service.py +146 -0
  16. backend/services/chat_service.py +197 -0
  17. data/case_store.py +507 -0
  18. frontend/app.py +532 -0
  19. frontend/components/__init__.py +0 -0
  20. frontend/components/analysis_view.py +214 -0
  21. frontend/components/patient_select.py +48 -0
  22. frontend/components/sidebar.py +55 -0
  23. frontend/components/styles.py +517 -0
  24. guidelines/index/chunks.json +0 -0
  25. guidelines/index/faiss.index +3 -0
  26. mcp_server/__init__.py +0 -0
  27. mcp_server/server.py +286 -0
  28. mcp_server/tool_registry.py +55 -0
  29. models/convnext_classifier.py +383 -0
  30. models/explainability.py +183 -0
  31. models/gradcam_tool.py +285 -0
  32. models/guidelines_rag.py +349 -0
  33. models/medgemma_agent.py +927 -0
  34. models/medsiglip_convnext_fusion.py +224 -0
  35. models/monet_concepts.py +332 -0
  36. models/monet_tool.py +354 -0
  37. models/overlay_tool.py +335 -0
  38. requirements.txt +15 -0
  39. test_models.py +86 -0
  40. web/index.html +16 -0
  41. web/package-lock.json +0 -0
  42. web/package.json +24 -0
  43. web/src/App.tsx +14 -0
  44. web/src/components/MessageContent.css +250 -0
  45. web/src/components/MessageContent.tsx +254 -0
  46. web/src/components/ToolCallCard.css +338 -0
  47. web/src/components/ToolCallCard.tsx +207 -0
  48. web/src/index.css +38 -0
  49. web/src/main.tsx +10 -0
  50. web/src/pages/ChatPage.css +340 -0
.dockerignore ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ venv/
2
+ .venv/
3
+ __pycache__/
4
+ *.pyc
5
+ *.pyo
6
+ *.pyd
7
+ .Python
8
+ .env
9
+
10
+ # Frontend dev artifacts (only dist is needed)
11
+ web/node_modules/
12
+ web/.vite/
13
+
14
+ # Local data — don't ship patient records or uploads
15
+ data/uploads/
16
+ data/patient_chats/
17
+ data/lesions/
18
+ data/patients.json
19
+
20
+ # Misc
21
+ .git/
22
+ .gitignore
23
+ *.md
24
+ test*.py
25
+ test*.jpg
26
+ test*.png
27
+ frontend/
28
+ mcp_server/
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ guidelines/index/faiss.index filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.pyc
4
+ *.pyo
5
+ *.pyd
6
+ *.egg-info/
7
+ .Python
8
+ venv/
9
+ .venv/
10
+ *.egg
11
+
12
+ # Environment
13
+ .env
14
+ .env.*
15
+
16
+ # Node
17
+ web/node_modules/
18
+ web/dist/
19
+ web/.vite/
20
+
21
+ # Patient data — never commit
22
+ data/uploads/
23
+ data/patient_chats/
24
+ data/lesions/
25
+ data/patients.json
26
+
27
+ # Model weights (large binaries — store separately)
28
+ models/*.pt
29
+ models/*.pth
30
+ models/*.bin
31
+ models/*.safetensors
32
+
33
+ # macOS
34
+ .DS_Store
35
+
36
+ # Test artifacts
37
+ test*.jpg
38
+ test*.png
39
+ *.log
40
+
41
+ # Clinical guidelines PDFs (copyrighted — obtain separately)
42
+ guidelines/*.pdf
43
+
44
+ # Temp
45
+ /tmp/
Dockerfile ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Install Node.js for building the React frontend
6
+ RUN apt-get update && \
7
+ apt-get install -y curl && \
8
+ curl -fsSL https://deb.nodesource.com/setup_20.x | bash - && \
9
+ apt-get install -y nodejs && \
10
+ rm -rf /var/lib/apt/lists/*
11
+
12
+ # Install Python dependencies
13
+ COPY requirements.txt ml-requirements.txt
14
+ COPY backend/requirements.txt api-requirements.txt
15
+ RUN pip install --no-cache-dir -r ml-requirements.txt -r api-requirements.txt
16
+
17
+ # Build React frontend
18
+ COPY web/ web/
19
+ WORKDIR /app/web
20
+ RUN npm ci && npm run build
21
+
22
+ WORKDIR /app
23
+
24
+ # Copy application source
25
+ COPY models/ models/
26
+ COPY backend/ backend/
27
+ COPY data/case_store.py data/case_store.py
28
+ COPY guidelines/ guidelines/
29
+
30
+ # Runtime directories (writable by the app)
31
+ RUN mkdir -p data/uploads data/patient_chats data/lesions && \
32
+ echo '{"patients": []}' > data/patients.json
33
+
34
+ # HF Spaces runs as a non-root user — ensure data dirs are writable
35
+ RUN chmod -R 777 data/
36
+
37
+ # HF Spaces uses port 7860
38
+ EXPOSE 7860
39
+
40
+ CMD ["uvicorn", "backend.main:app", "--host", "0.0.0.0", "--port", "7860"]
README.md ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: SkinProAI
3
+ emoji: 🔬
4
+ colorFrom: blue
5
+ colorTo: indigo
6
+ sdk: docker
7
+ app_port: 7860
8
+ pinned: false
9
+ ---
10
+
11
+ # SkinProAI
12
+
13
+ AI-assisted dermoscopic lesion analysis for clinical decision support.
14
+
15
+ ## Features
16
+
17
+ - **Patient management** — create and select patient profiles
18
+ - **Image analysis** — upload dermoscopic images for automated assessment via MedGemma visual examination, MONET feature extraction, and ConvNeXt classification
19
+ - **Temporal comparison** — sequential images are automatically compared to detect change over time
20
+ - **Grad-CAM visualisation** — attention maps highlight regions driving the classification
21
+ - **Persistent chat history** — full analysis cascade is stored and replayed on reload
22
+
23
+ ## Architecture
24
+
25
+ | Layer | Technology |
26
+ |-------|-----------|
27
+ | Frontend | React 18 + TypeScript (Vite) |
28
+ | Backend | FastAPI + uvicorn |
29
+ | Vision-language model | MedGemma (Google) via Hugging Face |
30
+ | Classifier | ConvNeXt fine-tuned on ISIC HAM10000 |
31
+ | Feature extraction | MONET skin concept probes |
32
+ | Explainability | Grad-CAM |
33
+
34
+ ## Usage
35
+
36
+ 1. Open the app and create a patient record
37
+ 2. Click the patient card to open the chat
38
+ 3. Attach a dermoscopic image and send — analysis runs automatically
39
+ 4. Upload further images for the same patient to trigger temporal comparison
40
+ 5. Ask follow-up questions in text to query the AI about the findings
41
+
42
+ ## Disclaimer
43
+
44
+ SkinProAI is a research prototype intended for educational and investigational use only. It is **not** a certified medical device and must not be used as a substitute for professional clinical judgement.
backend/__init__.py ADDED
File without changes
backend/main.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SkinProAI FastAPI Backend
3
+ """
4
+
5
+ from fastapi import FastAPI
6
+ from fastapi.middleware.cors import CORSMiddleware
7
+ from fastapi.staticfiles import StaticFiles
8
+ from pathlib import Path
9
+ import sys
10
+
11
+ # Add project root to path for model imports
12
+ sys.path.insert(0, str(Path(__file__).parent.parent))
13
+
14
+ from backend.routes import patients, lesions, analysis, chat
15
+
16
+ app = FastAPI(title="SkinProAI API", version="1.0.0")
17
+
18
+ # CORS middleware
19
+ app.add_middleware(
20
+ CORSMiddleware,
21
+ allow_origins=["*"],
22
+ allow_credentials=True,
23
+ allow_methods=["*"],
24
+ allow_headers=["*"],
25
+ )
26
+
27
+ # API routes — analysis must be registered BEFORE patients so the literal
28
+ # /gradcam route is not shadowed by the parameterised /{patient_id} route.
29
+ app.include_router(analysis.router, prefix="/api/patients", tags=["analysis"])
30
+ app.include_router(chat.router, prefix="/api/patients", tags=["chat"])
31
+ app.include_router(patients.router, prefix="/api/patients", tags=["patients"])
32
+ app.include_router(lesions.router, prefix="/api/patients", tags=["lesions"])
33
+
34
+ # Ensure upload directories exist
35
+ UPLOADS_DIR = Path(__file__).parent.parent / "data" / "uploads"
36
+ UPLOADS_DIR.mkdir(parents=True, exist_ok=True)
37
+
38
+ # Serve uploaded images
39
+ if UPLOADS_DIR.exists():
40
+ app.mount("/uploads", StaticFiles(directory=str(UPLOADS_DIR)), name="uploads")
41
+
42
+ # Serve React build (production)
43
+ BUILD_DIR = Path(__file__).parent.parent / "web" / "dist"
44
+ if BUILD_DIR.exists():
45
+ app.mount("/", StaticFiles(directory=str(BUILD_DIR), html=True), name="static")
46
+
47
+
48
+ @app.on_event("shutdown")
49
+ async def shutdown_event():
50
+ from backend.services.analysis_service import get_analysis_service
51
+ svc = get_analysis_service()
52
+ if svc.agent.mcp_client:
53
+ svc.agent.mcp_client.stop()
54
+
55
+
56
+ @app.get("/api/health")
57
+ def health_check():
58
+ return {"status": "ok"}
59
+
60
+
61
+ if __name__ == "__main__":
62
+ import uvicorn
63
+ uvicorn.run(app, host="0.0.0.0", port=8000)
backend/requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ fastapi>=0.100.0
2
+ uvicorn[standard]>=0.23.0
3
+ python-multipart>=0.0.6
backend/routes/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . import patients, lesions, analysis
backend/routes/analysis.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Analysis Routes - Image analysis with SSE streaming
3
+ """
4
+
5
+ from fastapi import APIRouter, Query, HTTPException
6
+ from fastapi.responses import StreamingResponse, FileResponse
7
+ from pathlib import Path
8
+ import json
9
+ import tempfile
10
+
11
+ from backend.services.analysis_service import get_analysis_service
12
+ from data.case_store import get_case_store
13
+
14
+ router = APIRouter()
15
+
16
+
17
+ @router.get("/gradcam")
18
+ def get_gradcam_by_path(path: str = Query(...)):
19
+ """Serve a temp visualization image (GradCAM or comparison overlay)"""
20
+ if not path:
21
+ raise HTTPException(status_code=400, detail="No path provided")
22
+
23
+ temp_dir = Path(tempfile.gettempdir()).resolve()
24
+ resolved_path = Path(path).resolve()
25
+ if not str(resolved_path).startswith(str(temp_dir)):
26
+ raise HTTPException(status_code=403, detail="Access denied")
27
+
28
+ allowed_suffixes = ("_gradcam.png", "_comparison.png")
29
+ if not any(resolved_path.name.endswith(s) for s in allowed_suffixes):
30
+ raise HTTPException(status_code=400, detail="Invalid image path")
31
+
32
+ if resolved_path.exists():
33
+ return FileResponse(str(resolved_path), media_type="image/png")
34
+ raise HTTPException(status_code=404, detail="Image not found")
35
+
36
+
37
+ @router.post("/{patient_id}/lesions/{lesion_id}/images/{image_id}/analyze")
38
+ async def analyze_image(
39
+ patient_id: str,
40
+ lesion_id: str,
41
+ image_id: str,
42
+ question: str = Query(None)
43
+ ):
44
+ """Analyze an image with SSE streaming"""
45
+ store = get_case_store()
46
+
47
+ # Verify image exists
48
+ img = store.get_image(patient_id, lesion_id, image_id)
49
+ if not img:
50
+ raise HTTPException(status_code=404, detail="Image not found")
51
+ if not img.image_path:
52
+ raise HTTPException(status_code=400, detail="Image has no file uploaded")
53
+
54
+ service = get_analysis_service()
55
+
56
+ async def generate():
57
+ try:
58
+ for chunk in service.analyze(patient_id, lesion_id, image_id, question):
59
+ yield f"data: {json.dumps(chunk)}\n\n"
60
+ yield "data: [DONE]\n\n"
61
+ except Exception as e:
62
+ yield f"data: {json.dumps(f'[ERROR]{str(e)}[/ERROR]')}\n\n"
63
+ yield "data: [DONE]\n\n"
64
+
65
+ return StreamingResponse(
66
+ generate(),
67
+ media_type="text/event-stream",
68
+ headers={
69
+ "Cache-Control": "no-cache",
70
+ "Connection": "keep-alive",
71
+ }
72
+ )
73
+
74
+
75
+ @router.post("/{patient_id}/lesions/{lesion_id}/images/{image_id}/confirm")
76
+ async def confirm_diagnosis(
77
+ patient_id: str,
78
+ lesion_id: str,
79
+ image_id: str,
80
+ confirmed: bool = Query(...),
81
+ feedback: str = Query(None)
82
+ ):
83
+ """Confirm or reject diagnosis and get management guidance"""
84
+ service = get_analysis_service()
85
+
86
+ async def generate():
87
+ try:
88
+ for chunk in service.confirm(patient_id, lesion_id, image_id, confirmed, feedback):
89
+ yield f"data: {json.dumps(chunk)}\n\n"
90
+ yield "data: [DONE]\n\n"
91
+ except Exception as e:
92
+ yield f"data: {json.dumps(f'[ERROR]{str(e)}[/ERROR]')}\n\n"
93
+ yield "data: [DONE]\n\n"
94
+
95
+ return StreamingResponse(
96
+ generate(),
97
+ media_type="text/event-stream",
98
+ headers={
99
+ "Cache-Control": "no-cache",
100
+ "Connection": "keep-alive",
101
+ }
102
+ )
103
+
104
+
105
+ @router.post("/{patient_id}/lesions/{lesion_id}/images/{image_id}/compare")
106
+ async def compare_to_previous(
107
+ patient_id: str,
108
+ lesion_id: str,
109
+ image_id: str
110
+ ):
111
+ """Compare this image to the previous one in the timeline"""
112
+ store = get_case_store()
113
+
114
+ # Get current and previous images
115
+ current_img = store.get_image(patient_id, lesion_id, image_id)
116
+ if not current_img:
117
+ raise HTTPException(status_code=404, detail="Image not found")
118
+
119
+ previous_img = store.get_previous_image(patient_id, lesion_id, image_id)
120
+ if not previous_img:
121
+ raise HTTPException(status_code=400, detail="No previous image to compare")
122
+
123
+ service = get_analysis_service()
124
+
125
+ async def generate():
126
+ try:
127
+ for chunk in service.compare_images(
128
+ patient_id, lesion_id,
129
+ previous_img.image_path,
130
+ current_img.image_path,
131
+ image_id
132
+ ):
133
+ yield f"data: {json.dumps(chunk)}\n\n"
134
+ yield "data: [DONE]\n\n"
135
+ except Exception as e:
136
+ yield f"data: {json.dumps(f'[ERROR]{str(e)}[/ERROR]')}\n\n"
137
+ yield "data: [DONE]\n\n"
138
+
139
+ return StreamingResponse(
140
+ generate(),
141
+ media_type="text/event-stream",
142
+ headers={
143
+ "Cache-Control": "no-cache",
144
+ "Connection": "keep-alive",
145
+ }
146
+ )
147
+
148
+
149
+ @router.post("/{patient_id}/lesions/{lesion_id}/chat")
150
+ async def chat_message(
151
+ patient_id: str,
152
+ lesion_id: str,
153
+ message: dict
154
+ ):
155
+ """Send a chat message with SSE streaming response"""
156
+ store = get_case_store()
157
+
158
+ lesion = store.get_lesion(patient_id, lesion_id)
159
+ if not lesion:
160
+ raise HTTPException(status_code=404, detail="Lesion not found")
161
+
162
+ service = get_analysis_service()
163
+ content = message.get("content", "")
164
+
165
+ async def generate():
166
+ try:
167
+ for chunk in service.chat_followup(patient_id, lesion_id, content):
168
+ yield f"data: {json.dumps(chunk)}\n\n"
169
+ yield "data: [DONE]\n\n"
170
+ except Exception as e:
171
+ yield f"data: {json.dumps(f'[ERROR]{str(e)}[/ERROR]')}\n\n"
172
+ yield "data: [DONE]\n\n"
173
+
174
+ return StreamingResponse(
175
+ generate(),
176
+ media_type="text/event-stream",
177
+ headers={
178
+ "Cache-Control": "no-cache",
179
+ "Connection": "keep-alive",
180
+ }
181
+ )
backend/routes/chat.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Chat Routes - Patient-level chat with image analysis tools
3
+ """
4
+
5
+ import asyncio
6
+ import json
7
+ import threading
8
+ from typing import Optional
9
+
10
+ from fastapi import APIRouter, HTTPException, UploadFile, File, Form
11
+ from fastapi.responses import StreamingResponse
12
+
13
+ from data.case_store import get_case_store
14
+ from backend.services.chat_service import get_chat_service
15
+
16
+ router = APIRouter()
17
+
18
+
19
+ @router.get("/{patient_id}/chat")
20
+ def get_chat_history(patient_id: str):
21
+ """Get patient-level chat history"""
22
+ store = get_case_store()
23
+ if not store.get_patient(patient_id):
24
+ raise HTTPException(status_code=404, detail="Patient not found")
25
+ messages = store.get_patient_chat_history(patient_id)
26
+ return {"messages": messages}
27
+
28
+
29
+ @router.delete("/{patient_id}/chat")
30
+ def clear_chat(patient_id: str):
31
+ """Clear patient-level chat history"""
32
+ store = get_case_store()
33
+ if not store.get_patient(patient_id):
34
+ raise HTTPException(status_code=404, detail="Patient not found")
35
+ store.clear_patient_chat_history(patient_id)
36
+ return {"success": True}
37
+
38
+
39
+ @router.post("/{patient_id}/chat")
40
+ async def post_chat_message(
41
+ patient_id: str,
42
+ content: str = Form(""),
43
+ image: Optional[UploadFile] = File(None),
44
+ ):
45
+ """Send a chat message, optionally with an image — SSE streaming response.
46
+
47
+ The sync ML generator runs in a background thread so it never blocks the
48
+ event loop. Events flow through an asyncio.Queue, so each SSE event is
49
+ flushed to the browser the moment it is produced (spinner shows instantly).
50
+ """
51
+ store = get_case_store()
52
+ if not store.get_patient(patient_id):
53
+ raise HTTPException(status_code=404, detail="Patient not found")
54
+
55
+ image_bytes = None
56
+ if image and image.filename:
57
+ image_bytes = await image.read()
58
+
59
+ chat_service = get_chat_service()
60
+
61
+ async def generate():
62
+ loop = asyncio.get_event_loop()
63
+ queue: asyncio.Queue = asyncio.Queue()
64
+
65
+ _SENTINEL = object()
66
+
67
+ def run_sync():
68
+ try:
69
+ for event in chat_service.stream_chat(patient_id, content, image_bytes):
70
+ loop.call_soon_threadsafe(queue.put_nowait, event)
71
+ except Exception as e:
72
+ loop.call_soon_threadsafe(
73
+ queue.put_nowait,
74
+ {"type": "error", "message": str(e)},
75
+ )
76
+ finally:
77
+ loop.call_soon_threadsafe(queue.put_nowait, _SENTINEL)
78
+
79
+ thread = threading.Thread(target=run_sync, daemon=True)
80
+ thread.start()
81
+
82
+ while True:
83
+ event = await queue.get()
84
+ if event is _SENTINEL:
85
+ break
86
+ yield f"data: {json.dumps(event)}\n\n"
87
+
88
+ yield f"data: {json.dumps({'type': 'done'})}\n\n"
89
+
90
+ return StreamingResponse(
91
+ generate(),
92
+ media_type="text/event-stream",
93
+ headers={
94
+ "Cache-Control": "no-cache",
95
+ "Connection": "keep-alive",
96
+ },
97
+ )
backend/routes/lesions.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Lesion Routes - CRUD for lesions and images
3
+ """
4
+
5
+ from fastapi import APIRouter, HTTPException, UploadFile, File
6
+ from fastapi.responses import FileResponse
7
+ from pydantic import BaseModel
8
+ from dataclasses import asdict
9
+ from pathlib import Path
10
+ from PIL import Image
11
+ import io
12
+
13
+ from data.case_store import get_case_store
14
+
15
+ router = APIRouter()
16
+
17
+
18
+ class CreateLesionRequest(BaseModel):
19
+ name: str
20
+ location: str = ""
21
+
22
+
23
+ class UpdateLesionRequest(BaseModel):
24
+ name: str = None
25
+ location: str = None
26
+
27
+
28
+ # -------------------------------------------------------------------------
29
+ # Lesion CRUD
30
+ # -------------------------------------------------------------------------
31
+
32
+ @router.get("/{patient_id}/lesions")
33
+ def list_lesions(patient_id: str):
34
+ """List all lesions for a patient"""
35
+ store = get_case_store()
36
+
37
+ patient = store.get_patient(patient_id)
38
+ if not patient:
39
+ raise HTTPException(status_code=404, detail="Patient not found")
40
+
41
+ lesions = store.list_lesions(patient_id)
42
+
43
+ result = []
44
+ for lesion in lesions:
45
+ images = store.list_images(patient_id, lesion.id)
46
+ # Get the most recent image as thumbnail
47
+ latest_image = images[-1] if images else None
48
+
49
+ result.append({
50
+ "id": lesion.id,
51
+ "patient_id": lesion.patient_id,
52
+ "name": lesion.name,
53
+ "location": lesion.location,
54
+ "created_at": lesion.created_at,
55
+ "image_count": len(images),
56
+ "latest_image": asdict(latest_image) if latest_image else None
57
+ })
58
+
59
+ return {"lesions": result}
60
+
61
+
62
+ @router.post("/{patient_id}/lesions")
63
+ def create_lesion(patient_id: str, req: CreateLesionRequest):
64
+ """Create a new lesion for a patient"""
65
+ store = get_case_store()
66
+
67
+ patient = store.get_patient(patient_id)
68
+ if not patient:
69
+ raise HTTPException(status_code=404, detail="Patient not found")
70
+
71
+ lesion = store.create_lesion(patient_id, req.name, req.location)
72
+ return {
73
+ "lesion": {
74
+ **asdict(lesion),
75
+ "image_count": 0,
76
+ "images": []
77
+ }
78
+ }
79
+
80
+
81
+ @router.get("/{patient_id}/lesions/{lesion_id}")
82
+ def get_lesion(patient_id: str, lesion_id: str):
83
+ """Get a lesion with all its images"""
84
+ store = get_case_store()
85
+
86
+ lesion = store.get_lesion(patient_id, lesion_id)
87
+ if not lesion:
88
+ raise HTTPException(status_code=404, detail="Lesion not found")
89
+
90
+ images = store.list_images(patient_id, lesion_id)
91
+
92
+ return {
93
+ "lesion": {
94
+ **asdict(lesion),
95
+ "image_count": len(images),
96
+ "images": [asdict(img) for img in images]
97
+ }
98
+ }
99
+
100
+
101
+ @router.patch("/{patient_id}/lesions/{lesion_id}")
102
+ def update_lesion(patient_id: str, lesion_id: str, req: UpdateLesionRequest):
103
+ """Update a lesion's name or location"""
104
+ store = get_case_store()
105
+
106
+ lesion = store.get_lesion(patient_id, lesion_id)
107
+ if not lesion:
108
+ raise HTTPException(status_code=404, detail="Lesion not found")
109
+
110
+ store.update_lesion(patient_id, lesion_id, req.name, req.location)
111
+
112
+ # Return updated lesion
113
+ lesion = store.get_lesion(patient_id, lesion_id)
114
+ images = store.list_images(patient_id, lesion_id)
115
+
116
+ return {
117
+ "lesion": {
118
+ **asdict(lesion),
119
+ "image_count": len(images),
120
+ "images": [asdict(img) for img in images]
121
+ }
122
+ }
123
+
124
+
125
+ @router.delete("/{patient_id}/lesions/{lesion_id}")
126
+ def delete_lesion(patient_id: str, lesion_id: str):
127
+ """Delete a lesion and all its images"""
128
+ store = get_case_store()
129
+
130
+ lesion = store.get_lesion(patient_id, lesion_id)
131
+ if not lesion:
132
+ raise HTTPException(status_code=404, detail="Lesion not found")
133
+
134
+ store.delete_lesion(patient_id, lesion_id)
135
+ return {"success": True}
136
+
137
+
138
+ # -------------------------------------------------------------------------
139
+ # Image CRUD
140
+ # -------------------------------------------------------------------------
141
+
142
+ @router.post("/{patient_id}/lesions/{lesion_id}/images")
143
+ async def upload_image(patient_id: str, lesion_id: str, image: UploadFile = File(...)):
144
+ """Upload a new image to a lesion's timeline"""
145
+ store = get_case_store()
146
+
147
+ lesion = store.get_lesion(patient_id, lesion_id)
148
+ if not lesion:
149
+ raise HTTPException(status_code=404, detail="Lesion not found")
150
+
151
+ try:
152
+ # Create image record
153
+ img_record = store.add_image(patient_id, lesion_id)
154
+
155
+ # Save the actual image file
156
+ pil_image = Image.open(io.BytesIO(await image.read())).convert("RGB")
157
+ image_path = store.save_lesion_image(patient_id, lesion_id, img_record.id, pil_image)
158
+
159
+ # Update image record with path
160
+ store.update_image(patient_id, lesion_id, img_record.id, image_path=image_path)
161
+
162
+ # Return updated record
163
+ img_record = store.get_image(patient_id, lesion_id, img_record.id)
164
+ return {"image": asdict(img_record)}
165
+
166
+ except Exception as e:
167
+ raise HTTPException(status_code=400, detail=f"Failed to upload image: {e}")
168
+
169
+
170
+ @router.get("/{patient_id}/lesions/{lesion_id}/images/{image_id}")
171
+ def get_image_record(patient_id: str, lesion_id: str, image_id: str):
172
+ """Get an image record"""
173
+ store = get_case_store()
174
+
175
+ img = store.get_image(patient_id, lesion_id, image_id)
176
+ if not img:
177
+ raise HTTPException(status_code=404, detail="Image not found")
178
+
179
+ return {"image": asdict(img)}
180
+
181
+
182
+ @router.get("/{patient_id}/lesions/{lesion_id}/images/{image_id}/file")
183
+ def get_image_file(patient_id: str, lesion_id: str, image_id: str):
184
+ """Get the actual image file"""
185
+ store = get_case_store()
186
+
187
+ img = store.get_image(patient_id, lesion_id, image_id)
188
+ if not img or not img.image_path:
189
+ raise HTTPException(status_code=404, detail="Image not found")
190
+
191
+ path = Path(img.image_path)
192
+ if not path.exists():
193
+ raise HTTPException(status_code=404, detail="Image file not found")
194
+
195
+ return FileResponse(str(path), media_type="image/png")
196
+
197
+
198
+ @router.get("/{patient_id}/lesions/{lesion_id}/images/{image_id}/gradcam")
199
+ def get_gradcam_file(patient_id: str, lesion_id: str, image_id: str):
200
+ """Get the GradCAM visualization for an image"""
201
+ store = get_case_store()
202
+
203
+ img = store.get_image(patient_id, lesion_id, image_id)
204
+ if not img or not img.gradcam_path:
205
+ raise HTTPException(status_code=404, detail="GradCAM not found")
206
+
207
+ path = Path(img.gradcam_path)
208
+ if not path.exists():
209
+ raise HTTPException(status_code=404, detail="GradCAM file not found")
210
+
211
+ return FileResponse(str(path), media_type="image/png")
212
+
213
+
214
+ # -------------------------------------------------------------------------
215
+ # Chat
216
+ # -------------------------------------------------------------------------
217
+
218
+ @router.get("/{patient_id}/lesions/{lesion_id}/chat")
219
+ def get_chat_history(patient_id: str, lesion_id: str):
220
+ """Get chat history for a lesion"""
221
+ store = get_case_store()
222
+
223
+ lesion = store.get_lesion(patient_id, lesion_id)
224
+ if not lesion:
225
+ raise HTTPException(status_code=404, detail="Lesion not found")
226
+
227
+ messages = store.get_chat_history(patient_id, lesion_id)
228
+ return {"messages": [asdict(m) for m in messages]}
229
+
230
+
231
+ @router.delete("/{patient_id}/lesions/{lesion_id}/chat")
232
+ def clear_chat_history(patient_id: str, lesion_id: str):
233
+ """Clear chat history for a lesion"""
234
+ store = get_case_store()
235
+
236
+ lesion = store.get_lesion(patient_id, lesion_id)
237
+ if not lesion:
238
+ raise HTTPException(status_code=404, detail="Lesion not found")
239
+
240
+ store.clear_chat_history(patient_id, lesion_id)
241
+ return {"success": True}
backend/routes/patients.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Patient Routes - CRUD for patients
3
+ """
4
+
5
+ from fastapi import APIRouter, HTTPException
6
+ from pydantic import BaseModel
7
+ from dataclasses import asdict
8
+
9
+ from data.case_store import get_case_store
10
+
11
+ router = APIRouter()
12
+
13
+
14
+ class CreatePatientRequest(BaseModel):
15
+ name: str
16
+
17
+
18
+ @router.get("")
19
+ def list_patients():
20
+ """List all patients with lesion counts"""
21
+ store = get_case_store()
22
+ patients = store.list_patients()
23
+
24
+ result = []
25
+ for p in patients:
26
+ result.append({
27
+ **asdict(p),
28
+ "lesion_count": store.get_patient_lesion_count(p.id)
29
+ })
30
+
31
+ return {"patients": result}
32
+
33
+
34
+ @router.post("")
35
+ def create_patient(req: CreatePatientRequest):
36
+ """Create a new patient"""
37
+ store = get_case_store()
38
+ patient = store.create_patient(req.name)
39
+ return {
40
+ "patient": {
41
+ **asdict(patient),
42
+ "lesion_count": 0
43
+ }
44
+ }
45
+
46
+
47
+ @router.get("/{patient_id}")
48
+ def get_patient(patient_id: str):
49
+ """Get a patient by ID"""
50
+ store = get_case_store()
51
+ patient = store.get_patient(patient_id)
52
+ if not patient:
53
+ raise HTTPException(status_code=404, detail="Patient not found")
54
+
55
+ return {
56
+ "patient": {
57
+ **asdict(patient),
58
+ "lesion_count": store.get_patient_lesion_count(patient_id)
59
+ }
60
+ }
61
+
62
+
63
+ @router.delete("/{patient_id}")
64
+ def delete_patient(patient_id: str):
65
+ """Delete a patient and all their lesions"""
66
+ store = get_case_store()
67
+ patient = store.get_patient(patient_id)
68
+ if not patient:
69
+ raise HTTPException(status_code=404, detail="Patient not found")
70
+
71
+ store.delete_patient(patient_id)
72
+ return {"success": True}
backend/services/__init__.py ADDED
File without changes
backend/services/analysis_service.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Analysis Service - Wraps MedGemmaAgent for API use
3
+ """
4
+
5
+ from pathlib import Path
6
+ from dataclasses import asdict
7
+ from typing import Optional, Generator
8
+
9
+ from models.medgemma_agent import MedGemmaAgent
10
+ from data.case_store import get_case_store
11
+
12
+
13
+ class AnalysisService:
14
+ """Singleton service for managing analysis operations"""
15
+
16
+ _instance = None
17
+
18
+ def __init__(self):
19
+ self.agent = MedGemmaAgent(verbose=True)
20
+ self.store = get_case_store()
21
+ self._loaded = False
22
+
23
+ def _ensure_loaded(self):
24
+ """Lazy load the ML models"""
25
+ if not self._loaded:
26
+ self.agent.load_model()
27
+ self._loaded = True
28
+
29
+ def analyze(
30
+ self,
31
+ patient_id: str,
32
+ lesion_id: str,
33
+ image_id: str,
34
+ question: Optional[str] = None
35
+ ) -> Generator[str, None, None]:
36
+ """Run analysis on an image, yielding streaming chunks"""
37
+ self._ensure_loaded()
38
+
39
+ image = self.store.get_image(patient_id, lesion_id, image_id)
40
+ if not image or not image.image_path:
41
+ yield "[ERROR]No image uploaded[/ERROR]"
42
+ return
43
+
44
+ # Update stage
45
+ self.store.update_image(patient_id, lesion_id, image_id, stage="analyzing")
46
+
47
+ # Reset agent state for new analysis
48
+ self.agent.reset_state()
49
+
50
+ # Run analysis with question
51
+ for chunk in self.agent.analyze_image_stream(image.image_path, question=question or ""):
52
+ yield chunk
53
+
54
+ # Save diagnosis after analysis
55
+ if self.agent.last_diagnosis:
56
+ analysis_data = {
57
+ "diagnosis": self.agent.last_diagnosis["predictions"][0]["class"],
58
+ "full_name": self.agent.last_diagnosis["predictions"][0]["full_name"],
59
+ "confidence": self.agent.last_diagnosis["predictions"][0]["probability"],
60
+ "all_predictions": self.agent.last_diagnosis["predictions"]
61
+ }
62
+
63
+ # Save MONET features if available
64
+ if self.agent.last_monet_result:
65
+ analysis_data["monet_features"] = self.agent.last_monet_result.get("features", {})
66
+
67
+ self.store.update_image(
68
+ patient_id, lesion_id, image_id,
69
+ stage="awaiting_confirmation",
70
+ analysis=analysis_data
71
+ )
72
+
73
+ def confirm(
74
+ self,
75
+ patient_id: str,
76
+ lesion_id: str,
77
+ image_id: str,
78
+ confirmed: bool,
79
+ feedback: Optional[str] = None
80
+ ) -> Generator[str, None, None]:
81
+ """Confirm diagnosis and generate management guidance"""
82
+ for chunk in self.agent.generate_management_guidance(confirmed, feedback):
83
+ yield chunk
84
+
85
+ # Update stage to complete
86
+ self.store.update_image(patient_id, lesion_id, image_id, stage="complete")
87
+
88
+ def chat_followup(
89
+ self,
90
+ patient_id: str,
91
+ lesion_id: str,
92
+ message: str
93
+ ) -> Generator[str, None, None]:
94
+ """Handle follow-up chat messages"""
95
+ # Save user message
96
+ self.store.add_chat_message(patient_id, lesion_id, "user", message)
97
+
98
+ # Generate response
99
+ response = ""
100
+ for chunk in self.agent.chat_followup(message):
101
+ response += chunk
102
+ yield chunk
103
+
104
+ # Save assistant response
105
+ self.store.add_chat_message(patient_id, lesion_id, "assistant", response)
106
+
107
+ def get_chat_history(self, patient_id: str, lesion_id: str):
108
+ """Get chat history for a lesion"""
109
+ messages = self.store.get_chat_history(patient_id, lesion_id)
110
+ return [asdict(m) for m in messages]
111
+
112
+ def compare_images(
113
+ self,
114
+ patient_id: str,
115
+ lesion_id: str,
116
+ previous_image_path: str,
117
+ current_image_path: str,
118
+ current_image_id: str
119
+ ) -> Generator[str, None, None]:
120
+ """Compare two images and assess changes"""
121
+ self._ensure_loaded()
122
+
123
+ # Run comparison
124
+ comparison_result = None
125
+ for chunk in self.agent.compare_followup_images(previous_image_path, current_image_path):
126
+ yield chunk
127
+
128
+ # Extract comparison status from agent if available
129
+ # Default to STABLE if we can't determine
130
+ comparison_data = {
131
+ "status": "STABLE",
132
+ "summary": "Comparison complete"
133
+ }
134
+
135
+ # Update the current image with comparison data
136
+ self.store.update_image(
137
+ patient_id, lesion_id, current_image_id,
138
+ comparison=comparison_data
139
+ )
140
+
141
+
142
+ def get_analysis_service() -> AnalysisService:
143
+ """Get or create AnalysisService singleton"""
144
+ if AnalysisService._instance is None:
145
+ AnalysisService._instance = AnalysisService()
146
+ return AnalysisService._instance
backend/services/chat_service.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Chat Service - Patient-level chat with tool dispatch and streaming
3
+ """
4
+
5
+ import io
6
+ import re
7
+ import uuid
8
+ from typing import Generator, Optional
9
+ from pathlib import Path
10
+ from PIL import Image as PILImage
11
+
12
+ from data.case_store import get_case_store
13
+ from backend.services.analysis_service import get_analysis_service
14
+
15
+
16
+ def _extract_response_text(raw: str) -> str:
17
+ """Pull clean text out of [RESPONSE]...[/RESPONSE]; strip all other tags."""
18
+ # Grab the RESPONSE block first
19
+ match = re.search(r'\[RESPONSE\](.*?)\[/RESPONSE\]', raw, re.DOTALL)
20
+ if match:
21
+ return match.group(1).strip()
22
+ # Fallback: strip every known markup tag
23
+ clean = re.sub(
24
+ r'\[(STAGE:[^\]]+|THINKING|RESPONSE|/RESPONSE|/THINKING|/STAGE'
25
+ r'|ERROR|/ERROR|RESULT|/RESULT|CONFIRM:\d+|/CONFIRM)\]',
26
+ '', raw
27
+ )
28
+ return clean.strip()
29
+
30
+
31
+ class ChatService:
32
+ _instance = None
33
+
34
+ def __init__(self):
35
+ self.store = get_case_store()
36
+
37
+ def _get_image_url(self, patient_id: str, lesion_id: str, image_id: str) -> str:
38
+ return f"/uploads/{patient_id}/{lesion_id}/{image_id}/image.png"
39
+
40
+ def stream_chat(
41
+ self,
42
+ patient_id: str,
43
+ content: str,
44
+ image_bytes: Optional[bytes] = None,
45
+ ) -> Generator[dict, None, None]:
46
+ """Main chat handler — yields SSE event dicts."""
47
+ analysis_service = get_analysis_service()
48
+
49
+ if image_bytes:
50
+ # ----------------------------------------------------------------
51
+ # Image path: analyze (and optionally compare).
52
+ # We do NOT stream the raw verbose analysis text to the chat bubble —
53
+ # the tool card IS the display artefact. We accumulate the text
54
+ # internally, extract the clean [RESPONSE] block, and put it in
55
+ # tool_result.summary so the expanded card can show it.
56
+ # ----------------------------------------------------------------
57
+ lesion = self.store.get_or_create_chat_lesion(patient_id)
58
+
59
+ img_record = self.store.add_image(patient_id, lesion.id)
60
+ pil_image = PILImage.open(io.BytesIO(image_bytes)).convert("RGB")
61
+ abs_path = self.store.save_lesion_image(
62
+ patient_id, lesion.id, img_record.id, pil_image
63
+ )
64
+ self.store.update_image(patient_id, lesion.id, img_record.id, image_path=abs_path)
65
+
66
+ user_image_url = self._get_image_url(patient_id, lesion.id, img_record.id)
67
+ self.store.add_patient_chat_message(
68
+ patient_id, "user", content, image_url=user_image_url
69
+ )
70
+
71
+ # ---- tool: analyze_image ----------------------------------------
72
+ call_id = f"tc-{uuid.uuid4().hex[:6]}"
73
+ yield {"type": "tool_start", "tool": "analyze_image", "call_id": call_id}
74
+
75
+ analysis_text = ""
76
+ for chunk in analysis_service.analyze(patient_id, lesion.id, img_record.id):
77
+ yield {"type": "text", "content": chunk}
78
+ analysis_text += chunk
79
+
80
+ updated_img = self.store.get_image(patient_id, lesion.id, img_record.id)
81
+ analysis_result: dict = {
82
+ "image_url": user_image_url,
83
+ "summary": _extract_response_text(analysis_text),
84
+ "diagnosis": None,
85
+ "full_name": None,
86
+ "confidence": None,
87
+ "all_predictions": [],
88
+ }
89
+ if updated_img and updated_img.analysis:
90
+ a = updated_img.analysis
91
+ analysis_result.update({
92
+ "diagnosis": a.get("diagnosis"),
93
+ "full_name": a.get("full_name"),
94
+ "confidence": a.get("confidence"),
95
+ "all_predictions": a.get("all_predictions", []),
96
+ })
97
+
98
+ yield {
99
+ "type": "tool_result",
100
+ "tool": "analyze_image",
101
+ "call_id": call_id,
102
+ "result": analysis_result,
103
+ }
104
+
105
+ # ---- tool: compare_images (if a previous image exists) ----------
106
+ previous_img = self.store.get_previous_image(patient_id, lesion.id, img_record.id)
107
+ compare_call_id = None
108
+ compare_result = None
109
+ compare_text = ""
110
+
111
+ if (
112
+ previous_img
113
+ and previous_img.image_path
114
+ and Path(previous_img.image_path).exists()
115
+ ):
116
+ compare_call_id = f"tc-{uuid.uuid4().hex[:6]}"
117
+ yield {
118
+ "type": "tool_start",
119
+ "tool": "compare_images",
120
+ "call_id": compare_call_id,
121
+ }
122
+
123
+ for chunk in analysis_service.compare_images(
124
+ patient_id,
125
+ lesion.id,
126
+ previous_img.image_path,
127
+ abs_path,
128
+ img_record.id,
129
+ ):
130
+ yield {"type": "text", "content": chunk}
131
+ compare_text += chunk
132
+
133
+ updated_img2 = self.store.get_image(patient_id, lesion.id, img_record.id)
134
+ compare_result = {
135
+ "prev_image_url": self._get_image_url(patient_id, lesion.id, previous_img.id),
136
+ "curr_image_url": user_image_url,
137
+ "status_label": "STABLE",
138
+ "feature_changes": {},
139
+ "summary": _extract_response_text(compare_text),
140
+ }
141
+ if updated_img2 and updated_img2.comparison:
142
+ c = updated_img2.comparison
143
+ compare_result.update({
144
+ "status_label": c.get("status", "STABLE"),
145
+ "feature_changes": c.get("feature_changes", {}),
146
+ })
147
+ if c.get("summary"):
148
+ compare_result["summary"] = c["summary"]
149
+
150
+ yield {
151
+ "type": "tool_result",
152
+ "tool": "compare_images",
153
+ "call_id": compare_call_id,
154
+ "result": compare_result,
155
+ }
156
+
157
+ # Save assistant message
158
+ tool_calls_data = [{
159
+ "id": call_id,
160
+ "tool": "analyze_image",
161
+ "status": "complete",
162
+ "result": analysis_result,
163
+ }]
164
+ if compare_call_id and compare_result:
165
+ tool_calls_data.append({
166
+ "id": compare_call_id,
167
+ "tool": "compare_images",
168
+ "status": "complete",
169
+ "result": compare_result,
170
+ })
171
+
172
+ self.store.add_patient_chat_message(
173
+ patient_id, "assistant", analysis_text + compare_text,
174
+ tool_calls=tool_calls_data,
175
+ )
176
+
177
+ else:
178
+ # ----------------------------------------------------------------
179
+ # Text-only chat — stream chunks; tags are stripped on the frontend
180
+ # ----------------------------------------------------------------
181
+ self.store.add_patient_chat_message(patient_id, "user", content)
182
+
183
+ analysis_service._ensure_loaded()
184
+ response_text = ""
185
+ for chunk in analysis_service.agent.chat_followup(content):
186
+ yield {"type": "text", "content": chunk}
187
+ response_text += chunk
188
+
189
+ self.store.add_patient_chat_message(
190
+ patient_id, "assistant", _extract_response_text(response_text)
191
+ )
192
+
193
+
194
+ def get_chat_service() -> ChatService:
195
+ if ChatService._instance is None:
196
+ ChatService._instance = ChatService()
197
+ return ChatService._instance
data/case_store.py ADDED
@@ -0,0 +1,507 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Case Store - JSON-based persistence for patients, lesions, and images
3
+ """
4
+
5
+ import json
6
+ import uuid
7
+ import shutil
8
+ from pathlib import Path
9
+ from datetime import datetime
10
+ from typing import List, Dict, Optional, Any
11
+ from dataclasses import dataclass, field, asdict
12
+ from PIL import Image as PILImage
13
+
14
+
15
+ @dataclass
16
+ class ChatMessage:
17
+ role: str # "user" or "assistant"
18
+ content: str
19
+ timestamp: str = field(default_factory=lambda: datetime.utcnow().isoformat())
20
+
21
+
22
+ @dataclass
23
+ class LesionImage:
24
+ """A single image capture of a lesion at a point in time"""
25
+ id: str
26
+ lesion_id: str
27
+ timestamp: str = field(default_factory=lambda: datetime.utcnow().isoformat())
28
+ image_path: Optional[str] = None
29
+ gradcam_path: Optional[str] = None
30
+ analysis: Optional[Dict[str, Any]] = None # {diagnosis, confidence, monet_features}
31
+ comparison: Optional[Dict[str, Any]] = None # {status, feature_changes, summary}
32
+ is_original: bool = False
33
+ stage: str = "pending" # pending, analyzing, complete, error
34
+
35
+
36
+ @dataclass
37
+ class Lesion:
38
+ """A tracked lesion that can have multiple images over time"""
39
+ id: str
40
+ patient_id: str
41
+ name: str # User-provided label (e.g., "Left shoulder mole")
42
+ location: str = "" # Body location
43
+ created_at: str = field(default_factory=lambda: datetime.utcnow().isoformat())
44
+ chat_history: List[Dict] = field(default_factory=list)
45
+
46
+
47
+ @dataclass
48
+ class Patient:
49
+ """A patient who can have multiple lesions"""
50
+ id: str
51
+ name: str
52
+ created_at: str = field(default_factory=lambda: datetime.utcnow().isoformat())
53
+
54
+
55
+ class CaseStore:
56
+ """JSON-based persistence for patients, lesions, and images"""
57
+
58
+ def __init__(self, data_dir: str = None):
59
+ if data_dir is None:
60
+ data_dir = Path(__file__).parent
61
+ self.data_dir = Path(data_dir)
62
+ self.patients_file = self.data_dir / "patients.json"
63
+ self.lesions_dir = self.data_dir / "lesions"
64
+ self.uploads_dir = self.data_dir / "uploads"
65
+
66
+ # Ensure directories exist
67
+ self.lesions_dir.mkdir(parents=True, exist_ok=True)
68
+ self.uploads_dir.mkdir(parents=True, exist_ok=True)
69
+
70
+ # Initialize patients file if needed
71
+ if not self.patients_file.exists():
72
+ self._init_patients_file()
73
+
74
+ def _init_patients_file(self):
75
+ """Initialize patients file"""
76
+ data = {"patients": []}
77
+ with open(self.patients_file, 'w') as f:
78
+ json.dump(data, f, indent=2)
79
+
80
+ def _load_patients_data(self) -> Dict:
81
+ """Load patients JSON file"""
82
+ with open(self.patients_file, 'r') as f:
83
+ return json.load(f)
84
+
85
+ def _save_patients_data(self, data: Dict):
86
+ """Save patients JSON file"""
87
+ with open(self.patients_file, 'w') as f:
88
+ json.dump(data, f, indent=2)
89
+
90
+ # -------------------------------------------------------------------------
91
+ # Patient Methods
92
+ # -------------------------------------------------------------------------
93
+
94
+ def list_patients(self) -> List[Patient]:
95
+ """List all patients"""
96
+ data = self._load_patients_data()
97
+ return [Patient(**p) for p in data.get("patients", [])]
98
+
99
+ def get_patient(self, patient_id: str) -> Optional[Patient]:
100
+ """Get a patient by ID"""
101
+ data = self._load_patients_data()
102
+ for p in data.get("patients", []):
103
+ if p["id"] == patient_id:
104
+ return Patient(**p)
105
+ return None
106
+
107
+ def create_patient(self, name: str) -> Patient:
108
+ """Create a new patient"""
109
+ patient = Patient(
110
+ id=f"patient-{uuid.uuid4().hex[:8]}",
111
+ name=name
112
+ )
113
+
114
+ data = self._load_patients_data()
115
+ data["patients"].append(asdict(patient))
116
+ self._save_patients_data(data)
117
+
118
+ # Create directory for this patient's lesions
119
+ (self.lesions_dir / patient.id).mkdir(exist_ok=True)
120
+
121
+ return patient
122
+
123
+ def delete_patient(self, patient_id: str):
124
+ """Delete a patient and all their lesions"""
125
+ data = self._load_patients_data()
126
+ data["patients"] = [p for p in data["patients"] if p["id"] != patient_id]
127
+ self._save_patients_data(data)
128
+
129
+ # Delete lesion files
130
+ patient_lesions_dir = self.lesions_dir / patient_id
131
+ if patient_lesions_dir.exists():
132
+ shutil.rmtree(patient_lesions_dir)
133
+
134
+ # Delete uploads
135
+ patient_uploads_dir = self.uploads_dir / patient_id
136
+ if patient_uploads_dir.exists():
137
+ shutil.rmtree(patient_uploads_dir)
138
+
139
+ # Delete patient chat history
140
+ patient_chat_file = self.data_dir / "patient_chats" / f"{patient_id}.json"
141
+ if patient_chat_file.exists():
142
+ patient_chat_file.unlink()
143
+
144
+ def get_patient_lesion_count(self, patient_id: str) -> int:
145
+ """Get number of lesions for a patient"""
146
+ return len(self.list_lesions(patient_id))
147
+
148
+ # -------------------------------------------------------------------------
149
+ # Lesion Methods
150
+ # -------------------------------------------------------------------------
151
+
152
+ def _get_lesion_path(self, patient_id: str, lesion_id: str) -> Path:
153
+ """Get path to lesion JSON file"""
154
+ return self.lesions_dir / patient_id / f"{lesion_id}.json"
155
+
156
+ def list_lesions(self, patient_id: str) -> List[Lesion]:
157
+ """List all lesions for a patient"""
158
+ patient_dir = self.lesions_dir / patient_id
159
+ if not patient_dir.exists():
160
+ return []
161
+
162
+ lesions = []
163
+ for f in sorted(patient_dir.glob("*.json")):
164
+ with open(f, 'r') as fp:
165
+ data = json.load(fp)
166
+ # Only load lesion data, not images
167
+ lesion_data = {k: v for k, v in data.items() if k != 'images'}
168
+ lesions.append(Lesion(**lesion_data))
169
+
170
+ lesions.sort(key=lambda x: x.created_at)
171
+ return lesions
172
+
173
+ def get_lesion(self, patient_id: str, lesion_id: str) -> Optional[Lesion]:
174
+ """Get a lesion by ID"""
175
+ path = self._get_lesion_path(patient_id, lesion_id)
176
+ if not path.exists():
177
+ return None
178
+
179
+ with open(path, 'r') as f:
180
+ data = json.load(f)
181
+ lesion_data = {k: v for k, v in data.items() if k != 'images'}
182
+ return Lesion(**lesion_data)
183
+
184
+ def create_lesion(self, patient_id: str, name: str, location: str = "") -> Lesion:
185
+ """Create a new lesion for a patient"""
186
+ lesion = Lesion(
187
+ id=f"lesion-{uuid.uuid4().hex[:8]}",
188
+ patient_id=patient_id,
189
+ name=name,
190
+ location=location
191
+ )
192
+
193
+ # Ensure patient directory exists
194
+ patient_dir = self.lesions_dir / patient_id
195
+ patient_dir.mkdir(exist_ok=True)
196
+
197
+ # Save lesion with empty images array
198
+ self._save_lesion_data(patient_id, lesion.id, {
199
+ **asdict(lesion),
200
+ "images": []
201
+ })
202
+
203
+ return lesion
204
+
205
+ def _save_lesion_data(self, patient_id: str, lesion_id: str, data: Dict):
206
+ """Save lesion data to JSON file"""
207
+ path = self._get_lesion_path(patient_id, lesion_id)
208
+ with open(path, 'w') as f:
209
+ json.dump(data, f, indent=2)
210
+
211
+ def _load_lesion_data(self, patient_id: str, lesion_id: str) -> Optional[Dict]:
212
+ """Load full lesion data including images"""
213
+ path = self._get_lesion_path(patient_id, lesion_id)
214
+ if not path.exists():
215
+ return None
216
+
217
+ with open(path, 'r') as f:
218
+ return json.load(f)
219
+
220
+ def delete_lesion(self, patient_id: str, lesion_id: str):
221
+ """Delete a lesion and all its images"""
222
+ path = self._get_lesion_path(patient_id, lesion_id)
223
+ if path.exists():
224
+ path.unlink()
225
+
226
+ # Delete uploads for this lesion
227
+ lesion_uploads_dir = self.uploads_dir / patient_id / lesion_id
228
+ if lesion_uploads_dir.exists():
229
+ shutil.rmtree(lesion_uploads_dir)
230
+
231
+ def update_lesion(self, patient_id: str, lesion_id: str, name: str = None, location: str = None):
232
+ """Update lesion name or location"""
233
+ data = self._load_lesion_data(patient_id, lesion_id)
234
+ if data is None:
235
+ return
236
+
237
+ if name is not None:
238
+ data["name"] = name
239
+ if location is not None:
240
+ data["location"] = location
241
+
242
+ self._save_lesion_data(patient_id, lesion_id, data)
243
+
244
+ # -------------------------------------------------------------------------
245
+ # LesionImage Methods
246
+ # -------------------------------------------------------------------------
247
+
248
+ def list_images(self, patient_id: str, lesion_id: str) -> List[LesionImage]:
249
+ """List all images for a lesion"""
250
+ data = self._load_lesion_data(patient_id, lesion_id)
251
+ if data is None:
252
+ return []
253
+
254
+ images = [LesionImage(**img) for img in data.get("images", [])]
255
+ images.sort(key=lambda x: x.timestamp)
256
+ return images
257
+
258
+ def get_image(self, patient_id: str, lesion_id: str, image_id: str) -> Optional[LesionImage]:
259
+ """Get an image by ID"""
260
+ data = self._load_lesion_data(patient_id, lesion_id)
261
+ if data is None:
262
+ return None
263
+
264
+ for img in data.get("images", []):
265
+ if img["id"] == image_id:
266
+ return LesionImage(**img)
267
+ return None
268
+
269
+ def add_image(self, patient_id: str, lesion_id: str) -> LesionImage:
270
+ """Add a new image to a lesion's timeline"""
271
+ data = self._load_lesion_data(patient_id, lesion_id)
272
+ if data is None:
273
+ raise ValueError(f"Lesion {lesion_id} not found")
274
+
275
+ # Check if this is the first image
276
+ is_first = len(data.get("images", [])) == 0
277
+
278
+ image = LesionImage(
279
+ id=f"img-{uuid.uuid4().hex[:8]}",
280
+ lesion_id=lesion_id,
281
+ is_original=is_first
282
+ )
283
+
284
+ if "images" not in data:
285
+ data["images"] = []
286
+ data["images"].append(asdict(image))
287
+ self._save_lesion_data(patient_id, lesion_id, data)
288
+
289
+ return image
290
+
291
+ def update_image(
292
+ self,
293
+ patient_id: str,
294
+ lesion_id: str,
295
+ image_id: str,
296
+ image_path: str = None,
297
+ gradcam_path: str = None,
298
+ analysis: Dict = None,
299
+ comparison: Dict = None,
300
+ stage: str = None
301
+ ):
302
+ """Update an image's data"""
303
+ data = self._load_lesion_data(patient_id, lesion_id)
304
+ if data is None:
305
+ return
306
+
307
+ for img in data.get("images", []):
308
+ if img["id"] == image_id:
309
+ if image_path is not None:
310
+ img["image_path"] = image_path
311
+ if gradcam_path is not None:
312
+ img["gradcam_path"] = gradcam_path
313
+ if analysis is not None:
314
+ img["analysis"] = analysis
315
+ if comparison is not None:
316
+ img["comparison"] = comparison
317
+ if stage is not None:
318
+ img["stage"] = stage
319
+ break
320
+
321
+ self._save_lesion_data(patient_id, lesion_id, data)
322
+
323
+ def save_lesion_image(
324
+ self,
325
+ patient_id: str,
326
+ lesion_id: str,
327
+ image_id: str,
328
+ image: PILImage.Image,
329
+ filename: str = "image.png"
330
+ ) -> str:
331
+ """Save an uploaded image file, return the path"""
332
+ upload_dir = self.uploads_dir / patient_id / lesion_id / image_id
333
+ upload_dir.mkdir(parents=True, exist_ok=True)
334
+
335
+ image_path = upload_dir / filename
336
+ image.save(image_path)
337
+
338
+ return str(image_path)
339
+
340
+ def get_previous_image(
341
+ self,
342
+ patient_id: str,
343
+ lesion_id: str,
344
+ current_image_id: str
345
+ ) -> Optional[LesionImage]:
346
+ """Get the image before the current one (for comparison)"""
347
+ images = self.list_images(patient_id, lesion_id)
348
+
349
+ for i, img in enumerate(images):
350
+ if img.id == current_image_id and i > 0:
351
+ return images[i - 1]
352
+ return None
353
+
354
+ # -------------------------------------------------------------------------
355
+ # Chat Methods (scoped to lesion)
356
+ # -------------------------------------------------------------------------
357
+
358
+ def add_chat_message(self, patient_id: str, lesion_id: str, role: str, content: str):
359
+ """Add a chat message to a lesion"""
360
+ data = self._load_lesion_data(patient_id, lesion_id)
361
+ if data is None:
362
+ return
363
+
364
+ message = ChatMessage(role=role, content=content)
365
+ if "chat_history" not in data:
366
+ data["chat_history"] = []
367
+ data["chat_history"].append(asdict(message))
368
+ self._save_lesion_data(patient_id, lesion_id, data)
369
+
370
+ def get_chat_history(self, patient_id: str, lesion_id: str) -> List[ChatMessage]:
371
+ """Get chat history for a lesion"""
372
+ data = self._load_lesion_data(patient_id, lesion_id)
373
+ if data is None:
374
+ return []
375
+
376
+ return [ChatMessage(**m) for m in data.get("chat_history", [])]
377
+
378
+ def clear_chat_history(self, patient_id: str, lesion_id: str):
379
+ """Clear chat history for a lesion"""
380
+ data = self._load_lesion_data(patient_id, lesion_id)
381
+ if data is None:
382
+ return
383
+
384
+ data["chat_history"] = []
385
+ self._save_lesion_data(patient_id, lesion_id, data)
386
+
387
+ # -------------------------------------------------------------------------
388
+ # Patient-level Chat Methods
389
+ # -------------------------------------------------------------------------
390
+
391
+ def _get_patient_chat_file(self, patient_id: str) -> Path:
392
+ """Get path to patient-level chat JSON file"""
393
+ chat_dir = self.data_dir / "patient_chats"
394
+ chat_dir.mkdir(exist_ok=True)
395
+ return chat_dir / f"{patient_id}.json"
396
+
397
+ def get_patient_chat_history(self, patient_id: str) -> List[dict]:
398
+ """Get chat history for a patient"""
399
+ chat_file = self._get_patient_chat_file(patient_id)
400
+ if not chat_file.exists():
401
+ return []
402
+ with open(chat_file, 'r') as f:
403
+ data = json.load(f)
404
+ return data.get("messages", [])
405
+
406
+ def add_patient_chat_message(
407
+ self,
408
+ patient_id: str,
409
+ role: str,
410
+ content: str,
411
+ image_url: Optional[str] = None,
412
+ tool_calls: Optional[list] = None
413
+ ):
414
+ """Add a message to patient-level chat history"""
415
+ chat_file = self._get_patient_chat_file(patient_id)
416
+ if chat_file.exists():
417
+ with open(chat_file, 'r') as f:
418
+ data = json.load(f)
419
+ else:
420
+ data = {"messages": []}
421
+
422
+ message: Dict[str, Any] = {
423
+ "id": f"msg-{uuid.uuid4().hex[:8]}",
424
+ "role": role,
425
+ "content": content,
426
+ "timestamp": datetime.utcnow().isoformat(),
427
+ }
428
+ if image_url is not None:
429
+ message["image_url"] = image_url
430
+ if tool_calls is not None:
431
+ message["tool_calls"] = tool_calls
432
+
433
+ data["messages"].append(message)
434
+ with open(chat_file, 'w') as f:
435
+ json.dump(data, f, indent=2)
436
+
437
+ def clear_patient_chat_history(self, patient_id: str):
438
+ """Clear patient-level chat history"""
439
+ chat_file = self._get_patient_chat_file(patient_id)
440
+ with open(chat_file, 'w') as f:
441
+ json.dump({"messages": []}, f)
442
+
443
+ def get_or_create_chat_lesion(self, patient_id: str) -> 'Lesion':
444
+ """Get or create the internal chat-images lesion for a patient"""
445
+ for lesion in self.list_lesions(patient_id):
446
+ if lesion.name == "__chat_images__":
447
+ return lesion
448
+ return self.create_lesion(patient_id, "__chat_images__", "internal")
449
+
450
+ def get_latest_chat_image(self, patient_id: str) -> Optional['LesionImage']:
451
+ """Get the most recently analyzed chat image for a patient"""
452
+ lesion = self.get_or_create_chat_lesion(patient_id)
453
+ images = self.list_images(patient_id, lesion.id)
454
+ for img in reversed(images):
455
+ if img.analysis is not None:
456
+ return img
457
+ return None
458
+
459
+
460
+ # Singleton instance
461
+ _store_instance = None
462
+
463
+
464
+ def get_case_store() -> CaseStore:
465
+ """Get or create CaseStore singleton"""
466
+ global _store_instance
467
+ if _store_instance is None:
468
+ _store_instance = CaseStore()
469
+ return _store_instance
470
+
471
+
472
+ if __name__ == "__main__":
473
+ # Test the store
474
+ store = CaseStore()
475
+
476
+ print("Patients:")
477
+ for patient in store.list_patients():
478
+ print(f" - {patient.id}: {patient.name}")
479
+
480
+ # Create a test patient
481
+ print("\nCreating test patient...")
482
+ patient = store.create_patient("Test Patient")
483
+ print(f" Created: {patient.id}")
484
+
485
+ # Create a lesion
486
+ print("\nCreating lesion...")
487
+ lesion = store.create_lesion(patient.id, "Left shoulder mole", "Left shoulder")
488
+ print(f" Created: {lesion.id}")
489
+
490
+ # Add an image
491
+ print("\nAdding image...")
492
+ image = store.add_image(patient.id, lesion.id)
493
+ print(f" Created: {image.id} (is_original={image.is_original})")
494
+
495
+ # Add another image
496
+ image2 = store.add_image(patient.id, lesion.id)
497
+ print(f" Created: {image2.id} (is_original={image2.is_original})")
498
+
499
+ # List images
500
+ print(f"\nImages for lesion {lesion.id}:")
501
+ for img in store.list_images(patient.id, lesion.id):
502
+ print(f" - {img.id}: original={img.is_original}, stage={img.stage}")
503
+
504
+ # Cleanup
505
+ print("\nCleaning up test patient...")
506
+ store.delete_patient(patient.id)
507
+ print("Done!")
frontend/app.py ADDED
@@ -0,0 +1,532 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SkinProAI Frontend - Modular Gradio application
3
+ """
4
+
5
+ import gradio as gr
6
+ from typing import Dict, Generator, Optional
7
+ from datetime import datetime
8
+ import sys
9
+ import os
10
+ import re
11
+ import base64
12
+
13
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
14
+
15
+ from data.case_store import get_case_store
16
+ from frontend.components.styles import MAIN_CSS
17
+ from frontend.components.analysis_view import format_output
18
+
19
+
20
+ # =============================================================================
21
+ # CONFIG
22
+ # =============================================================================
23
+
24
+ class Config:
25
+ APP_TITLE = "SkinProAI"
26
+ SERVER_PORT = int(os.environ.get("GRADIO_SERVER_PORT", "7860"))
27
+ HF_SPACES = os.environ.get("SPACE_ID") is not None
28
+
29
+
30
+ # =============================================================================
31
+ # AGENT
32
+ # =============================================================================
33
+
34
+ class AnalysisAgent:
35
+ """Wrapper for the MedGemma analysis agent"""
36
+
37
+ def __init__(self):
38
+ self.model = None
39
+ self.loaded = False
40
+
41
+ def load(self):
42
+ if self.loaded:
43
+ return
44
+ from models.medgemma_agent import MedGemmaAgent
45
+ self.model = MedGemmaAgent(verbose=True)
46
+ self.model.load_model()
47
+ self.loaded = True
48
+
49
+ def analyze(self, image_path: str, question: str = "") -> Generator[str, None, None]:
50
+ if not self.loaded:
51
+ yield "[STAGE:loading]Loading AI models...[/STAGE]\n"
52
+ self.load()
53
+
54
+ for chunk in self.model.analyze_image_stream(image_path, question=question):
55
+ yield chunk
56
+
57
+ def management_guidance(self, confirmed: bool, feedback: str = None) -> Generator[str, None, None]:
58
+ for chunk in self.model.generate_management_guidance(confirmed, feedback):
59
+ yield chunk
60
+
61
+ def followup(self, message: str) -> Generator[str, None, None]:
62
+ if not self.loaded or not self.model.last_diagnosis:
63
+ yield "[ERROR]No analysis context available.[/ERROR]\n"
64
+ return
65
+ for chunk in self.model.chat_followup(message):
66
+ yield chunk
67
+
68
+ def reset(self):
69
+ if self.model:
70
+ self.model.reset_state()
71
+
72
+
73
+ agent = AnalysisAgent()
74
+ case_store = get_case_store()
75
+
76
+
77
+ # =============================================================================
78
+ # APP
79
+ # =============================================================================
80
+
81
+ with gr.Blocks(title=Config.APP_TITLE, css=MAIN_CSS, theme=gr.themes.Soft()) as app:
82
+
83
+ # =========================================================================
84
+ # STATE
85
+ # =========================================================================
86
+ state = gr.State({
87
+ "page": "patient_select", # patient_select | analysis
88
+ "case_id": None,
89
+ "instance_id": None,
90
+ "output": "",
91
+ "gradcam_base64": None
92
+ })
93
+
94
+ # =========================================================================
95
+ # PAGE 1: PATIENT SELECTION
96
+ # =========================================================================
97
+ with gr.Group(visible=True, elem_classes=["patient-select-container"]) as page_patient:
98
+ gr.Markdown("# SkinProAI", elem_classes=["patient-select-title"])
99
+ gr.Markdown("Select a patient to continue or create a new case", elem_classes=["patient-select-subtitle"])
100
+
101
+ with gr.Row(elem_classes=["patient-grid"]):
102
+ btn_demo_melanoma = gr.Button("Demo: Melanocytic Lesion", elem_classes=["patient-card"])
103
+ btn_demo_ak = gr.Button("Demo: Actinic Keratosis", elem_classes=["patient-card"])
104
+ btn_new_patient = gr.Button("+ New Patient", variant="primary", elem_classes=["new-patient-btn"])
105
+
106
+ # =========================================================================
107
+ # PAGE 2: ANALYSIS
108
+ # =========================================================================
109
+ with gr.Group(visible=False) as page_analysis:
110
+
111
+ # Header
112
+ with gr.Row(elem_classes=["app-header"]):
113
+ gr.Markdown(f"**{Config.APP_TITLE}**", elem_classes=["app-title"])
114
+ btn_back = gr.Button("< Back to Patients", elem_classes=["back-btn"])
115
+
116
+ with gr.Row(elem_classes=["analysis-container"]):
117
+
118
+ # Sidebar (previous queries)
119
+ with gr.Column(scale=0, min_width=260, visible=False, elem_classes=["query-sidebar"]) as sidebar:
120
+ gr.Markdown("### Previous Queries", elem_classes=["sidebar-header"])
121
+ sidebar_list = gr.Column(elem_id="sidebar-queries")
122
+ btn_new_query = gr.Button("+ New Query", size="sm", variant="primary")
123
+
124
+ # Main content
125
+ with gr.Column(scale=4, elem_classes=["main-content"]):
126
+
127
+ # Input view (greeting style)
128
+ with gr.Group(visible=True, elem_classes=["input-greeting"]) as view_input:
129
+ gr.Markdown("What would you like to analyze?", elem_classes=["greeting-title"])
130
+ gr.Markdown("Upload an image and describe what you'd like to know", elem_classes=["greeting-subtitle"])
131
+
132
+ with gr.Column(elem_classes=["input-box-container"]):
133
+ input_message = gr.Textbox(
134
+ placeholder="Describe the lesion or ask a question...",
135
+ show_label=False,
136
+ lines=2,
137
+ elem_classes=["message-input"]
138
+ )
139
+
140
+ input_image = gr.Image(
141
+ type="pil",
142
+ height=180,
143
+ show_label=False,
144
+ elem_classes=["image-preview"]
145
+ )
146
+
147
+ with gr.Row(elem_classes=["input-actions"]):
148
+ gr.Markdown("*Upload a skin lesion image*")
149
+ btn_analyze = gr.Button("Analyze", elem_classes=["send-btn"], interactive=False)
150
+
151
+ # Results view (shown after analysis)
152
+ with gr.Group(visible=False, elem_classes=["chat-view"]) as view_results:
153
+ output_html = gr.HTML(
154
+ value='<div class="analysis-output">Starting...</div>',
155
+ elem_classes=["results-area"]
156
+ )
157
+
158
+ # Confirmation
159
+ with gr.Group(visible=False, elem_classes=["confirm-buttons"]) as confirm_box:
160
+ gr.Markdown("**Do you agree with this diagnosis?**")
161
+ with gr.Row():
162
+ btn_confirm_yes = gr.Button("Yes, continue", variant="primary", size="sm")
163
+ btn_confirm_no = gr.Button("No, I disagree", variant="secondary", size="sm")
164
+ input_feedback = gr.Textbox(label="Your assessment", placeholder="Enter diagnosis...", visible=False)
165
+ btn_submit_feedback = gr.Button("Submit", visible=False, size="sm")
166
+
167
+ # Follow-up
168
+ with gr.Row(elem_classes=["chat-input-area"]):
169
+ input_followup = gr.Textbox(placeholder="Ask a follow-up question...", show_label=False, lines=1, scale=4)
170
+ btn_followup = gr.Button("Send", size="sm", scale=1)
171
+
172
+ # =========================================================================
173
+ # DYNAMIC SIDEBAR RENDERING
174
+ # =========================================================================
175
+ @gr.render(inputs=[state], triggers=[state.change])
176
+ def render_sidebar(s):
177
+ case_id = s.get("case_id")
178
+ if not case_id or s.get("page") != "analysis":
179
+ return
180
+
181
+ instances = case_store.list_instances(case_id)
182
+ current = s.get("instance_id")
183
+
184
+ for i, inst in enumerate(instances, 1):
185
+ diagnosis = "Pending"
186
+ if inst.analysis and inst.analysis.get("diagnosis"):
187
+ d = inst.analysis["diagnosis"]
188
+ diagnosis = d.get("class", "?")
189
+
190
+ label = f"#{i}: {diagnosis}"
191
+ variant = "primary" if inst.id == current else "secondary"
192
+ btn = gr.Button(label, size="sm", variant=variant, elem_classes=["query-item"])
193
+
194
+ # Attach click handler to load this instance
195
+ def load_instance(inst_id=inst.id, c_id=case_id):
196
+ def _load(current_state):
197
+ current_state["instance_id"] = inst_id
198
+ instance = case_store.get_instance(c_id, inst_id)
199
+
200
+ # Load saved output if available
201
+ output_html = '<div class="analysis-output"><div class="result">Previous analysis loaded</div></div>'
202
+ if instance and instance.analysis:
203
+ diag = instance.analysis.get("diagnosis", {})
204
+ output_html = f'<div class="analysis-output"><div class="result">Diagnosis: {diag.get("full_name", diag.get("class", "Unknown"))}</div></div>'
205
+
206
+ return (
207
+ current_state,
208
+ gr.update(visible=False), # view_input
209
+ gr.update(visible=True), # view_results
210
+ output_html,
211
+ gr.update(visible=False) # confirm_box
212
+ )
213
+ return _load
214
+
215
+ btn.click(
216
+ load_instance(),
217
+ inputs=[state],
218
+ outputs=[state, view_input, view_results, output_html, confirm_box]
219
+ )
220
+
221
+ # =========================================================================
222
+ # EVENT HANDLERS
223
+ # =========================================================================
224
+
225
+ def select_patient(case_id: str, s: Dict):
226
+ """Handle patient selection"""
227
+ s["case_id"] = case_id
228
+ s["page"] = "analysis"
229
+
230
+ instances = case_store.list_instances(case_id)
231
+ has_queries = len(instances) > 0
232
+
233
+ if has_queries:
234
+ # Load most recent
235
+ inst = instances[-1]
236
+ s["instance_id"] = inst.id
237
+
238
+ # Load image if exists
239
+ img = None
240
+ if inst.image_path and os.path.exists(inst.image_path):
241
+ from PIL import Image
242
+ img = Image.open(inst.image_path)
243
+
244
+ return (
245
+ s,
246
+ gr.update(visible=False), # page_patient
247
+ gr.update(visible=True), # page_analysis
248
+ gr.update(visible=True), # sidebar
249
+ gr.update(visible=False), # view_input
250
+ gr.update(visible=True), # view_results
251
+ '<div class="analysis-output"><div class="result">Previous analysis loaded</div></div>',
252
+ gr.update(visible=False) # confirm_box
253
+ )
254
+ else:
255
+ # New instance
256
+ inst = case_store.create_instance(case_id)
257
+ s["instance_id"] = inst.id
258
+ s["output"] = ""
259
+
260
+ return (
261
+ s,
262
+ gr.update(visible=False),
263
+ gr.update(visible=True),
264
+ gr.update(visible=False), # sidebar hidden for new patient
265
+ gr.update(visible=True), # view_input
266
+ gr.update(visible=False), # view_results
267
+ "",
268
+ gr.update(visible=False)
269
+ )
270
+
271
+ def new_patient(s: Dict):
272
+ """Create new patient"""
273
+ case = case_store.create_case(f"Patient {datetime.now().strftime('%Y-%m-%d %H:%M')}")
274
+ return select_patient(case.id, s)
275
+
276
+ def go_back(s: Dict):
277
+ """Return to patient selection"""
278
+ s["page"] = "patient_select"
279
+ s["case_id"] = None
280
+ s["instance_id"] = None
281
+ s["output"] = ""
282
+
283
+ return (
284
+ s,
285
+ gr.update(visible=True), # page_patient
286
+ gr.update(visible=False), # page_analysis
287
+ gr.update(visible=False), # sidebar
288
+ gr.update(visible=True), # view_input
289
+ gr.update(visible=False), # view_results
290
+ "",
291
+ gr.update(visible=False) # confirm_box
292
+ )
293
+
294
+ def new_query(s: Dict):
295
+ """Start new query for current patient"""
296
+ case_id = s.get("case_id")
297
+ if not case_id:
298
+ return s, gr.update(), gr.update(), gr.update(), "", gr.update()
299
+
300
+ inst = case_store.create_instance(case_id)
301
+ s["instance_id"] = inst.id
302
+ s["output"] = ""
303
+ s["gradcam_base64"] = None
304
+
305
+ agent.reset()
306
+
307
+ return (
308
+ s,
309
+ gr.update(visible=True), # view_input
310
+ gr.update(visible=False), # view_results
311
+ None, # clear image
312
+ "", # clear output
313
+ gr.update(visible=False) # confirm_box
314
+ )
315
+
316
+ def enable_analyze(img):
317
+ """Enable analyze button when image uploaded"""
318
+ return gr.update(interactive=img is not None)
319
+
320
+ def run_analysis(image, message, s: Dict):
321
+ """Run analysis on uploaded image"""
322
+ if image is None:
323
+ yield s, gr.update(), gr.update(), gr.update(), gr.update()
324
+ return
325
+
326
+ case_id = s["case_id"]
327
+ instance_id = s["instance_id"]
328
+
329
+ # Save image
330
+ image_path = case_store.save_image(case_id, instance_id, image)
331
+ case_store.update_analysis(case_id, instance_id, stage="analyzing", image_path=image_path)
332
+
333
+ agent.reset()
334
+ s["output"] = ""
335
+ gradcam_base64 = None
336
+ has_confirm = False
337
+
338
+ # Switch to results view
339
+ yield (
340
+ s,
341
+ gr.update(visible=False), # view_input
342
+ gr.update(visible=True), # view_results
343
+ '<div class="analysis-output">Starting analysis...</div>',
344
+ gr.update(visible=False) # confirm_box
345
+ )
346
+
347
+ partial = ""
348
+ for chunk in agent.analyze(image_path, message or ""):
349
+ partial += chunk
350
+
351
+ # Check for GradCAM
352
+ if gradcam_base64 is None:
353
+ match = re.search(r'\[GRADCAM_IMAGE:([^\]]+)\]', partial)
354
+ if match:
355
+ path = match.group(1)
356
+ if os.path.exists(path):
357
+ try:
358
+ with open(path, "rb") as f:
359
+ gradcam_base64 = base64.b64encode(f.read()).decode('utf-8')
360
+ s["gradcam_base64"] = gradcam_base64
361
+ except:
362
+ pass
363
+
364
+ if '[CONFIRM:' in partial:
365
+ has_confirm = True
366
+
367
+ s["output"] = partial
368
+
369
+ yield (
370
+ s,
371
+ gr.update(visible=False),
372
+ gr.update(visible=True),
373
+ format_output(partial, gradcam_base64),
374
+ gr.update(visible=has_confirm)
375
+ )
376
+
377
+ # Save analysis
378
+ if agent.model and agent.model.last_diagnosis:
379
+ diag = agent.model.last_diagnosis["predictions"][0]
380
+ case_store.update_analysis(
381
+ case_id, instance_id,
382
+ stage="awaiting_confirmation",
383
+ analysis={"diagnosis": diag}
384
+ )
385
+
386
+ def confirm_yes(s: Dict):
387
+ """User confirmed diagnosis"""
388
+ partial = s.get("output", "")
389
+ gradcam = s.get("gradcam_base64")
390
+
391
+ for chunk in agent.management_guidance(confirmed=True):
392
+ partial += chunk
393
+ s["output"] = partial
394
+ yield s, format_output(partial, gradcam), gr.update(visible=False)
395
+
396
+ case_store.update_analysis(s["case_id"], s["instance_id"], stage="complete")
397
+
398
+ def confirm_no():
399
+ """Show feedback input"""
400
+ return gr.update(visible=True), gr.update(visible=True)
401
+
402
+ def submit_feedback(feedback: str, s: Dict):
403
+ """Submit user feedback"""
404
+ partial = s.get("output", "")
405
+ gradcam = s.get("gradcam_base64")
406
+
407
+ for chunk in agent.management_guidance(confirmed=False, feedback=feedback):
408
+ partial += chunk
409
+ s["output"] = partial
410
+ yield (
411
+ s,
412
+ format_output(partial, gradcam),
413
+ gr.update(visible=False),
414
+ gr.update(visible=False),
415
+ gr.update(visible=False),
416
+ ""
417
+ )
418
+
419
+ case_store.update_analysis(s["case_id"], s["instance_id"], stage="complete")
420
+
421
+ def send_followup(message: str, s: Dict):
422
+ """Send follow-up question"""
423
+ if not message.strip():
424
+ return s, gr.update(), ""
425
+
426
+ case_store.add_chat_message(s["case_id"], s["instance_id"], "user", message)
427
+
428
+ partial = s.get("output", "")
429
+ gradcam = s.get("gradcam_base64")
430
+
431
+ partial += f'\n<div class="chat-message user">You: {message}</div>\n'
432
+
433
+ response = ""
434
+ for chunk in agent.followup(message):
435
+ response += chunk
436
+ s["output"] = partial + response
437
+ yield s, format_output(partial + response, gradcam), ""
438
+
439
+ case_store.add_chat_message(s["case_id"], s["instance_id"], "assistant", response)
440
+
441
+ # =========================================================================
442
+ # WIRE EVENTS
443
+ # =========================================================================
444
+
445
+ # Patient selection
446
+ btn_demo_melanoma.click(
447
+ lambda s: select_patient("demo-melanoma", s),
448
+ inputs=[state],
449
+ outputs=[state, page_patient, page_analysis, sidebar, view_input, view_results, output_html, confirm_box]
450
+ )
451
+
452
+ btn_demo_ak.click(
453
+ lambda s: select_patient("demo-ak", s),
454
+ inputs=[state],
455
+ outputs=[state, page_patient, page_analysis, sidebar, view_input, view_results, output_html, confirm_box]
456
+ )
457
+
458
+ btn_new_patient.click(
459
+ new_patient,
460
+ inputs=[state],
461
+ outputs=[state, page_patient, page_analysis, sidebar, view_input, view_results, output_html, confirm_box]
462
+ )
463
+
464
+ # Navigation
465
+ btn_back.click(
466
+ go_back,
467
+ inputs=[state],
468
+ outputs=[state, page_patient, page_analysis, sidebar, view_input, view_results, output_html, confirm_box]
469
+ )
470
+
471
+ btn_new_query.click(
472
+ new_query,
473
+ inputs=[state],
474
+ outputs=[state, view_input, view_results, input_image, output_html, confirm_box]
475
+ )
476
+
477
+ # Analysis
478
+ input_image.change(enable_analyze, inputs=[input_image], outputs=[btn_analyze])
479
+
480
+ btn_analyze.click(
481
+ run_analysis,
482
+ inputs=[input_image, input_message, state],
483
+ outputs=[state, view_input, view_results, output_html, confirm_box]
484
+ )
485
+
486
+ # Confirmation
487
+ btn_confirm_yes.click(
488
+ confirm_yes,
489
+ inputs=[state],
490
+ outputs=[state, output_html, confirm_box]
491
+ )
492
+
493
+ btn_confirm_no.click(
494
+ confirm_no,
495
+ outputs=[input_feedback, btn_submit_feedback]
496
+ )
497
+
498
+ btn_submit_feedback.click(
499
+ submit_feedback,
500
+ inputs=[input_feedback, state],
501
+ outputs=[state, output_html, confirm_box, input_feedback, btn_submit_feedback, input_feedback]
502
+ )
503
+
504
+ # Follow-up
505
+ btn_followup.click(
506
+ send_followup,
507
+ inputs=[input_followup, state],
508
+ outputs=[state, output_html, input_followup]
509
+ )
510
+
511
+ input_followup.submit(
512
+ send_followup,
513
+ inputs=[input_followup, state],
514
+ outputs=[state, output_html, input_followup]
515
+ )
516
+
517
+
518
+ # =============================================================================
519
+ # MAIN
520
+ # =============================================================================
521
+
522
+ if __name__ == "__main__":
523
+ print(f"\n{'='*50}")
524
+ print(f" {Config.APP_TITLE}")
525
+ print(f"{'='*50}\n")
526
+
527
+ app.queue().launch(
528
+ server_name="0.0.0.0" if Config.HF_SPACES else "127.0.0.1",
529
+ server_port=Config.SERVER_PORT,
530
+ share=False,
531
+ show_error=True
532
+ )
frontend/components/__init__.py ADDED
File without changes
frontend/components/analysis_view.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Analysis View Component - Main analysis interface with input and results
3
+ """
4
+
5
+ import gradio as gr
6
+ import re
7
+ from typing import Optional
8
+
9
+
10
+ def parse_markdown(text: str) -> str:
11
+ """Convert basic markdown to HTML"""
12
+ text = re.sub(r'\*\*(.+?)\*\*', r'<strong>\1</strong>', text)
13
+ text = re.sub(r'__(.+?)__', r'<strong>\1</strong>', text)
14
+ text = re.sub(r'\*(.+?)\*', r'<em>\1</em>', text)
15
+
16
+ # Bullet lists
17
+ lines = text.split('\n')
18
+ in_list = False
19
+ result = []
20
+ for line in lines:
21
+ stripped = line.strip()
22
+ if re.match(r'^[\*\-] ', stripped):
23
+ if not in_list:
24
+ result.append('<ul>')
25
+ in_list = True
26
+ item = re.sub(r'^[\*\-] ', '', stripped)
27
+ result.append(f'<li>{item}</li>')
28
+ else:
29
+ if in_list:
30
+ result.append('</ul>')
31
+ in_list = False
32
+ result.append(line)
33
+ if in_list:
34
+ result.append('</ul>')
35
+
36
+ return '\n'.join(result)
37
+
38
+
39
+ # Regex patterns for output parsing
40
+ _STAGE_RE = re.compile(r'\[STAGE:(\w+)\](.*?)\[/STAGE\]')
41
+ _THINKING_RE = re.compile(r'\[THINKING\](.*?)\[/THINKING\]')
42
+ _OBSERVATION_RE = re.compile(r'\[OBSERVATION\](.*?)\[/OBSERVATION\]')
43
+ _TOOL_OUTPUT_RE = re.compile(r'\[TOOL_OUTPUT:(.*?)\]\n(.*?)\[/TOOL_OUTPUT\]', re.DOTALL)
44
+ _RESULT_RE = re.compile(r'\[RESULT\](.*?)\[/RESULT\]')
45
+ _ERROR_RE = re.compile(r'\[ERROR\](.*?)\[/ERROR\]')
46
+ _GRADCAM_RE = re.compile(r'\[GRADCAM_IMAGE:[^\]]+\]\n?')
47
+ _RESPONSE_RE = re.compile(r'\[RESPONSE\]\n(.*?)\n\[/RESPONSE\]', re.DOTALL)
48
+ _COMPLETE_RE = re.compile(r'\[COMPLETE\](.*?)\[/COMPLETE\]')
49
+ _CONFIRM_RE = re.compile(r'\[CONFIRM:(\w+)\](.*?)\[/CONFIRM\]')
50
+ _REFERENCES_RE = re.compile(r'\[REFERENCES\](.*?)\[/REFERENCES\]', re.DOTALL)
51
+ _REF_RE = re.compile(r'\[REF:([^:]+):([^:]+):([^:]+):([^:]+):([^\]]+)\]')
52
+
53
+
54
+ def format_output(raw_text: str, gradcam_base64: Optional[str] = None) -> str:
55
+ """Convert tagged output to styled HTML"""
56
+ html = raw_text
57
+
58
+ # Stage headers
59
+ html = _STAGE_RE.sub(
60
+ r'<div class="stage"><span class="stage-indicator"></span><span class="stage-text">\2</span></div>',
61
+ html
62
+ )
63
+
64
+ # Thinking
65
+ html = _THINKING_RE.sub(r'<div class="thinking">\1</div>', html)
66
+
67
+ # Observations
68
+ html = _OBSERVATION_RE.sub(r'<div class="observation">\1</div>', html)
69
+
70
+ # Tool outputs
71
+ html = _TOOL_OUTPUT_RE.sub(
72
+ r'<div class="tool-output"><div class="tool-header">\1</div><pre class="tool-content">\2</pre></div>',
73
+ html
74
+ )
75
+
76
+ # Results
77
+ html = _RESULT_RE.sub(r'<div class="result">\1</div>', html)
78
+
79
+ # Errors
80
+ html = _ERROR_RE.sub(r'<div class="error">\1</div>', html)
81
+
82
+ # GradCAM image
83
+ if gradcam_base64:
84
+ img_html = f'<div class="gradcam-inline"><div class="gradcam-header">Attention Map</div><img src="data:image/png;base64,{gradcam_base64}" alt="Grad-CAM"></div>'
85
+ html = _GRADCAM_RE.sub(img_html, html)
86
+ else:
87
+ html = _GRADCAM_RE.sub('', html)
88
+
89
+ # Response section
90
+ def format_response(match):
91
+ content = match.group(1)
92
+ parsed = parse_markdown(content)
93
+ parsed = re.sub(r'\n\n+', '</p><p>', parsed)
94
+ parsed = parsed.replace('\n', '<br>')
95
+ return f'<div class="response"><p>{parsed}</p></div>'
96
+
97
+ html = _RESPONSE_RE.sub(format_response, html)
98
+
99
+ # Complete
100
+ html = _COMPLETE_RE.sub(r'<div class="complete">\1</div>', html)
101
+
102
+ # Confirmation
103
+ html = _CONFIRM_RE.sub(
104
+ r'<div class="confirm-box"><div class="confirm-text">\2</div></div>',
105
+ html
106
+ )
107
+
108
+ # References
109
+ def format_references(match):
110
+ ref_content = match.group(1)
111
+ refs_html = ['<div class="references"><div class="references-header">References</div><ul>']
112
+ for ref_match in _REF_RE.finditer(ref_content):
113
+ _, source, page, filename, superscript = ref_match.groups()
114
+ refs_html.append(
115
+ f'<li><a href="guidelines/{filename}#page={page}" target="_blank" class="ref-link">'
116
+ f'<sup>{superscript}</sup> {source}, p.{page}</a></li>'
117
+ )
118
+ refs_html.append('</ul></div>')
119
+ return '\n'.join(refs_html)
120
+
121
+ html = _REFERENCES_RE.sub(format_references, html)
122
+
123
+ # Convert newlines
124
+ html = html.replace('\n', '<br>')
125
+
126
+ return f'<div class="analysis-output">{html}</div>'
127
+
128
+
129
+ def create_analysis_view():
130
+ """
131
+ Create the analysis view component.
132
+
133
+ Returns:
134
+ Tuple of (container, components dict)
135
+ """
136
+ with gr.Group(visible=False, elem_classes=["analysis-container"]) as container:
137
+
138
+ with gr.Row():
139
+ # Main content area
140
+ with gr.Column(elem_classes=["main-content"]):
141
+
142
+ # Input greeting (shown when no analysis yet)
143
+ with gr.Group(visible=True, elem_classes=["input-greeting"]) as input_greeting:
144
+ gr.Markdown("What would you like to analyze?", elem_classes=["greeting-title"])
145
+ gr.Markdown("Upload an image and describe what you'd like to know", elem_classes=["greeting-subtitle"])
146
+
147
+ with gr.Column(elem_classes=["input-box-container"]):
148
+ message_input = gr.Textbox(
149
+ placeholder="Describe the lesion or ask a question...",
150
+ show_label=False,
151
+ lines=3,
152
+ elem_classes=["message-input"]
153
+ )
154
+
155
+ # Image upload (compact)
156
+ image_input = gr.Image(
157
+ label="",
158
+ type="pil",
159
+ height=180,
160
+ elem_classes=["image-preview"],
161
+ show_label=False
162
+ )
163
+
164
+ with gr.Row(elem_classes=["input-actions"]):
165
+ upload_hint = gr.Markdown("*Upload a skin lesion image above*", visible=True)
166
+ send_btn = gr.Button("Analyze", elem_classes=["send-btn"], interactive=False)
167
+
168
+ # Chat/results view (shown after analysis starts)
169
+ with gr.Group(visible=False, elem_classes=["chat-view"]) as chat_view:
170
+ results_output = gr.HTML(
171
+ value='<div class="analysis-output">Starting analysis...</div>',
172
+ elem_classes=["results-area"]
173
+ )
174
+
175
+ # Confirmation buttons
176
+ with gr.Group(visible=False, elem_classes=["confirm-buttons"]) as confirm_group:
177
+ gr.Markdown("**Do you agree with this diagnosis?**")
178
+ with gr.Row():
179
+ confirm_yes_btn = gr.Button("Yes, continue", variant="primary", size="sm")
180
+ confirm_no_btn = gr.Button("No, I disagree", variant="secondary", size="sm")
181
+ feedback_input = gr.Textbox(
182
+ label="Your assessment",
183
+ placeholder="Enter your diagnosis...",
184
+ visible=False
185
+ )
186
+ submit_feedback_btn = gr.Button("Submit", visible=False, size="sm")
187
+
188
+ # Follow-up input
189
+ with gr.Row(elem_classes=["chat-input-area"]):
190
+ followup_input = gr.Textbox(
191
+ placeholder="Ask a follow-up question...",
192
+ show_label=False,
193
+ lines=1
194
+ )
195
+ followup_btn = gr.Button("Send", size="sm", elem_classes=["send-btn"])
196
+
197
+ components = {
198
+ "input_greeting": input_greeting,
199
+ "chat_view": chat_view,
200
+ "message_input": message_input,
201
+ "image_input": image_input,
202
+ "send_btn": send_btn,
203
+ "results_output": results_output,
204
+ "confirm_group": confirm_group,
205
+ "confirm_yes_btn": confirm_yes_btn,
206
+ "confirm_no_btn": confirm_no_btn,
207
+ "feedback_input": feedback_input,
208
+ "submit_feedback_btn": submit_feedback_btn,
209
+ "followup_input": followup_input,
210
+ "followup_btn": followup_btn,
211
+ "upload_hint": upload_hint
212
+ }
213
+
214
+ return container, components
frontend/components/patient_select.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Patient Selection Component - Landing page for selecting/creating patients
3
+ """
4
+
5
+ import gradio as gr
6
+ from typing import Callable, List
7
+ from data.case_store import get_case_store, Case
8
+
9
+
10
+ def create_patient_select(on_patient_selected: Callable[[str], None]) -> gr.Group:
11
+ """
12
+ Create the patient selection page component.
13
+
14
+ Args:
15
+ on_patient_selected: Callback when a patient is selected (receives case_id)
16
+
17
+ Returns:
18
+ gr.Group containing the patient selection UI
19
+ """
20
+ case_store = get_case_store()
21
+
22
+ with gr.Group(visible=True, elem_classes=["patient-select-container"]) as container:
23
+ gr.Markdown("# SkinProAI", elem_classes=["patient-select-title"])
24
+ gr.Markdown("Select a patient to continue or create a new case", elem_classes=["patient-select-subtitle"])
25
+
26
+ with gr.Column(elem_classes=["patient-grid"]):
27
+ # Demo cases
28
+ demo_melanoma_btn = gr.Button(
29
+ "Demo: Melanocytic Lesion",
30
+ elem_classes=["patient-card"]
31
+ )
32
+ demo_ak_btn = gr.Button(
33
+ "Demo: Actinic Keratosis",
34
+ elem_classes=["patient-card"]
35
+ )
36
+
37
+ # New patient button
38
+ new_patient_btn = gr.Button(
39
+ "+ New Patient",
40
+ elem_classes=["new-patient-btn"]
41
+ )
42
+
43
+ return container, demo_melanoma_btn, demo_ak_btn, new_patient_btn
44
+
45
+
46
+ def get_patient_cases() -> List[Case]:
47
+ """Get list of all patient cases"""
48
+ return get_case_store().list_cases()
frontend/components/sidebar.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Sidebar Component - Shows previous queries for a patient
3
+ """
4
+
5
+ import gradio as gr
6
+ from datetime import datetime
7
+ from typing import List, Optional
8
+ from data.case_store import get_case_store, Instance
9
+
10
+
11
+ def format_query_item(instance: Instance, index: int) -> str:
12
+ """Format an instance as a query item for display"""
13
+ diagnosis = "Pending"
14
+ if instance.analysis and instance.analysis.get("diagnosis"):
15
+ diag = instance.analysis["diagnosis"]
16
+ diagnosis = diag.get("full_name", diag.get("class", "Unknown"))
17
+
18
+ try:
19
+ dt = datetime.fromisoformat(instance.created_at.replace('Z', '+00:00'))
20
+ date_str = dt.strftime("%b %d, %H:%M")
21
+ except:
22
+ date_str = "Unknown"
23
+
24
+ return f"Query #{index}: {diagnosis} ({date_str})"
25
+
26
+
27
+ def create_sidebar():
28
+ """
29
+ Create the sidebar component for showing previous queries.
30
+
31
+ Returns:
32
+ Tuple of (container, components dict)
33
+ """
34
+ with gr.Column(visible=False, elem_classes=["query-sidebar"]) as container:
35
+ gr.Markdown("### Previous Queries", elem_classes=["sidebar-header"])
36
+
37
+ # Dynamic list of query buttons
38
+ query_list = gr.Column(elem_id="query-list")
39
+
40
+ # New query button
41
+ new_query_btn = gr.Button("+ New Query", size="sm", variant="primary")
42
+
43
+ components = {
44
+ "query_list": query_list,
45
+ "new_query_btn": new_query_btn
46
+ }
47
+
48
+ return container, components
49
+
50
+
51
+ def get_queries_for_case(case_id: str) -> List[Instance]:
52
+ """Get all instances/queries for a case"""
53
+ if not case_id:
54
+ return []
55
+ return get_case_store().list_instances(case_id)
frontend/components/styles.py ADDED
@@ -0,0 +1,517 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CSS Styles for SkinProAI components
3
+ """
4
+
5
+ MAIN_CSS = """
6
+ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600&display=swap');
7
+
8
+ * {
9
+ font-family: 'Inter', sans-serif !important;
10
+ }
11
+
12
+ .gradio-container {
13
+ max-width: 1200px !important;
14
+ margin: 0 auto !important;
15
+ }
16
+
17
+ /* Hide Gradio footer */
18
+ .gradio-container footer { display: none !important; }
19
+
20
+ /* ============================================
21
+ PATIENT SELECTION PAGE
22
+ ============================================ */
23
+
24
+ .patient-select-container {
25
+ min-height: 80vh;
26
+ display: flex;
27
+ flex-direction: column;
28
+ align-items: center;
29
+ justify-content: center;
30
+ padding: 40px 20px;
31
+ }
32
+
33
+ .patient-select-title {
34
+ font-size: 32px;
35
+ font-weight: 600;
36
+ color: #111827;
37
+ margin-bottom: 8px;
38
+ text-align: center;
39
+ }
40
+
41
+ .patient-select-subtitle {
42
+ font-size: 16px;
43
+ color: #6b7280;
44
+ margin-bottom: 40px;
45
+ text-align: center;
46
+ }
47
+
48
+ .patient-grid {
49
+ display: flex;
50
+ gap: 20px;
51
+ flex-wrap: wrap;
52
+ justify-content: center;
53
+ max-width: 800px;
54
+ }
55
+
56
+ .patient-card {
57
+ background: white !important;
58
+ border: 2px solid #e5e7eb !important;
59
+ border-radius: 16px !important;
60
+ padding: 24px 32px !important;
61
+ min-width: 200px !important;
62
+ cursor: pointer;
63
+ transition: all 0.2s ease !important;
64
+ }
65
+
66
+ .patient-card:hover {
67
+ border-color: #6366f1 !important;
68
+ box-shadow: 0 8px 25px rgba(99, 102, 241, 0.15) !important;
69
+ transform: translateY(-2px);
70
+ }
71
+
72
+ .new-patient-btn {
73
+ background: #6366f1 !important;
74
+ color: white !important;
75
+ border: none !important;
76
+ border-radius: 12px !important;
77
+ padding: 16px 32px !important;
78
+ font-weight: 500 !important;
79
+ margin-top: 24px;
80
+ }
81
+
82
+ .new-patient-btn:hover {
83
+ background: #4f46e5 !important;
84
+ }
85
+
86
+ /* ============================================
87
+ ANALYSIS PAGE - MAIN LAYOUT
88
+ ============================================ */
89
+
90
+ .analysis-container {
91
+ display: flex;
92
+ height: calc(100vh - 80px);
93
+ min-height: 600px;
94
+ }
95
+
96
+ /* Sidebar */
97
+ .query-sidebar {
98
+ width: 280px;
99
+ background: #f9fafb;
100
+ border-right: 1px solid #e5e7eb;
101
+ padding: 20px;
102
+ overflow-y: auto;
103
+ flex-shrink: 0;
104
+ }
105
+
106
+ .sidebar-header {
107
+ font-size: 14px;
108
+ font-weight: 600;
109
+ color: #374151;
110
+ margin-bottom: 16px;
111
+ padding-bottom: 12px;
112
+ border-bottom: 1px solid #e5e7eb;
113
+ }
114
+
115
+ .query-item {
116
+ background: white;
117
+ border: 1px solid #e5e7eb;
118
+ border-radius: 8px;
119
+ padding: 12px;
120
+ margin-bottom: 8px;
121
+ cursor: pointer;
122
+ transition: all 0.15s;
123
+ }
124
+
125
+ .query-item:hover {
126
+ border-color: #6366f1;
127
+ background: #f5f3ff;
128
+ }
129
+
130
+ .query-item-title {
131
+ font-size: 13px;
132
+ font-weight: 500;
133
+ color: #111827;
134
+ margin-bottom: 4px;
135
+ }
136
+
137
+ .query-item-meta {
138
+ font-size: 11px;
139
+ color: #6b7280;
140
+ }
141
+
142
+ /* Main content area */
143
+ .main-content {
144
+ flex: 1;
145
+ display: flex;
146
+ flex-direction: column;
147
+ padding: 24px;
148
+ overflow: hidden;
149
+ }
150
+
151
+ /* ============================================
152
+ INPUT AREA (Greeting style)
153
+ ============================================ */
154
+
155
+ .input-greeting {
156
+ flex: 1;
157
+ display: flex;
158
+ flex-direction: column;
159
+ align-items: center;
160
+ justify-content: center;
161
+ padding: 40px;
162
+ }
163
+
164
+ .greeting-title {
165
+ font-size: 24px;
166
+ font-weight: 600;
167
+ color: #111827;
168
+ margin-bottom: 8px;
169
+ }
170
+
171
+ .greeting-subtitle {
172
+ font-size: 14px;
173
+ color: #6b7280;
174
+ margin-bottom: 32px;
175
+ }
176
+
177
+ .input-box-container {
178
+ width: 100%;
179
+ max-width: 600px;
180
+ background: white;
181
+ border: 2px solid #e5e7eb;
182
+ border-radius: 16px;
183
+ padding: 20px;
184
+ transition: border-color 0.2s;
185
+ }
186
+
187
+ .input-box-container:focus-within {
188
+ border-color: #6366f1;
189
+ }
190
+
191
+ .message-input textarea {
192
+ border: none !important;
193
+ resize: none !important;
194
+ font-size: 15px !important;
195
+ line-height: 1.5 !important;
196
+ padding: 0 !important;
197
+ }
198
+
199
+ .message-input textarea:focus {
200
+ box-shadow: none !important;
201
+ }
202
+
203
+ .input-actions {
204
+ display: flex;
205
+ align-items: center;
206
+ justify-content: space-between;
207
+ margin-top: 16px;
208
+ padding-top: 16px;
209
+ border-top: 1px solid #f3f4f6;
210
+ }
211
+
212
+ .upload-btn {
213
+ background: #f3f4f6 !important;
214
+ color: #374151 !important;
215
+ border: 1px solid #e5e7eb !important;
216
+ border-radius: 8px !important;
217
+ padding: 8px 16px !important;
218
+ font-size: 13px !important;
219
+ }
220
+
221
+ .upload-btn:hover {
222
+ background: #e5e7eb !important;
223
+ }
224
+
225
+ .send-btn {
226
+ background: #6366f1 !important;
227
+ color: white !important;
228
+ border: none !important;
229
+ border-radius: 8px !important;
230
+ padding: 10px 24px !important;
231
+ font-weight: 500 !important;
232
+ }
233
+
234
+ .send-btn:hover {
235
+ background: #4f46e5 !important;
236
+ }
237
+
238
+ .send-btn:disabled {
239
+ background: #d1d5db !important;
240
+ cursor: not-allowed;
241
+ }
242
+
243
+ /* Image preview */
244
+ .image-preview {
245
+ margin-top: 16px;
246
+ border-radius: 12px;
247
+ overflow: hidden;
248
+ max-height: 200px;
249
+ }
250
+
251
+ .image-preview img {
252
+ max-height: 200px;
253
+ object-fit: contain;
254
+ }
255
+
256
+ /* ============================================
257
+ CHAT/RESULTS VIEW
258
+ ============================================ */
259
+
260
+ .chat-view {
261
+ flex: 1;
262
+ display: flex;
263
+ flex-direction: column;
264
+ overflow: hidden;
265
+ }
266
+
267
+ .results-area {
268
+ flex: 1;
269
+ overflow-y: auto;
270
+ padding: 20px;
271
+ background: #ffffff;
272
+ border: 1px solid #e5e7eb;
273
+ border-radius: 12px;
274
+ margin-bottom: 16px;
275
+ }
276
+
277
+ /* Analysis output styling */
278
+ .analysis-output {
279
+ line-height: 1.6;
280
+ color: #333;
281
+ }
282
+
283
+ .stage {
284
+ display: flex;
285
+ align-items: center;
286
+ gap: 10px;
287
+ padding: 8px 0;
288
+ font-weight: 500;
289
+ color: #1a1a1a;
290
+ margin-top: 12px;
291
+ }
292
+
293
+ .stage-indicator {
294
+ width: 8px;
295
+ height: 8px;
296
+ background: #6366f1;
297
+ border-radius: 50%;
298
+ animation: pulse 1.5s ease-in-out infinite;
299
+ }
300
+
301
+ @keyframes pulse {
302
+ 0%, 100% { opacity: 1; transform: scale(1); }
303
+ 50% { opacity: 0.5; transform: scale(0.8); }
304
+ }
305
+
306
+ .thinking {
307
+ color: #6b7280;
308
+ font-style: italic;
309
+ font-size: 13px;
310
+ padding: 4px 0 4px 16px;
311
+ border-left: 2px solid #e5e7eb;
312
+ margin: 4px 0;
313
+ }
314
+
315
+ .observation {
316
+ color: #374151;
317
+ font-size: 13px;
318
+ padding: 4px 0 4px 16px;
319
+ }
320
+
321
+ .tool-output {
322
+ background: #f8fafc;
323
+ border-radius: 8px;
324
+ margin: 12px 0;
325
+ overflow: hidden;
326
+ border: 1px solid #e2e8f0;
327
+ }
328
+
329
+ .tool-header {
330
+ background: #f1f5f9;
331
+ padding: 8px 12px;
332
+ font-weight: 500;
333
+ font-size: 13px;
334
+ color: #475569;
335
+ border-bottom: 1px solid #e2e8f0;
336
+ }
337
+
338
+ .tool-content {
339
+ padding: 12px;
340
+ margin: 0;
341
+ font-family: 'SF Mono', Monaco, monospace !important;
342
+ font-size: 12px;
343
+ line-height: 1.5;
344
+ white-space: pre-wrap;
345
+ color: #334155;
346
+ }
347
+
348
+ .result {
349
+ background: #ecfdf5;
350
+ border: 1px solid #a7f3d0;
351
+ border-radius: 8px;
352
+ padding: 12px 16px;
353
+ margin: 12px 0;
354
+ font-weight: 500;
355
+ color: #065f46;
356
+ }
357
+
358
+ .error {
359
+ background: #fef2f2;
360
+ border: 1px solid #fecaca;
361
+ border-radius: 8px;
362
+ padding: 12px 16px;
363
+ margin: 8px 0;
364
+ color: #b91c1c;
365
+ }
366
+
367
+ .response {
368
+ background: #ffffff;
369
+ border: 1px solid #e5e7eb;
370
+ border-radius: 8px;
371
+ padding: 16px;
372
+ margin: 16px 0;
373
+ line-height: 1.7;
374
+ }
375
+
376
+ .response ul, .response ol {
377
+ margin: 8px 0;
378
+ padding-left: 24px;
379
+ }
380
+
381
+ .response li {
382
+ margin: 4px 0;
383
+ }
384
+
385
+ .complete {
386
+ color: #6b7280;
387
+ font-size: 12px;
388
+ padding: 8px 0;
389
+ text-align: center;
390
+ }
391
+
392
+ /* Confirmation */
393
+ .confirm-box {
394
+ background: #eff6ff;
395
+ border: 1px solid #bfdbfe;
396
+ border-radius: 8px;
397
+ padding: 16px;
398
+ margin: 16px 0;
399
+ text-align: center;
400
+ }
401
+
402
+ .confirm-buttons {
403
+ background: #f0f9ff;
404
+ border: 1px solid #bae6fd;
405
+ border-radius: 8px;
406
+ padding: 12px;
407
+ margin-top: 12px;
408
+ }
409
+
410
+ /* References */
411
+ .references {
412
+ background: #f9fafb;
413
+ border: 1px solid #e5e7eb;
414
+ border-radius: 8px;
415
+ margin: 16px 0;
416
+ overflow: hidden;
417
+ }
418
+
419
+ .references-header {
420
+ background: #f3f4f6;
421
+ padding: 8px 12px;
422
+ font-weight: 500;
423
+ font-size: 13px;
424
+ border-bottom: 1px solid #e5e7eb;
425
+ }
426
+
427
+ .references ul {
428
+ list-style: none;
429
+ padding: 12px;
430
+ margin: 0;
431
+ }
432
+
433
+ .ref-link {
434
+ color: #6366f1;
435
+ text-decoration: none;
436
+ font-size: 13px;
437
+ }
438
+
439
+ .ref-link:hover {
440
+ text-decoration: underline;
441
+ }
442
+
443
+ /* GradCAM */
444
+ .gradcam-inline {
445
+ margin: 16px 0;
446
+ background: #f8fafc;
447
+ border-radius: 8px;
448
+ overflow: hidden;
449
+ border: 1px solid #e2e8f0;
450
+ }
451
+
452
+ .gradcam-header {
453
+ background: #f1f5f9;
454
+ padding: 8px 12px;
455
+ font-weight: 500;
456
+ font-size: 13px;
457
+ border-bottom: 1px solid #e2e8f0;
458
+ }
459
+
460
+ .gradcam-inline img {
461
+ max-width: 100%;
462
+ max-height: 300px;
463
+ display: block;
464
+ margin: 12px auto;
465
+ }
466
+
467
+ /* Chat input at bottom */
468
+ .chat-input-area {
469
+ background: white;
470
+ border: 1px solid #e5e7eb;
471
+ border-radius: 12px;
472
+ padding: 12px 16px;
473
+ display: flex;
474
+ gap: 12px;
475
+ align-items: flex-end;
476
+ }
477
+
478
+ .chat-input-area textarea {
479
+ flex: 1;
480
+ border: none !important;
481
+ resize: none !important;
482
+ font-size: 14px !important;
483
+ }
484
+
485
+ /* ============================================
486
+ HEADER
487
+ ============================================ */
488
+
489
+ .app-header {
490
+ display: flex;
491
+ align-items: center;
492
+ justify-content: space-between;
493
+ padding: 16px 24px;
494
+ border-bottom: 1px solid #e5e7eb;
495
+ background: white;
496
+ }
497
+
498
+ .app-title {
499
+ font-size: 20px;
500
+ font-weight: 600;
501
+ color: #111827;
502
+ }
503
+
504
+ .back-btn {
505
+ background: transparent !important;
506
+ color: #6b7280 !important;
507
+ border: 1px solid #e5e7eb !important;
508
+ border-radius: 8px !important;
509
+ padding: 8px 16px !important;
510
+ font-size: 13px !important;
511
+ }
512
+
513
+ .back-btn:hover {
514
+ background: #f9fafb !important;
515
+ color: #111827 !important;
516
+ }
517
+ """
guidelines/index/chunks.json ADDED
The diff for this file is too large to render. See raw diff
 
guidelines/index/faiss.index ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:faf9cd914f52b84a55d486a156b4756f28dbc1a92abeafc121077402e1fa53f4
3
+ size 145965
mcp_server/__init__.py ADDED
File without changes
mcp_server/server.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SkinProAI MCP Server - Pure JSON-RPC 2.0 stdio server (no mcp library required).
3
+
4
+ Uses sys.executable (venv Python) so all ML packages (torch, transformers, etc.)
5
+ are available. Tools are loaded lazily on first call.
6
+
7
+ Run standalone: python mcp_server/server.py
8
+ (Should start silently, waiting on stdin.)
9
+ """
10
+
11
+ import sys
12
+ import json
13
+ import os
14
+
15
+ # Ensure project root is on path
16
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
17
+
18
+ from mcp_server.tool_registry import get_monet, get_convnext, get_gradcam, get_rag
19
+
20
+
21
+ # ---------------------------------------------------------------------------
22
+ # Tool implementations
23
+ # ---------------------------------------------------------------------------
24
+
25
+ def _monet_analyze(arguments: dict) -> dict:
26
+ from PIL import Image
27
+ image = Image.open(arguments["image_path"]).convert("RGB")
28
+ return get_monet().analyze(image)
29
+
30
+
31
+ def _classify_lesion(arguments: dict) -> dict:
32
+ from PIL import Image
33
+ image = Image.open(arguments["image_path"]).convert("RGB")
34
+ monet_scores = arguments.get("monet_scores")
35
+ return get_convnext().classify(
36
+ clinical_image=image,
37
+ derm_image=None,
38
+ monet_scores=monet_scores,
39
+ )
40
+
41
+
42
+ def _generate_gradcam(arguments: dict) -> dict:
43
+ from PIL import Image
44
+ import tempfile
45
+ image = Image.open(arguments["image_path"]).convert("RGB")
46
+ result = get_gradcam().analyze(image)
47
+
48
+ gradcam_file = tempfile.NamedTemporaryFile(suffix="_gradcam.png", delete=False)
49
+ gradcam_path = gradcam_file.name
50
+ gradcam_file.close()
51
+ result["overlay"].save(gradcam_path)
52
+
53
+ return {
54
+ "gradcam_path": gradcam_path,
55
+ "predicted_class": result["predicted_class"],
56
+ "predicted_class_full": result["predicted_class_full"],
57
+ "confidence": result["confidence"],
58
+ }
59
+
60
+
61
+ def _search_guidelines(arguments: dict) -> dict:
62
+ query = arguments.get("query", "")
63
+ diagnosis = arguments.get("diagnosis") or ""
64
+ rag = get_rag()
65
+ context, references = rag.get_management_context(diagnosis, query)
66
+ references_display = rag.format_references_for_display(references)
67
+ return {
68
+ "context": context,
69
+ "references": references,
70
+ "references_display": references_display,
71
+ }
72
+
73
+
74
+ def _compare_images(arguments: dict) -> dict:
75
+ from PIL import Image
76
+ import tempfile
77
+ image1 = Image.open(arguments["image1_path"]).convert("RGB")
78
+ image2 = Image.open(arguments["image2_path"]).convert("RGB")
79
+
80
+ from models.overlay_tool import get_overlay_tool
81
+ comparison = get_overlay_tool().generate_comparison_overlay(
82
+ image1, image2, label1="Previous", label2="Current"
83
+ )
84
+ comparison_path = comparison["path"]
85
+
86
+ monet = get_monet()
87
+ prev_result = monet.analyze(image1)
88
+ curr_result = monet.analyze(image2)
89
+
90
+ monet_deltas = {}
91
+ for name in curr_result["features"]:
92
+ prev_val = prev_result["features"].get(name, 0.0)
93
+ curr_val = curr_result["features"][name]
94
+ delta = curr_val - prev_val
95
+ if abs(delta) > 0.1:
96
+ monet_deltas[name] = {
97
+ "previous": prev_val,
98
+ "current": curr_val,
99
+ "delta": delta,
100
+ }
101
+
102
+ # Generate GradCAM for both images so the frontend can show a side-by-side comparison
103
+ prev_gradcam_path = None
104
+ curr_gradcam_path = None
105
+ try:
106
+ gradcam = get_gradcam()
107
+ prev_gc = gradcam.analyze(image1)
108
+ curr_gc = gradcam.analyze(image2)
109
+
110
+ f1 = tempfile.NamedTemporaryFile(suffix="_gradcam.png", delete=False)
111
+ prev_gradcam_path = f1.name
112
+ f1.close()
113
+ prev_gc["overlay"].save(prev_gradcam_path)
114
+
115
+ f2 = tempfile.NamedTemporaryFile(suffix="_gradcam.png", delete=False)
116
+ curr_gradcam_path = f2.name
117
+ f2.close()
118
+ curr_gc["overlay"].save(curr_gradcam_path)
119
+ except Exception:
120
+ pass # GradCAM comparison is best-effort
121
+
122
+ return {
123
+ "comparison_path": comparison_path,
124
+ "monet_deltas": monet_deltas,
125
+ "prev_gradcam_path": prev_gradcam_path,
126
+ "curr_gradcam_path": curr_gradcam_path,
127
+ }
128
+
129
+
130
+ TOOLS = {
131
+ "monet_analyze": _monet_analyze,
132
+ "classify_lesion": _classify_lesion,
133
+ "generate_gradcam": _generate_gradcam,
134
+ "search_guidelines": _search_guidelines,
135
+ "compare_images": _compare_images,
136
+ }
137
+
138
+ TOOLS_LIST = [
139
+ {
140
+ "name": "monet_analyze",
141
+ "description": "Extract MONET concept-presence scores from a skin lesion image.",
142
+ "inputSchema": {
143
+ "type": "object",
144
+ "properties": {"image_path": {"type": "string"}},
145
+ "required": ["image_path"],
146
+ },
147
+ },
148
+ {
149
+ "name": "classify_lesion",
150
+ "description": "Classify a skin lesion using ConvNeXt dual-encoder.",
151
+ "inputSchema": {
152
+ "type": "object",
153
+ "properties": {
154
+ "image_path": {"type": "string"},
155
+ "monet_scores": {"type": "array"},
156
+ },
157
+ "required": ["image_path"],
158
+ },
159
+ },
160
+ {
161
+ "name": "generate_gradcam",
162
+ "description": "Generate a Grad-CAM attention overlay for a skin lesion image.",
163
+ "inputSchema": {
164
+ "type": "object",
165
+ "properties": {"image_path": {"type": "string"}},
166
+ "required": ["image_path"],
167
+ },
168
+ },
169
+ {
170
+ "name": "search_guidelines",
171
+ "description": "Search clinical guidelines RAG for management context.",
172
+ "inputSchema": {
173
+ "type": "object",
174
+ "properties": {
175
+ "query": {"type": "string"},
176
+ "diagnosis": {"type": "string"},
177
+ },
178
+ "required": ["query"],
179
+ },
180
+ },
181
+ {
182
+ "name": "compare_images",
183
+ "description": "Generate comparison overlay and MONET deltas for two lesion images.",
184
+ "inputSchema": {
185
+ "type": "object",
186
+ "properties": {
187
+ "image1_path": {"type": "string"},
188
+ "image2_path": {"type": "string"},
189
+ },
190
+ "required": ["image1_path", "image2_path"],
191
+ },
192
+ },
193
+ ]
194
+
195
+
196
+ # ---------------------------------------------------------------------------
197
+ # JSON-RPC 2.0 dispatcher
198
+ # ---------------------------------------------------------------------------
199
+
200
+ def handle_request(request: dict):
201
+ method = request.get("method")
202
+ req_id = request.get("id") # None for notifications
203
+ params = request.get("params", {})
204
+
205
+ if method == "initialize":
206
+ return {
207
+ "jsonrpc": "2.0",
208
+ "id": req_id,
209
+ "result": {
210
+ "protocolVersion": "2024-11-05",
211
+ "capabilities": {"tools": {"listChanged": False}},
212
+ "serverInfo": {"name": "SkinProAI", "version": "1.0.0"},
213
+ },
214
+ }
215
+
216
+ if method in ("notifications/initialized",):
217
+ return None # notification — no response
218
+
219
+ if method == "tools/list":
220
+ return {
221
+ "jsonrpc": "2.0",
222
+ "id": req_id,
223
+ "result": {"tools": TOOLS_LIST},
224
+ }
225
+
226
+ if method == "tools/call":
227
+ name = params.get("name")
228
+ arguments = params.get("arguments", {})
229
+ if name not in TOOLS:
230
+ return {
231
+ "jsonrpc": "2.0",
232
+ "id": req_id,
233
+ "error": {"code": -32601, "message": f"Unknown tool: {name}"},
234
+ }
235
+ try:
236
+ result = TOOLS[name](arguments)
237
+ return {
238
+ "jsonrpc": "2.0",
239
+ "id": req_id,
240
+ "result": {
241
+ "content": [{"type": "text", "text": json.dumps(result)}],
242
+ "isError": False,
243
+ },
244
+ }
245
+ except Exception as e:
246
+ return {
247
+ "jsonrpc": "2.0",
248
+ "id": req_id,
249
+ "result": {
250
+ "content": [{"type": "text", "text": f"Tool error: {e}"}],
251
+ "isError": True,
252
+ },
253
+ }
254
+
255
+ # Unknown method with id → method not found
256
+ if req_id is not None:
257
+ return {
258
+ "jsonrpc": "2.0",
259
+ "id": req_id,
260
+ "error": {"code": -32601, "message": f"Method not found: {method}"},
261
+ }
262
+
263
+ return None # unknown notification — ignore
264
+
265
+
266
+ # ---------------------------------------------------------------------------
267
+ # Main loop
268
+ # ---------------------------------------------------------------------------
269
+
270
+ def main():
271
+ for line in sys.stdin:
272
+ line = line.strip()
273
+ if not line:
274
+ continue
275
+ try:
276
+ request = json.loads(line)
277
+ except json.JSONDecodeError:
278
+ continue
279
+ response = handle_request(request)
280
+ if response is not None:
281
+ sys.stdout.write(json.dumps(response) + "\n")
282
+ sys.stdout.flush()
283
+
284
+
285
+ if __name__ == "__main__":
286
+ main()
mcp_server/tool_registry.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Lazy singleton loader for all 4 ML models used by the MCP server.
3
+ Fixes sys.path so the subprocess can import from models/.
4
+ """
5
+
6
+ import sys
7
+ import os
8
+
9
+ # Ensure project root is on path (this file lives at project_root/mcp_server/tool_registry.py)
10
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
11
+
12
+ _monet = None
13
+ _convnext = None
14
+ _gradcam = None
15
+ _rag = None
16
+
17
+
18
+ def get_monet():
19
+ global _monet
20
+ if _monet is None:
21
+ from models.monet_tool import MonetTool
22
+ _monet = MonetTool()
23
+ _monet.load()
24
+ return _monet
25
+
26
+
27
+ def get_convnext():
28
+ global _convnext
29
+ if _convnext is None:
30
+ from models.convnext_classifier import ConvNeXtClassifier
31
+ root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
32
+ _convnext = ConvNeXtClassifier(
33
+ checkpoint_path=os.path.join(root, "models", "seed42_fold0.pt")
34
+ )
35
+ _convnext.load()
36
+ return _convnext
37
+
38
+
39
+ def get_gradcam():
40
+ global _gradcam
41
+ if _gradcam is None:
42
+ from models.gradcam_tool import GradCAMTool
43
+ _gradcam = GradCAMTool(classifier=get_convnext())
44
+ _gradcam.load()
45
+ return _gradcam
46
+
47
+
48
+ def get_rag():
49
+ global _rag
50
+ if _rag is None:
51
+ from models.guidelines_rag import get_guidelines_rag
52
+ _rag = get_guidelines_rag()
53
+ if not _rag.loaded:
54
+ _rag.load_index()
55
+ return _rag
models/convnext_classifier.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ConvNeXt Classifier Tool - Skin lesion classification using ConvNeXt + MONET features
3
+ Loads seed42_fold0.pt checkpoint and performs classification.
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import numpy as np
9
+ from PIL import Image
10
+ from torchvision import transforms
11
+ from typing import Optional, Dict, List, Tuple
12
+ import timm
13
+
14
+
15
+ # Class names for the 11-class skin lesion classification
16
+ CLASS_NAMES = [
17
+ 'AKIEC', 'BCC', 'BEN_OTH', 'BKL', 'DF',
18
+ 'INF', 'MAL_OTH', 'MEL', 'NV', 'SCCKA', 'VASC'
19
+ ]
20
+
21
+ CLASS_FULL_NAMES = {
22
+ 'AKIEC': 'Actinic Keratosis / Intraepithelial Carcinoma',
23
+ 'BCC': 'Basal Cell Carcinoma',
24
+ 'BEN_OTH': 'Benign Other',
25
+ 'BKL': 'Benign Keratosis-like Lesion',
26
+ 'DF': 'Dermatofibroma',
27
+ 'INF': 'Inflammatory',
28
+ 'MAL_OTH': 'Malignant Other',
29
+ 'MEL': 'Melanoma',
30
+ 'NV': 'Melanocytic Nevus',
31
+ 'SCCKA': 'Squamous Cell Carcinoma / Keratoacanthoma',
32
+ 'VASC': 'Vascular Lesion'
33
+ }
34
+
35
+
36
+ class ConvNeXtDualEncoder(nn.Module):
37
+ """
38
+ Dual-image ConvNeXt model matching the trained checkpoint.
39
+ Processes BOTH clinical and dermoscopy images through shared backbone.
40
+
41
+ Metadata input: 19 dimensions
42
+ - age (1): normalized age
43
+ - sex (4): one-hot encoded
44
+ - site (7): one-hot encoded (reduced from 14)
45
+ - MONET (7): 7 MONET feature scores
46
+ """
47
+
48
+ def __init__(
49
+ self,
50
+ model_name: str = 'convnext_base.fb_in22k_ft_in1k',
51
+ metadata_dim: int = 19,
52
+ num_classes: int = 11,
53
+ dropout: float = 0.3
54
+ ):
55
+ super().__init__()
56
+
57
+ self.backbone = timm.create_model(
58
+ model_name,
59
+ pretrained=False,
60
+ num_classes=0
61
+ )
62
+ backbone_dim = self.backbone.num_features # 1024 for convnext_base
63
+
64
+ # Metadata MLP: 19 -> 64
65
+ self.meta_mlp = nn.Sequential(
66
+ nn.Linear(metadata_dim, 64),
67
+ nn.LayerNorm(64),
68
+ nn.GELU(),
69
+ nn.Dropout(dropout)
70
+ )
71
+
72
+ # Classifier: 2112 -> 512 -> 256 -> 11
73
+ # Input: clinical(1024) + derm(1024) + meta(64) = 2112
74
+ fusion_dim = backbone_dim * 2 + 64
75
+ self.classifier = nn.Sequential(
76
+ nn.Linear(fusion_dim, 512),
77
+ nn.LayerNorm(512),
78
+ nn.GELU(),
79
+ nn.Dropout(dropout),
80
+ nn.Linear(512, 256),
81
+ nn.LayerNorm(256),
82
+ nn.GELU(),
83
+ nn.Dropout(dropout),
84
+ nn.Linear(256, num_classes)
85
+ )
86
+
87
+ self.metadata_dim = metadata_dim
88
+ self.num_classes = num_classes
89
+ self.backbone_dim = backbone_dim
90
+
91
+ def forward(
92
+ self,
93
+ clinical_img: torch.Tensor,
94
+ derm_img: Optional[torch.Tensor] = None,
95
+ metadata: Optional[torch.Tensor] = None
96
+ ) -> torch.Tensor:
97
+ """
98
+ Forward pass with dual images.
99
+
100
+ Args:
101
+ clinical_img: [B, 3, H, W] clinical image tensor
102
+ derm_img: [B, 3, H, W] dermoscopy image tensor (uses clinical if None)
103
+ metadata: [B, 19] metadata tensor (zeros if None)
104
+
105
+ Returns:
106
+ logits: [B, 11]
107
+ """
108
+ # Process clinical image
109
+ clinical_features = self.backbone(clinical_img)
110
+
111
+ # Process dermoscopy image
112
+ if derm_img is not None:
113
+ derm_features = self.backbone(derm_img)
114
+ else:
115
+ derm_features = clinical_features
116
+
117
+ # Process metadata
118
+ if metadata is not None:
119
+ meta_features = self.meta_mlp(metadata)
120
+ else:
121
+ batch_size = clinical_features.size(0)
122
+ meta_features = torch.zeros(
123
+ batch_size, 64,
124
+ device=clinical_features.device
125
+ )
126
+
127
+ # Concatenate: [B, 1024] + [B, 1024] + [B, 64] = [B, 2112]
128
+ fused = torch.cat([clinical_features, derm_features, meta_features], dim=1)
129
+ logits = self.classifier(fused)
130
+
131
+ return logits
132
+
133
+
134
+ class ConvNeXtClassifier:
135
+ """
136
+ ConvNeXt classifier tool for skin lesion classification.
137
+ Uses dual images (clinical + dermoscopy) and MONET features.
138
+ """
139
+
140
+ # Site mapping for metadata encoding
141
+ SITE_MAPPING = {
142
+ 'head': 0, 'neck': 0, 'face': 0, # head_neck_face
143
+ 'trunk': 1, 'back': 1, 'chest': 1, 'abdomen': 1,
144
+ 'upper': 2, 'arm': 2, 'hand': 2, # upper extremity
145
+ 'lower': 3, 'leg': 3, 'foot': 3, 'thigh': 3, # lower extremity
146
+ 'genital': 4, 'oral': 5, 'acral': 6,
147
+ }
148
+
149
+ SEX_MAPPING = {'male': 0, 'female': 1, 'other': 2, 'unknown': 3}
150
+
151
+ def __init__(
152
+ self,
153
+ checkpoint_path: str = "models/seed42_fold0.pt",
154
+ device: Optional[str] = None
155
+ ):
156
+ self.checkpoint_path = checkpoint_path
157
+ self.device = device
158
+ self.model = None
159
+ self.loaded = False
160
+
161
+ # Image preprocessing
162
+ self.transform = transforms.Compose([
163
+ transforms.Resize((384, 384)),
164
+ transforms.ToTensor(),
165
+ transforms.Normalize(
166
+ mean=[0.485, 0.456, 0.406],
167
+ std=[0.229, 0.224, 0.225]
168
+ )
169
+ ])
170
+
171
+ def load(self):
172
+ """Load the ConvNeXt model from checkpoint"""
173
+ if self.loaded:
174
+ return
175
+
176
+ # Determine device
177
+ if self.device is None:
178
+ if torch.cuda.is_available():
179
+ self.device = "cuda"
180
+ elif torch.backends.mps.is_available():
181
+ self.device = "mps"
182
+ else:
183
+ self.device = "cpu"
184
+
185
+ # Create model
186
+ self.model = ConvNeXtDualEncoder(
187
+ model_name='convnext_base.fb_in22k_ft_in1k',
188
+ metadata_dim=19,
189
+ num_classes=11,
190
+ dropout=0.3
191
+ )
192
+
193
+ # Load checkpoint
194
+ checkpoint = torch.load(
195
+ self.checkpoint_path,
196
+ map_location=self.device,
197
+ weights_only=False
198
+ )
199
+
200
+ if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
201
+ self.model.load_state_dict(checkpoint['model_state_dict'])
202
+ else:
203
+ self.model.load_state_dict(checkpoint)
204
+
205
+ self.model.to(self.device)
206
+ self.model.eval()
207
+ self.loaded = True
208
+
209
+ def encode_metadata(
210
+ self,
211
+ age: Optional[float] = None,
212
+ sex: Optional[str] = None,
213
+ site: Optional[str] = None,
214
+ monet_scores: Optional[List[float]] = None
215
+ ) -> torch.Tensor:
216
+ """
217
+ Encode metadata into 19-dim vector.
218
+
219
+ Layout: [age(1), sex(4), site(7), monet(7)] = 19
220
+
221
+ Args:
222
+ age: Patient age in years
223
+ sex: 'male', 'female', 'other', or None
224
+ site: Anatomical site string
225
+ monet_scores: List of 7 MONET feature scores
226
+
227
+ Returns:
228
+ torch.Tensor of shape [19]
229
+ """
230
+ features = []
231
+
232
+ # Age (1 dim) - normalized
233
+ age_norm = (age - 50) / 30 if age is not None else 0.0
234
+ features.append(age_norm)
235
+
236
+ # Sex (4 dim) - one-hot
237
+ sex_onehot = [0.0] * 4
238
+ if sex:
239
+ sex_idx = self.SEX_MAPPING.get(sex.lower(), 3)
240
+ sex_onehot[sex_idx] = 1.0
241
+ features.extend(sex_onehot)
242
+
243
+ # Site (7 dim) - one-hot
244
+ site_onehot = [0.0] * 7
245
+ if site:
246
+ site_lower = site.lower()
247
+ for key, idx in self.SITE_MAPPING.items():
248
+ if key in site_lower:
249
+ site_onehot[idx] = 1.0
250
+ break
251
+ features.extend(site_onehot)
252
+
253
+ # MONET (7 dim)
254
+ if monet_scores is not None and len(monet_scores) == 7:
255
+ features.extend(monet_scores)
256
+ else:
257
+ features.extend([0.0] * 7)
258
+
259
+ return torch.tensor(features, dtype=torch.float32)
260
+
261
+ def preprocess_image(self, image: Image.Image) -> torch.Tensor:
262
+ """Preprocess PIL image for model input"""
263
+ if image.mode != "RGB":
264
+ image = image.convert("RGB")
265
+ return self.transform(image).unsqueeze(0)
266
+
267
+ def classify(
268
+ self,
269
+ clinical_image: Image.Image,
270
+ derm_image: Optional[Image.Image] = None,
271
+ age: Optional[float] = None,
272
+ sex: Optional[str] = None,
273
+ site: Optional[str] = None,
274
+ monet_scores: Optional[List[float]] = None,
275
+ top_k: int = 5
276
+ ) -> Dict:
277
+ """
278
+ Classify a skin lesion.
279
+
280
+ Args:
281
+ clinical_image: Clinical (close-up) image
282
+ derm_image: Dermoscopy image (optional, uses clinical if None)
283
+ age: Patient age
284
+ sex: Patient sex
285
+ site: Anatomical site
286
+ monet_scores: 7 MONET feature scores
287
+ top_k: Number of top predictions to return
288
+
289
+ Returns:
290
+ dict with 'predictions', 'probabilities', 'top_class', 'confidence'
291
+ """
292
+ if not self.loaded:
293
+ self.load()
294
+
295
+ # Preprocess images
296
+ clinical_tensor = self.preprocess_image(clinical_image).to(self.device)
297
+
298
+ if derm_image is not None:
299
+ derm_tensor = self.preprocess_image(derm_image).to(self.device)
300
+ else:
301
+ derm_tensor = None
302
+
303
+ # Encode metadata
304
+ metadata = self.encode_metadata(age, sex, site, monet_scores)
305
+ metadata_tensor = metadata.unsqueeze(0).to(self.device)
306
+
307
+ # Run inference
308
+ with torch.no_grad():
309
+ logits = self.model(clinical_tensor, derm_tensor, metadata_tensor)
310
+ probs = torch.softmax(logits, dim=1)[0].cpu().numpy()
311
+
312
+ # Get top-k predictions
313
+ top_indices = np.argsort(probs)[::-1][:top_k]
314
+
315
+ predictions = []
316
+ for idx in top_indices:
317
+ predictions.append({
318
+ 'class': CLASS_NAMES[idx],
319
+ 'full_name': CLASS_FULL_NAMES[CLASS_NAMES[idx]],
320
+ 'probability': float(probs[idx])
321
+ })
322
+
323
+ return {
324
+ 'predictions': predictions,
325
+ 'probabilities': probs.tolist(),
326
+ 'top_class': CLASS_NAMES[top_indices[0]],
327
+ 'confidence': float(probs[top_indices[0]]),
328
+ 'all_classes': CLASS_NAMES,
329
+ }
330
+
331
+ def __call__(
332
+ self,
333
+ clinical_image: Image.Image,
334
+ derm_image: Optional[Image.Image] = None,
335
+ **kwargs
336
+ ) -> Dict:
337
+ """Shorthand for classify()"""
338
+ return self.classify(clinical_image, derm_image, **kwargs)
339
+
340
+
341
+ # Singleton instance
342
+ _convnext_instance = None
343
+
344
+
345
+ def get_convnext_classifier(checkpoint_path: str = "models/seed42_fold0.pt") -> ConvNeXtClassifier:
346
+ """Get or create ConvNeXt classifier instance"""
347
+ global _convnext_instance
348
+ if _convnext_instance is None:
349
+ _convnext_instance = ConvNeXtClassifier(checkpoint_path)
350
+ return _convnext_instance
351
+
352
+
353
+ if __name__ == "__main__":
354
+ import sys
355
+
356
+ print("ConvNeXt Classifier Test")
357
+ print("=" * 50)
358
+
359
+ classifier = ConvNeXtClassifier()
360
+ print("Loading model...")
361
+ classifier.load()
362
+ print("Model loaded!")
363
+
364
+ if len(sys.argv) > 1:
365
+ image_path = sys.argv[1]
366
+ print(f"\nClassifying: {image_path}")
367
+
368
+ image = Image.open(image_path).convert("RGB")
369
+
370
+ # Example with mock MONET scores
371
+ monet_scores = [0.2, 0.1, 0.05, 0.3, 0.7, 0.1, 0.05]
372
+
373
+ result = classifier.classify(
374
+ clinical_image=image,
375
+ age=55,
376
+ sex="male",
377
+ site="back",
378
+ monet_scores=monet_scores
379
+ )
380
+
381
+ print("\nTop Predictions:")
382
+ for pred in result['predictions']:
383
+ print(f" {pred['probability']:.1%} - {pred['class']} ({pred['full_name']})")
models/explainability.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # models/explainability.py
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+ import cv2
7
+ from typing import Tuple
8
+ from PIL import Image
9
+
10
+ class GradCAM:
11
+ """
12
+ Gradient-weighted Class Activation Mapping
13
+ Shows which regions of image are important for prediction
14
+ """
15
+
16
+ def __init__(self, model: torch.nn.Module, target_layer: str = None):
17
+ """
18
+ Args:
19
+ model: The neural network
20
+ target_layer: Layer name to compute CAM on (usually last conv layer)
21
+ """
22
+ self.model = model
23
+ self.gradients = None
24
+ self.activations = None
25
+
26
+ # Auto-detect target layer if not specified
27
+ if target_layer is None:
28
+ # Use last ConvNeXt stage
29
+ self.target_layer = model.convnext.stages[-1]
30
+ else:
31
+ self.target_layer = dict(model.named_modules())[target_layer]
32
+
33
+ # Register hooks
34
+ self.target_layer.register_forward_hook(self._save_activation)
35
+ self.target_layer.register_full_backward_hook(self._save_gradient)
36
+
37
+ def _save_activation(self, module, input, output):
38
+ """Save forward activations"""
39
+ self.activations = output.detach()
40
+
41
+ def _save_gradient(self, module, grad_input, grad_output):
42
+ """Save backward gradients"""
43
+ self.gradients = grad_output[0].detach()
44
+
45
+ def generate_cam(
46
+ self,
47
+ image: torch.Tensor,
48
+ target_class: int = None
49
+ ) -> np.ndarray:
50
+ """
51
+ Generate Class Activation Map
52
+
53
+ Args:
54
+ image: Input image [1, 3, H, W]
55
+ target_class: Class to generate CAM for (None = predicted class)
56
+
57
+ Returns:
58
+ cam: Activation map [H, W] normalized to 0-1
59
+ """
60
+ self.model.eval()
61
+
62
+ # Forward pass
63
+ output = self.model(image)
64
+
65
+ # Use predicted class if not specified
66
+ if target_class is None:
67
+ target_class = output.argmax(dim=1).item()
68
+
69
+ # Zero gradients
70
+ self.model.zero_grad()
71
+
72
+ # Backward pass for target class
73
+ output[0, target_class].backward()
74
+
75
+ # Get gradients and activations
76
+ gradients = self.gradients[0] # [C, H, W]
77
+ activations = self.activations[0] # [C, H, W]
78
+
79
+ # Global average pooling of gradients
80
+ weights = gradients.mean(dim=(1, 2)) # [C]
81
+
82
+ # Weighted sum of activations
83
+ cam = torch.zeros(activations.shape[1:], dtype=torch.float32)
84
+ for i, w in enumerate(weights):
85
+ cam += w * activations[i]
86
+
87
+ # ReLU
88
+ cam = F.relu(cam)
89
+
90
+ # Normalize to 0-1
91
+ cam = cam.cpu().numpy()
92
+ cam = cam - cam.min()
93
+ if cam.max() > 0:
94
+ cam = cam / cam.max()
95
+
96
+ return cam
97
+
98
+ def overlay_cam_on_image(
99
+ self,
100
+ image: np.ndarray, # [H, W, 3] RGB
101
+ cam: np.ndarray, # [h, w]
102
+ alpha: float = 0.5,
103
+ colormap: int = cv2.COLORMAP_JET
104
+ ) -> np.ndarray:
105
+ """
106
+ Overlay CAM heatmap on original image
107
+
108
+ Returns:
109
+ overlay: [H, W, 3] RGB image with heatmap
110
+ """
111
+ H, W = image.shape[:2]
112
+
113
+ # Resize CAM to image size
114
+ cam_resized = cv2.resize(cam, (W, H))
115
+
116
+ # Convert to heatmap
117
+ heatmap = cv2.applyColorMap(
118
+ np.uint8(255 * cam_resized),
119
+ colormap
120
+ )
121
+ heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
122
+
123
+ # Blend with original image
124
+ overlay = (alpha * heatmap + (1 - alpha) * image).astype(np.uint8)
125
+
126
+ return overlay
127
+
128
+ class AttentionVisualizer:
129
+ """Visualize MedSigLIP attention maps"""
130
+
131
+ def __init__(self, model):
132
+ self.model = model
133
+
134
+ def get_attention_maps(self, image: torch.Tensor) -> np.ndarray:
135
+ """
136
+ Extract attention maps from MedSigLIP
137
+
138
+ Returns:
139
+ attention: [num_heads, H, W] attention weights
140
+ """
141
+ # Forward pass
142
+ with torch.no_grad():
143
+ _ = self.model(image)
144
+
145
+ # Get last layer attention from MedSigLIP
146
+ # Shape: [batch, num_heads, seq_len, seq_len]
147
+ attention = self.model.medsiglip_features
148
+
149
+ # Average across heads and extract spatial attention
150
+ # This is model-dependent - adjust based on MedSigLIP architecture
151
+
152
+ # Placeholder implementation
153
+ # You'll need to adapt this to your specific MedSigLIP implementation
154
+ return np.random.rand(14, 14) # Placeholder
155
+
156
+ def overlay_attention(
157
+ self,
158
+ image: np.ndarray,
159
+ attention: np.ndarray,
160
+ alpha: float = 0.6
161
+ ) -> np.ndarray:
162
+ """Overlay attention map on image"""
163
+ H, W = image.shape[:2]
164
+
165
+ # Resize attention to image size
166
+ attention_resized = cv2.resize(attention, (W, H))
167
+
168
+ # Normalize
169
+ attention_resized = (attention_resized - attention_resized.min())
170
+ if attention_resized.max() > 0:
171
+ attention_resized = attention_resized / attention_resized.max()
172
+
173
+ # Create colored overlay
174
+ heatmap = cv2.applyColorMap(
175
+ np.uint8(255 * attention_resized),
176
+ cv2.COLORMAP_VIRIDIS
177
+ )
178
+ heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
179
+
180
+ # Blend
181
+ overlay = (alpha * heatmap + (1 - alpha) * image).astype(np.uint8)
182
+
183
+ return overlay
models/gradcam_tool.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Grad-CAM Tool - Visual explanation of ConvNeXt predictions
3
+ Shows which regions of the image the model focuses on.
4
+ """
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import numpy as np
9
+ from PIL import Image
10
+ from torchvision import transforms
11
+ from typing import Optional, Tuple
12
+ import cv2
13
+
14
+
15
+ class GradCAM:
16
+ """
17
+ Grad-CAM implementation for ConvNeXt model.
18
+ Generates heatmaps showing model attention.
19
+ """
20
+
21
+ def __init__(self, model, target_layer=None):
22
+ """
23
+ Args:
24
+ model: ConvNeXtDualEncoder model
25
+ target_layer: Layer to extract gradients from (default: last conv layer)
26
+ """
27
+ self.model = model
28
+ self.gradients = None
29
+ self.activations = None
30
+
31
+ # Hook the target layer (last stage of backbone)
32
+ if target_layer is None:
33
+ target_layer = model.backbone.stages[-1]
34
+
35
+ target_layer.register_forward_hook(self._save_activation)
36
+ target_layer.register_full_backward_hook(self._save_gradient)
37
+
38
+ def _save_activation(self, module, input, output):
39
+ """Save activations during forward pass"""
40
+ self.activations = output.detach()
41
+
42
+ def _save_gradient(self, module, grad_input, grad_output):
43
+ """Save gradients during backward pass"""
44
+ self.gradients = grad_output[0].detach()
45
+
46
+ def generate(
47
+ self,
48
+ image_tensor: torch.Tensor,
49
+ target_class: Optional[int] = None,
50
+ derm_tensor: Optional[torch.Tensor] = None,
51
+ metadata: Optional[torch.Tensor] = None
52
+ ) -> np.ndarray:
53
+ """
54
+ Generate Grad-CAM heatmap.
55
+
56
+ Args:
57
+ image_tensor: Input image tensor [1, 3, H, W]
58
+ target_class: Class index to visualize (default: predicted class)
59
+ derm_tensor: Optional dermoscopy image tensor
60
+ metadata: Optional metadata tensor
61
+
62
+ Returns:
63
+ CAM heatmap as numpy array [H, W] normalized to 0-1
64
+ """
65
+ self.model.eval()
66
+
67
+ # Forward pass
68
+ output = self.model(image_tensor, derm_tensor, metadata)
69
+
70
+ if target_class is None:
71
+ target_class = output.argmax(dim=1).item()
72
+
73
+ # Backward pass for target class
74
+ self.model.zero_grad()
75
+ output[0, target_class].backward()
76
+
77
+ # Get gradients and activations
78
+ gradients = self.gradients[0] # [C, H, W]
79
+ activations = self.activations[0] # [C, H, W]
80
+
81
+ # Global average pooling of gradients
82
+ weights = gradients.mean(dim=(1, 2)) # [C]
83
+
84
+ # Weighted combination of activation maps
85
+ cam = torch.zeros(activations.shape[1:], dtype=torch.float32, device=activations.device)
86
+ for i, w in enumerate(weights):
87
+ cam += w * activations[i]
88
+
89
+ # ReLU and normalize
90
+ cam = F.relu(cam)
91
+ cam = cam.cpu().numpy()
92
+
93
+ if cam.max() > 0:
94
+ cam = (cam - cam.min()) / (cam.max() - cam.min())
95
+
96
+ return cam
97
+
98
+ def overlay(
99
+ self,
100
+ image: np.ndarray,
101
+ cam: np.ndarray,
102
+ alpha: float = 0.5,
103
+ colormap: int = cv2.COLORMAP_JET
104
+ ) -> np.ndarray:
105
+ """
106
+ Overlay CAM heatmap on original image.
107
+
108
+ Args:
109
+ image: Original image [H, W, 3] RGB uint8
110
+ cam: CAM heatmap [H, W] float 0-1
111
+ alpha: Overlay transparency
112
+ colormap: OpenCV colormap
113
+
114
+ Returns:
115
+ Overlaid image [H, W, 3] RGB uint8
116
+ """
117
+ H, W = image.shape[:2]
118
+
119
+ # Resize CAM to image size
120
+ cam_resized = cv2.resize(cam, (W, H))
121
+
122
+ # Apply colormap
123
+ heatmap = cv2.applyColorMap(
124
+ np.uint8(255 * cam_resized),
125
+ colormap
126
+ )
127
+ heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
128
+
129
+ # Overlay
130
+ overlay = (alpha * heatmap + (1 - alpha) * image).astype(np.uint8)
131
+
132
+ return overlay
133
+
134
+
135
+ class GradCAMTool:
136
+ """
137
+ High-level Grad-CAM tool for ConvNeXt classifier.
138
+ """
139
+
140
+ def __init__(self, classifier=None):
141
+ """
142
+ Args:
143
+ classifier: ConvNeXtClassifier instance (will create one if None)
144
+ """
145
+ self.classifier = classifier
146
+ self.gradcam = None
147
+ self.loaded = False
148
+
149
+ # Preprocessing
150
+ self.transform = transforms.Compose([
151
+ transforms.Resize((384, 384)),
152
+ transforms.ToTensor(),
153
+ transforms.Normalize(
154
+ mean=[0.485, 0.456, 0.406],
155
+ std=[0.229, 0.224, 0.225]
156
+ )
157
+ ])
158
+
159
+ def load(self):
160
+ """Load classifier and setup Grad-CAM"""
161
+ if self.loaded:
162
+ return
163
+
164
+ if self.classifier is None:
165
+ from models.convnext_classifier import ConvNeXtClassifier
166
+ self.classifier = ConvNeXtClassifier()
167
+ self.classifier.load()
168
+
169
+ self.gradcam = GradCAM(self.classifier.model)
170
+ self.loaded = True
171
+
172
+ def generate_heatmap(
173
+ self,
174
+ image: Image.Image,
175
+ target_class: Optional[int] = None
176
+ ) -> Tuple[np.ndarray, np.ndarray, int, float]:
177
+ """
178
+ Generate Grad-CAM heatmap for an image.
179
+
180
+ Args:
181
+ image: PIL Image
182
+ target_class: Class to visualize (default: predicted)
183
+
184
+ Returns:
185
+ Tuple of (overlay_image, cam_heatmap, predicted_class, confidence)
186
+ """
187
+ if not self.loaded:
188
+ self.load()
189
+
190
+ # Ensure RGB
191
+ if image.mode != "RGB":
192
+ image = image.convert("RGB")
193
+
194
+ # Preprocess
195
+ image_np = np.array(image.resize((384, 384)))
196
+ image_tensor = self.transform(image).unsqueeze(0).to(self.classifier.device)
197
+
198
+ # Get prediction first
199
+ with torch.no_grad():
200
+ logits = self.classifier.model(image_tensor)
201
+ probs = torch.softmax(logits, dim=1)[0]
202
+ pred_class = probs.argmax().item()
203
+ confidence = probs[pred_class].item()
204
+
205
+ # Use predicted class if not specified
206
+ if target_class is None:
207
+ target_class = pred_class
208
+
209
+ # Generate CAM
210
+ cam = self.gradcam.generate(image_tensor, target_class)
211
+
212
+ # Create overlay
213
+ overlay = self.gradcam.overlay(image_np, cam, alpha=0.5)
214
+
215
+ return overlay, cam, pred_class, confidence
216
+
217
+ def analyze(
218
+ self,
219
+ image: Image.Image,
220
+ target_class: Optional[int] = None
221
+ ) -> dict:
222
+ """
223
+ Full analysis with Grad-CAM visualization.
224
+
225
+ Args:
226
+ image: PIL Image
227
+ target_class: Class to visualize
228
+
229
+ Returns:
230
+ Dict with overlay_image, cam, prediction info
231
+ """
232
+ from models.convnext_classifier import CLASS_NAMES, CLASS_FULL_NAMES
233
+
234
+ overlay, cam, pred_class, confidence = self.generate_heatmap(image, target_class)
235
+
236
+ return {
237
+ "overlay": Image.fromarray(overlay),
238
+ "cam": cam,
239
+ "predicted_class": CLASS_NAMES[pred_class],
240
+ "predicted_class_full": CLASS_FULL_NAMES[CLASS_NAMES[pred_class]],
241
+ "confidence": confidence,
242
+ "class_index": pred_class,
243
+ }
244
+
245
+ def __call__(self, image: Image.Image, target_class: Optional[int] = None) -> dict:
246
+ return self.analyze(image, target_class)
247
+
248
+
249
+ # Singleton
250
+ _gradcam_instance = None
251
+
252
+
253
+ def get_gradcam_tool() -> GradCAMTool:
254
+ """Get or create Grad-CAM tool instance"""
255
+ global _gradcam_instance
256
+ if _gradcam_instance is None:
257
+ _gradcam_instance = GradCAMTool()
258
+ return _gradcam_instance
259
+
260
+
261
+ if __name__ == "__main__":
262
+ import sys
263
+
264
+ print("Grad-CAM Tool Test")
265
+ print("=" * 50)
266
+
267
+ tool = GradCAMTool()
268
+ print("Loading model...")
269
+ tool.load()
270
+ print("Model loaded!")
271
+
272
+ if len(sys.argv) > 1:
273
+ image_path = sys.argv[1]
274
+ print(f"\nAnalyzing: {image_path}")
275
+
276
+ image = Image.open(image_path).convert("RGB")
277
+ result = tool.analyze(image)
278
+
279
+ print(f"\nPrediction: {result['predicted_class']} ({result['confidence']:.1%})")
280
+ print(f"Full name: {result['predicted_class_full']}")
281
+
282
+ # Save overlay
283
+ output_path = image_path.rsplit(".", 1)[0] + "_gradcam.png"
284
+ result["overlay"].save(output_path)
285
+ print(f"\nGrad-CAM overlay saved to: {output_path}")
models/guidelines_rag.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Guidelines RAG System - Retrieval-Augmented Generation for clinical guidelines
3
+ Uses FAISS for vector similarity search on chunked guideline PDFs.
4
+ """
5
+
6
+ import os
7
+ import json
8
+ import re
9
+ from pathlib import Path
10
+ from typing import List, Dict, Optional, Tuple
11
+
12
+ import numpy as np
13
+
14
+ # Paths
15
+ GUIDELINES_DIR = Path(__file__).parent.parent / "guidelines"
16
+ INDEX_DIR = GUIDELINES_DIR / "index"
17
+ FAISS_INDEX_PATH = INDEX_DIR / "faiss.index"
18
+ CHUNKS_PATH = INDEX_DIR / "chunks.json"
19
+
20
+ # Chunking parameters
21
+ CHUNK_SIZE = 500 # tokens (approximate)
22
+ CHUNK_OVERLAP = 50 # tokens overlap between chunks
23
+
24
+
25
+ class GuidelinesRAG:
26
+ """
27
+ RAG system for clinical guidelines.
28
+ Extracts text from PDFs, chunks it, creates embeddings, and provides search.
29
+ """
30
+
31
+ def __init__(self):
32
+ self.index = None
33
+ self.chunks = []
34
+ self.embedder = None
35
+ self.loaded = False
36
+
37
+ def _load_embedder(self):
38
+ """Load sentence transformer model for embeddings"""
39
+ if self.embedder is None:
40
+ from sentence_transformers import SentenceTransformer
41
+ self.embedder = SentenceTransformer('all-MiniLM-L6-v2')
42
+
43
+ def _extract_pdf_text(self, pdf_path: Path) -> str:
44
+ """Extract text from a PDF file"""
45
+ try:
46
+ import pdfplumber
47
+ text_parts = []
48
+ with pdfplumber.open(pdf_path) as pdf:
49
+ for page in pdf.pages:
50
+ page_text = page.extract_text()
51
+ if page_text:
52
+ text_parts.append(page_text)
53
+ return "\n\n".join(text_parts)
54
+ except ImportError:
55
+ # Fallback to PyPDF2
56
+ from PyPDF2 import PdfReader
57
+ reader = PdfReader(pdf_path)
58
+ text_parts = []
59
+ for page in reader.pages:
60
+ text = page.extract_text()
61
+ if text:
62
+ text_parts.append(text)
63
+ return "\n\n".join(text_parts)
64
+
65
+ def _clean_text(self, text: str) -> str:
66
+ """Clean extracted text"""
67
+ # Remove excessive whitespace
68
+ text = re.sub(r'\s+', ' ', text)
69
+ # Remove page numbers and headers
70
+ text = re.sub(r'\n\d+\s*\n', '\n', text)
71
+ # Fix broken words from line breaks
72
+ text = re.sub(r'(\w)-\s+(\w)', r'\1\2', text)
73
+ return text.strip()
74
+
75
+ def _extract_pdf_with_pages(self, pdf_path: Path) -> List[Tuple[str, int]]:
76
+ """Extract text from PDF with page numbers"""
77
+ try:
78
+ import pdfplumber
79
+ pages = []
80
+ with pdfplumber.open(pdf_path) as pdf:
81
+ for i, page in enumerate(pdf.pages, 1):
82
+ page_text = page.extract_text()
83
+ if page_text:
84
+ pages.append((page_text, i))
85
+ return pages
86
+ except ImportError:
87
+ from PyPDF2 import PdfReader
88
+ reader = PdfReader(pdf_path)
89
+ pages = []
90
+ for i, page in enumerate(reader.pages, 1):
91
+ text = page.extract_text()
92
+ if text:
93
+ pages.append((text, i))
94
+ return pages
95
+
96
+ def _chunk_text(self, text: str, source: str, page_num: int = 0) -> List[Dict]:
97
+ """
98
+ Chunk text into overlapping segments.
99
+ Returns list of dicts with 'text', 'source', 'chunk_id', 'page'.
100
+ """
101
+ # Approximate tokens by words (rough estimate: 1 token ≈ 0.75 words)
102
+ words = text.split()
103
+ chunk_words = int(CHUNK_SIZE * 0.75)
104
+ overlap_words = int(CHUNK_OVERLAP * 0.75)
105
+
106
+ chunks = []
107
+ start = 0
108
+ chunk_id = 0
109
+
110
+ while start < len(words):
111
+ end = start + chunk_words
112
+ chunk_text = ' '.join(words[start:end])
113
+
114
+ # Try to end at sentence boundary
115
+ if end < len(words):
116
+ last_period = chunk_text.rfind('.')
117
+ if last_period > len(chunk_text) * 0.7:
118
+ chunk_text = chunk_text[:last_period + 1]
119
+
120
+ chunks.append({
121
+ 'text': chunk_text,
122
+ 'source': source,
123
+ 'chunk_id': chunk_id,
124
+ 'page': page_num
125
+ })
126
+
127
+ start = end - overlap_words
128
+ chunk_id += 1
129
+
130
+ return chunks
131
+
132
+ def build_index(self, force_rebuild: bool = False) -> bool:
133
+ """
134
+ Build FAISS index from guideline PDFs.
135
+ Returns True if index was built, False if loaded from cache.
136
+ """
137
+ # Check if index already exists
138
+ if not force_rebuild and FAISS_INDEX_PATH.exists() and CHUNKS_PATH.exists():
139
+ return self.load_index()
140
+
141
+ print("Building guidelines index...")
142
+ self._load_embedder()
143
+
144
+ # Create index directory
145
+ INDEX_DIR.mkdir(parents=True, exist_ok=True)
146
+
147
+ # Extract and chunk all PDFs with page tracking
148
+ all_chunks = []
149
+ pdf_files = list(GUIDELINES_DIR.glob("*.pdf"))
150
+
151
+ for pdf_path in pdf_files:
152
+ print(f" Processing: {pdf_path.name}")
153
+ pages = self._extract_pdf_with_pages(pdf_path)
154
+ pdf_chunks = 0
155
+ for page_text, page_num in pages:
156
+ cleaned = self._clean_text(page_text)
157
+ chunks = self._chunk_text(cleaned, pdf_path.name, page_num)
158
+ all_chunks.extend(chunks)
159
+ pdf_chunks += len(chunks)
160
+ print(f" -> {pdf_chunks} chunks from {len(pages)} pages")
161
+
162
+ if not all_chunks:
163
+ print("No chunks extracted from PDFs!")
164
+ return False
165
+
166
+ self.chunks = all_chunks
167
+ print(f"Total chunks: {len(self.chunks)}")
168
+
169
+ # Generate embeddings
170
+ print("Generating embeddings...")
171
+ texts = [c['text'] for c in self.chunks]
172
+ embeddings = self.embedder.encode(texts, show_progress_bar=True)
173
+ embeddings = np.array(embeddings).astype('float32')
174
+
175
+ # Build FAISS index
176
+ import faiss
177
+ dimension = embeddings.shape[1]
178
+ self.index = faiss.IndexFlatIP(dimension) # Inner product (cosine with normalized vectors)
179
+
180
+ # Normalize embeddings for cosine similarity
181
+ faiss.normalize_L2(embeddings)
182
+ self.index.add(embeddings)
183
+
184
+ # Save index and chunks
185
+ faiss.write_index(self.index, str(FAISS_INDEX_PATH))
186
+ with open(CHUNKS_PATH, 'w') as f:
187
+ json.dump(self.chunks, f)
188
+
189
+ print(f"Index saved to {INDEX_DIR}")
190
+ self.loaded = True
191
+ return True
192
+
193
+ def load_index(self) -> bool:
194
+ """Load persisted FAISS index and chunks"""
195
+ if not FAISS_INDEX_PATH.exists() or not CHUNKS_PATH.exists():
196
+ return False
197
+
198
+ import faiss
199
+ self.index = faiss.read_index(str(FAISS_INDEX_PATH))
200
+
201
+ with open(CHUNKS_PATH, 'r') as f:
202
+ self.chunks = json.load(f)
203
+
204
+ self._load_embedder()
205
+ self.loaded = True
206
+ return True
207
+
208
+ def search(self, query: str, k: int = 5) -> List[Dict]:
209
+ """
210
+ Search for relevant guideline chunks.
211
+ Returns list of chunks with similarity scores.
212
+ """
213
+ if not self.loaded:
214
+ if not self.load_index():
215
+ self.build_index()
216
+
217
+ import faiss
218
+
219
+ # Encode query
220
+ query_embedding = self.embedder.encode([query])
221
+ query_embedding = np.array(query_embedding).astype('float32')
222
+ faiss.normalize_L2(query_embedding)
223
+
224
+ # Search
225
+ scores, indices = self.index.search(query_embedding, k)
226
+
227
+ results = []
228
+ for score, idx in zip(scores[0], indices[0]):
229
+ if idx < len(self.chunks):
230
+ chunk = self.chunks[idx].copy()
231
+ chunk['score'] = float(score)
232
+ results.append(chunk)
233
+
234
+ return results
235
+
236
+ def get_management_context(self, diagnosis: str, features: Optional[str] = None) -> Tuple[str, List[Dict]]:
237
+ """
238
+ Get formatted context from guidelines for management recommendations.
239
+ Returns tuple of (context_string, references_list).
240
+ References can be used for citation hyperlinks.
241
+ """
242
+ # Build search query
243
+ query = f"{diagnosis} management treatment recommendations"
244
+ if features:
245
+ query += f" {features}"
246
+
247
+ chunks = self.search(query, k=5)
248
+
249
+ if not chunks:
250
+ return "No relevant guidelines found.", []
251
+
252
+ # Build context and collect references
253
+ context_parts = []
254
+ references = []
255
+
256
+ # Unicode superscript digits
257
+ superscripts = ['¹', '²', '³', '⁴', '⁵', '⁶', '⁷', '⁸', '⁹']
258
+
259
+ for i, chunk in enumerate(chunks, 1):
260
+ source = chunk['source'].replace('.pdf', '')
261
+ page = chunk.get('page', 0)
262
+ ref_id = f"ref{i}"
263
+ superscript = superscripts[i-1] if i <= len(superscripts) else f"[{i}]"
264
+
265
+ # Add reference marker with superscript
266
+ context_parts.append(f"[Source {superscript}] {chunk['text']}")
267
+
268
+ # Collect reference info
269
+ references.append({
270
+ 'id': ref_id,
271
+ 'source': source,
272
+ 'page': page,
273
+ 'file': chunk['source'],
274
+ 'score': chunk.get('score', 0)
275
+ })
276
+
277
+ context = "\n\n".join(context_parts)
278
+ return context, references
279
+
280
+ def format_references_for_prompt(self, references: List[Dict]) -> str:
281
+ """Format references for inclusion in LLM prompt"""
282
+ if not references:
283
+ return ""
284
+
285
+ lines = ["\n**References:**"]
286
+ for ref in references:
287
+ lines.append(f"[{ref['id']}] {ref['source']}, p.{ref['page']}")
288
+ return "\n".join(lines)
289
+
290
+ def format_references_for_display(self, references: List[Dict]) -> str:
291
+ """
292
+ Format references with markers that frontend can parse into hyperlinks.
293
+ Uses format: [REF:id:source:page:file:superscript]
294
+ """
295
+ if not references:
296
+ return ""
297
+
298
+ # Unicode superscript digits
299
+ superscripts = ['¹', '²', '³', '⁴', '⁵', '⁶', '⁷', '⁸', '⁹']
300
+
301
+ lines = ["\n[REFERENCES]"]
302
+ for i, ref in enumerate(references, 1):
303
+ superscript = superscripts[i-1] if i <= len(superscripts) else f"[{i}]"
304
+ # Format: [REF:ref1:Melanoma Guidelines:5:melanoma.pdf:¹]
305
+ lines.append(f"[REF:{ref['id']}:{ref['source']}:{ref['page']}:{ref['file']}:{superscript}]")
306
+ lines.append("[/REFERENCES]")
307
+ return "\n".join(lines)
308
+
309
+
310
+ # Singleton instance
311
+ _rag_instance = None
312
+
313
+
314
+ def get_guidelines_rag() -> GuidelinesRAG:
315
+ """Get or create RAG instance"""
316
+ global _rag_instance
317
+ if _rag_instance is None:
318
+ _rag_instance = GuidelinesRAG()
319
+ return _rag_instance
320
+
321
+
322
+ if __name__ == "__main__":
323
+ print("=" * 60)
324
+ print(" Guidelines RAG System - Index Builder")
325
+ print("=" * 60)
326
+
327
+ rag = GuidelinesRAG()
328
+
329
+ # Build or rebuild index
330
+ import sys
331
+ force = "--force" in sys.argv
332
+ rag.build_index(force_rebuild=force)
333
+
334
+ # Test search
335
+ print("\n" + "=" * 60)
336
+ print(" Testing Search")
337
+ print("=" * 60)
338
+
339
+ test_queries = [
340
+ "melanoma management",
341
+ "actinic keratosis treatment",
342
+ "surgical excision margins"
343
+ ]
344
+
345
+ for query in test_queries:
346
+ print(f"\nQuery: '{query}'")
347
+ results = rag.search(query, k=2)
348
+ for r in results:
349
+ print(f" [{r['score']:.3f}] {r['source']}: {r['text'][:100]}...")
models/medgemma_agent.py ADDED
@@ -0,0 +1,927 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MedGemma Agent - LLM agent with tool calling and staged thinking feedback
3
+
4
+ Pipeline: MedGemma independent exam → Tools (MONET/ConvNeXt/GradCAM) → MedGemma reconciliation → Management
5
+ """
6
+
7
+ import sys
8
+ import time
9
+ import random
10
+ import json
11
+ import os
12
+ import subprocess
13
+ import threading
14
+ from typing import Optional, Generator, Dict, Any
15
+ from PIL import Image
16
+
17
+
18
+ class MCPClient:
19
+ """
20
+ Minimal MCP client that communicates with a FastMCP subprocess over stdio.
21
+
22
+ Uses raw newline-delimited JSON-RPC 2.0 so the main process (Python 3.9)
23
+ does not need the mcp library. The subprocess is launched with python3.11
24
+ which has mcp installed.
25
+ """
26
+
27
+ def __init__(self):
28
+ self._process = None
29
+ self._lock = threading.Lock()
30
+ self._id_counter = 0
31
+
32
+ def _next_id(self) -> int:
33
+ self._id_counter += 1
34
+ return self._id_counter
35
+
36
+ def _send(self, obj: dict):
37
+ line = json.dumps(obj) + "\n"
38
+ self._process.stdin.write(line)
39
+ self._process.stdin.flush()
40
+
41
+ def _recv(self) -> dict:
42
+ while True:
43
+ line = self._process.stdout.readline()
44
+ if not line:
45
+ raise RuntimeError("MCP server closed connection unexpectedly")
46
+ line = line.strip()
47
+ if not line:
48
+ continue
49
+ msg = json.loads(line)
50
+ # Skip server-initiated notifications (no "id" key)
51
+ if "id" in msg:
52
+ return msg
53
+
54
+ def _initialize(self):
55
+ """Send MCP initialize handshake."""
56
+ req_id = self._next_id()
57
+ self._send({
58
+ "jsonrpc": "2.0",
59
+ "id": req_id,
60
+ "method": "initialize",
61
+ "params": {
62
+ "protocolVersion": "2024-11-05",
63
+ "capabilities": {},
64
+ "clientInfo": {"name": "SkinProAI", "version": "1.0.0"},
65
+ },
66
+ })
67
+ self._recv() # consume initialize response
68
+ # Confirm initialization
69
+ self._send({
70
+ "jsonrpc": "2.0",
71
+ "method": "notifications/initialized",
72
+ "params": {},
73
+ })
74
+
75
+ def start(self):
76
+ """Spawn the MCP server subprocess and complete the handshake."""
77
+ root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
78
+ server_script = os.path.join(root, "mcp_server", "server.py")
79
+ self._process = subprocess.Popen(
80
+ [sys.executable, server_script], # use same venv Python (has all ML packages)
81
+ stdin=subprocess.PIPE,
82
+ stdout=subprocess.PIPE,
83
+ stderr=subprocess.PIPE,
84
+ text=True,
85
+ bufsize=1, # line-buffered
86
+ )
87
+ self._initialize()
88
+
89
+ def call_tool_sync(self, tool_name: str, arguments: dict) -> dict:
90
+ """Call a tool synchronously and return the parsed result dict."""
91
+ with self._lock:
92
+ req_id = self._next_id()
93
+ self._send({
94
+ "jsonrpc": "2.0",
95
+ "id": req_id,
96
+ "method": "tools/call",
97
+ "params": {"name": tool_name, "arguments": arguments},
98
+ })
99
+ response = self._recv()
100
+
101
+ # Protocol-level error (e.g. unknown method)
102
+ if "error" in response:
103
+ raise RuntimeError(
104
+ f"MCP tool '{tool_name}' failed: {response['error']}"
105
+ )
106
+
107
+ result = response["result"]
108
+ content_text = result["content"][0]["text"]
109
+
110
+ # Tool-level error (isError=True means the tool itself raised an exception)
111
+ if result.get("isError"):
112
+ raise RuntimeError(f"MCP tool '{tool_name}' error: {content_text}")
113
+
114
+ return json.loads(content_text)
115
+
116
+ def stop(self):
117
+ """Terminate the MCP server subprocess."""
118
+ if self._process:
119
+ try:
120
+ self._process.stdin.close()
121
+ self._process.terminate()
122
+ self._process.wait(timeout=5)
123
+ except Exception:
124
+ pass
125
+ self._process = None
126
+
127
+
128
+ # Rotating verbs for spinner effect
129
+ ANALYSIS_VERBS = [
130
+ "Analyzing", "Examining", "Processing", "Inspecting", "Evaluating",
131
+ "Scanning", "Assessing", "Reviewing", "Studying", "Interpreting"
132
+ ]
133
+
134
+ # Comprehensive visual exam prompt (combined from 4 separate stages)
135
+ COMPREHENSIVE_EXAM_PROMPT = """Perform a systematic dermoscopic examination of this skin lesion. Assess ALL of the following in a SINGLE concise analysis:
136
+
137
+ 1. PATTERN: Overall architecture, symmetry (symmetric/asymmetric), organization
138
+ 2. COLORS: List all colors present (brown, black, blue, white, red, pink) and distribution
139
+ 3. BORDER: Sharp vs gradual, regular vs irregular, any disruptions
140
+ 4. STRUCTURES: Pigment network, dots/globules, streaks, blue-white veil, regression, vessels
141
+
142
+ Then provide:
143
+ - Top 3 differential diagnoses with brief reasoning
144
+ - Concern level (1-5, where 5=urgent)
145
+ - Single most important feature driving your assessment
146
+
147
+ Be CONCISE - focus on clinically relevant findings only."""
148
+
149
+
150
+ def get_verb():
151
+ """Get a random analysis verb for spinner effect"""
152
+ return random.choice(ANALYSIS_VERBS)
153
+
154
+
155
+ class MedGemmaAgent:
156
+ """
157
+ Medical image analysis agent with:
158
+ - Staged thinking display (no emojis)
159
+ - Tool calling (MONET, ConvNeXt, Grad-CAM)
160
+ - Streaming responses
161
+ """
162
+
163
+ def __init__(self, verbose: bool = True):
164
+ self.verbose = verbose
165
+ self.pipe = None
166
+ self.model_id = "google/medgemma-4b-it"
167
+ self.loaded = False
168
+
169
+ # Tools (legacy direct instances, kept for fallback / non-MCP use)
170
+ self.monet_tool = None
171
+ self.convnext_tool = None
172
+ self.gradcam_tool = None
173
+ self.rag_tool = None
174
+ self.tools_loaded = False
175
+
176
+ # MCP client
177
+ self.mcp_client = None
178
+
179
+ # State for confirmation flow
180
+ self.last_diagnosis = None
181
+ self.last_monet_result = None
182
+ self.last_image = None
183
+ self.last_medgemma_exam = None # Store independent MedGemma findings
184
+ self.last_reconciliation = None
185
+
186
+ def reset_state(self):
187
+ """Reset analysis state for new analysis (keeps models loaded)"""
188
+ self.last_diagnosis = None
189
+ self.last_monet_result = None
190
+ self.last_image = None
191
+ self.last_medgemma_exam = None
192
+ self.last_reconciliation = None
193
+
194
+ def _print(self, message: str):
195
+ """Print if verbose"""
196
+ if self.verbose:
197
+ print(message, flush=True)
198
+
199
+ def load_model(self):
200
+ """Load MedGemma model"""
201
+ if self.loaded:
202
+ return
203
+
204
+ self._print("Initializing MedGemma agent...")
205
+
206
+ import torch
207
+ from transformers import pipeline
208
+
209
+ self._print(f"Loading model: {self.model_id}")
210
+
211
+ if torch.cuda.is_available():
212
+ device = "cuda"
213
+ self._print(f"Using GPU: {torch.cuda.get_device_name(0)}")
214
+ elif torch.backends.mps.is_available():
215
+ device = "mps"
216
+ self._print("Using Apple Silicon (MPS)")
217
+ else:
218
+ device = "cpu"
219
+ self._print("Using CPU")
220
+
221
+ model_kwargs = dict(
222
+ torch_dtype=torch.bfloat16 if device != "cpu" else torch.float32,
223
+ device_map="auto",
224
+ )
225
+
226
+ start = time.time()
227
+ self.pipe = pipeline(
228
+ "image-text-to-text",
229
+ model=self.model_id,
230
+ model_kwargs=model_kwargs
231
+ )
232
+
233
+ self._print(f"Model loaded in {time.time() - start:.1f}s")
234
+ self.loaded = True
235
+
236
+ def load_tools(self):
237
+ """Load tool models (MONET + ConvNeXt + Grad-CAM + RAG)"""
238
+ if self.tools_loaded:
239
+ return
240
+
241
+ from models.monet_tool import MonetTool
242
+ self.monet_tool = MonetTool()
243
+ self.monet_tool.load()
244
+
245
+ from models.convnext_classifier import ConvNeXtClassifier
246
+ self.convnext_tool = ConvNeXtClassifier()
247
+ self.convnext_tool.load()
248
+
249
+ from models.gradcam_tool import GradCAMTool
250
+ self.gradcam_tool = GradCAMTool(classifier=self.convnext_tool)
251
+ self.gradcam_tool.load()
252
+
253
+ from models.guidelines_rag import get_guidelines_rag
254
+ self.rag_tool = get_guidelines_rag()
255
+ if not self.rag_tool.loaded:
256
+ self.rag_tool.load_index()
257
+
258
+ self.tools_loaded = True
259
+
260
+ def load_tools_via_mcp(self):
261
+ """Start the MCP server subprocess and mark tools as loaded."""
262
+ if self.tools_loaded:
263
+ return
264
+ self.mcp_client = MCPClient()
265
+ self.mcp_client.start()
266
+ self.tools_loaded = True
267
+
268
+ def _multi_pass_visual_exam(self, image, question: Optional[str] = None) -> Generator[str, None, Dict[str, str]]:
269
+ """
270
+ MedGemma performs comprehensive visual examination BEFORE tools run.
271
+ Single prompt covers pattern, colors, borders, structures, and differentials.
272
+ Returns findings dict after yielding all output.
273
+ """
274
+ findings = {}
275
+
276
+ yield f"\n[STAGE:medgemma_exam]MedGemma Visual Examination[/STAGE]\n"
277
+ yield f"[THINKING]Performing systematic dermoscopic assessment...[/THINKING]\n"
278
+
279
+ # Build prompt with optional clinical question
280
+ exam_prompt = COMPREHENSIVE_EXAM_PROMPT
281
+ if question:
282
+ exam_prompt += f"\n\nCLINICAL QUESTION: {question}"
283
+
284
+ messages = [
285
+ {
286
+ "role": "user",
287
+ "content": [
288
+ {"type": "image", "image": image},
289
+ {"type": "text", "text": exam_prompt}
290
+ ]
291
+ }
292
+ ]
293
+
294
+ try:
295
+ time.sleep(0.2)
296
+ output = self.pipe(messages, max_new_tokens=400)
297
+ result = output[0]["generated_text"][-1]["content"]
298
+ findings['synthesis'] = result
299
+
300
+ yield f"[RESPONSE]\n"
301
+ words = result.split()
302
+ for i, word in enumerate(words):
303
+ time.sleep(0.015)
304
+ yield word + (" " if i < len(words) - 1 else "")
305
+ yield f"\n[/RESPONSE]\n"
306
+
307
+ except Exception as e:
308
+ findings['synthesis'] = f"Analysis failed: {e}"
309
+ yield f"[ERROR]Visual examination failed: {e}[/ERROR]\n"
310
+
311
+ self.last_medgemma_exam = findings
312
+ return findings
313
+
314
+ def _reconcile_findings(
315
+ self,
316
+ image,
317
+ medgemma_exam: Dict[str, str],
318
+ monet_result: Dict[str, Any],
319
+ convnext_result: Dict[str, Any],
320
+ question: Optional[str] = None
321
+ ) -> Generator[str, None, None]:
322
+ """
323
+ MedGemma reconciles its independent findings with tool outputs.
324
+ Identifies agreements, disagreements, and provides integrated assessment.
325
+ """
326
+ yield f"\n[STAGE:reconciliation]Reconciling MedGemma Findings with Tool Results[/STAGE]\n"
327
+ yield f"[THINKING]Comparing independent visual assessment against AI classification tools...[/THINKING]\n"
328
+
329
+ top = convnext_result['predictions'][0]
330
+ runner_up = convnext_result['predictions'][1] if len(convnext_result['predictions']) > 1 else None
331
+
332
+ # Build MONET features string
333
+ monet_top = sorted(monet_result["features"].items(), key=lambda x: x[1], reverse=True)[:5]
334
+ monet_str = ", ".join([f"{k.replace('MONET_', '').replace('_', ' ')}: {v:.0%}" for k, v in monet_top])
335
+
336
+ reconciliation_prompt = f"""You performed an independent visual examination of this lesion and concluded:
337
+
338
+ YOUR ASSESSMENT:
339
+ {medgemma_exam.get('synthesis', 'Not available')[:600]}
340
+
341
+ The AI classification tools produced these results:
342
+ - ConvNeXt classifier: {top['full_name']} ({top['probability']:.1%} confidence)
343
+ {f"- Runner-up: {runner_up['full_name']} ({runner_up['probability']:.1%})" if runner_up else ""}
344
+ - Key MONET features: {monet_str}
345
+
346
+ {f'CLINICAL QUESTION: {question}' if question else ''}
347
+
348
+ Reconcile your visual findings with the AI classification:
349
+ 1. AGREEMENT/DISAGREEMENT: Do your findings support the AI diagnosis? Any conflicts?
350
+ 2. INTEGRATED ASSESSMENT: Final diagnosis considering all evidence
351
+ 3. CONFIDENCE (1-10): How certain? What would change your assessment?
352
+
353
+ Be concise and specific."""
354
+
355
+ messages = [
356
+ {
357
+ "role": "user",
358
+ "content": [
359
+ {"type": "image", "image": image},
360
+ {"type": "text", "text": reconciliation_prompt}
361
+ ]
362
+ }
363
+ ]
364
+
365
+ try:
366
+ output = self.pipe(messages, max_new_tokens=300)
367
+ reconciliation = output[0]["generated_text"][-1]["content"]
368
+ self.last_reconciliation = reconciliation
369
+
370
+ yield f"[RESPONSE]\n"
371
+ words = reconciliation.split()
372
+ for i, word in enumerate(words):
373
+ time.sleep(0.015)
374
+ yield word + (" " if i < len(words) - 1 else "")
375
+ yield f"\n[/RESPONSE]\n"
376
+
377
+ except Exception as e:
378
+ yield f"[ERROR]Reconciliation failed: {e}[/ERROR]\n"
379
+
380
+ def analyze_image_stream(
381
+ self,
382
+ image_path: str,
383
+ question: Optional[str] = None,
384
+ max_tokens: int = 512,
385
+ use_tools: bool = True
386
+ ) -> Generator[str, None, None]:
387
+ """
388
+ Stream analysis with new pipeline:
389
+ 1. MedGemma independent multi-pass exam
390
+ 2. MONET + ConvNeXt + GradCAM tools
391
+ 3. MedGemma reconciliation
392
+ 4. Confirmation request
393
+ """
394
+ if not self.loaded:
395
+ yield "[STAGE:loading]Initializing MedGemma...[/STAGE]\n"
396
+ self.load_model()
397
+
398
+ yield f"[STAGE:image]{get_verb()} image...[/STAGE]\n"
399
+
400
+ try:
401
+ image = Image.open(image_path).convert("RGB")
402
+ self.last_image = image
403
+ except Exception as e:
404
+ yield f"[ERROR]Failed to load image: {e}[/ERROR]\n"
405
+ return
406
+
407
+ # Load tools early via MCP subprocess
408
+ if use_tools and not self.tools_loaded:
409
+ yield f"[STAGE:tools]Loading analysis tools...[/STAGE]\n"
410
+ self.load_tools_via_mcp()
411
+
412
+ # ===== PHASE 1: MedGemma Independent Visual Examination =====
413
+ medgemma_exam = {}
414
+ for chunk in self._multi_pass_visual_exam(image, question):
415
+ yield chunk
416
+ if isinstance(chunk, dict):
417
+ medgemma_exam = chunk
418
+ medgemma_exam = self.last_medgemma_exam or {}
419
+
420
+ monet_result = None
421
+ convnext_result = None
422
+
423
+ if use_tools:
424
+ # ===== PHASE 2: Run Classification Tools =====
425
+ yield f"\n[STAGE:tools_run]Running AI Classification Tools[/STAGE]\n"
426
+ yield f"[THINKING]Now running MONET and ConvNeXt to compare against visual examination...[/THINKING]\n"
427
+
428
+ # MONET Feature Extraction
429
+ time.sleep(0.2)
430
+ yield f"\n[STAGE:monet]MONET Feature Extraction[/STAGE]\n"
431
+
432
+ try:
433
+ monet_result = self.mcp_client.call_tool_sync(
434
+ "monet_analyze", {"image_path": image_path}
435
+ )
436
+ self.last_monet_result = monet_result
437
+
438
+ yield f"[TOOL_OUTPUT:MONET Features]\n"
439
+ for name, score in monet_result["features"].items():
440
+ short_name = name.replace("MONET_", "").replace("_", " ").title()
441
+ bar_filled = int(score * 10)
442
+ bar = "|" + "=" * bar_filled + "-" * (10 - bar_filled) + "|"
443
+ yield f" {short_name}: {bar} {score:.0%}\n"
444
+ yield f"[/TOOL_OUTPUT]\n"
445
+
446
+ except Exception as e:
447
+ yield f"[ERROR]MONET failed: {e}[/ERROR]\n"
448
+
449
+ # ConvNeXt Classification
450
+ time.sleep(0.2)
451
+ yield f"\n[STAGE:convnext]ConvNeXt Classification[/STAGE]\n"
452
+
453
+ try:
454
+ monet_scores = monet_result["vector"] if monet_result else None
455
+ convnext_result = self.mcp_client.call_tool_sync(
456
+ "classify_lesion",
457
+ {
458
+ "image_path": image_path,
459
+ "monet_scores": monet_scores,
460
+ },
461
+ )
462
+ self.last_diagnosis = convnext_result
463
+
464
+ yield f"[TOOL_OUTPUT:Classification Results]\n"
465
+ for pred in convnext_result["predictions"][:5]:
466
+ prob = pred['probability']
467
+ bar_filled = int(prob * 20)
468
+ bar = "|" + "=" * bar_filled + "-" * (20 - bar_filled) + "|"
469
+ yield f" {pred['class']}: {bar} {prob:.1%}\n"
470
+ yield f" {pred['full_name']}\n"
471
+ yield f"[/TOOL_OUTPUT]\n"
472
+
473
+ top = convnext_result['predictions'][0]
474
+ yield f"[RESULT]ConvNeXt Primary: {top['full_name']} ({top['probability']:.1%})[/RESULT]\n"
475
+
476
+ except Exception as e:
477
+ yield f"[ERROR]ConvNeXt failed: {e}[/ERROR]\n"
478
+
479
+ # Grad-CAM Visualization
480
+ time.sleep(0.2)
481
+ yield f"\n[STAGE:gradcam]Grad-CAM Attention Map[/STAGE]\n"
482
+
483
+ try:
484
+ gradcam_result = self.mcp_client.call_tool_sync(
485
+ "generate_gradcam", {"image_path": image_path}
486
+ )
487
+ gradcam_path = gradcam_result["gradcam_path"]
488
+ yield f"[GRADCAM_IMAGE:{gradcam_path}]\n"
489
+ except Exception as e:
490
+ yield f"[ERROR]Grad-CAM failed: {e}[/ERROR]\n"
491
+
492
+ # ===== PHASE 3: MedGemma Reconciliation =====
493
+ if convnext_result and monet_result and medgemma_exam:
494
+ for chunk in self._reconcile_findings(
495
+ image, medgemma_exam, monet_result, convnext_result, question
496
+ ):
497
+ yield chunk
498
+
499
+ # Yield confirmation request
500
+ if convnext_result:
501
+ top = convnext_result['predictions'][0]
502
+ yield f"\n[CONFIRM:diagnosis]Do you agree with the integrated assessment?[/CONFIRM]\n"
503
+
504
+ def generate_management_guidance(
505
+ self,
506
+ user_confirmed: bool = True,
507
+ user_feedback: Optional[str] = None
508
+ ) -> Generator[str, None, None]:
509
+ """
510
+ Generate LESION-SPECIFIC management guidance using RAG + MedGemma reasoning.
511
+ References specific findings from this analysis, not generic textbook management.
512
+ """
513
+ if not self.last_diagnosis:
514
+ yield "[ERROR]No diagnosis available. Please analyze an image first.[/ERROR]\n"
515
+ return
516
+
517
+ top = self.last_diagnosis['predictions'][0]
518
+ runner_up = self.last_diagnosis['predictions'][1] if len(self.last_diagnosis['predictions']) > 1 else None
519
+ diagnosis = top['full_name']
520
+
521
+ if not user_confirmed and user_feedback:
522
+ yield f"[THINKING]Clinician provided alternative assessment: {user_feedback}[/THINKING]\n"
523
+ diagnosis = user_feedback
524
+
525
+ # Stage: RAG Search
526
+ time.sleep(0.3)
527
+ yield f"\n[STAGE:guidelines]Searching clinical guidelines for {diagnosis}...[/STAGE]\n"
528
+
529
+ # Get RAG context via MCP
530
+ features_desc = self.last_monet_result.get('description', '') if self.last_monet_result else ''
531
+ rag_data = self.mcp_client.call_tool_sync(
532
+ "search_guidelines",
533
+ {"query": features_desc, "diagnosis": diagnosis},
534
+ )
535
+ context = rag_data["context"]
536
+ references = rag_data["references"]
537
+
538
+ # Check guideline relevance
539
+ has_relevant_guidelines = False
540
+ if references:
541
+ diagnosis_lower = diagnosis.lower()
542
+ for ref in references:
543
+ source_lower = ref['source'].lower()
544
+ if any(term in diagnosis_lower for term in ['melanoma']) and 'melanoma' in source_lower:
545
+ has_relevant_guidelines = True
546
+ break
547
+ elif 'actinic' in diagnosis_lower and 'actinic' in source_lower:
548
+ has_relevant_guidelines = True
549
+ break
550
+ elif ref.get('score', 0) > 0.7:
551
+ has_relevant_guidelines = True
552
+ break
553
+
554
+ if not references or not has_relevant_guidelines:
555
+ yield f"[THINKING]No specific published guidelines for {diagnosis}. Using clinical knowledge.[/THINKING]\n"
556
+ context = "No specific clinical guidelines available."
557
+ references = []
558
+
559
+ # Build MONET features for context
560
+ monet_features = ""
561
+ if self.last_monet_result:
562
+ top_features = sorted(self.last_monet_result["features"].items(), key=lambda x: x[1], reverse=True)[:5]
563
+ monet_features = ", ".join([f"{k.replace('MONET_', '').replace('_', ' ')}: {v:.0%}" for k, v in top_features])
564
+
565
+ # Stage: Lesion-Specific Management Reasoning
566
+ time.sleep(0.3)
567
+ yield f"\n[STAGE:management]Generating Lesion-Specific Management Plan[/STAGE]\n"
568
+ yield f"[THINKING]Creating management plan tailored to THIS lesion's specific characteristics...[/THINKING]\n"
569
+
570
+ management_prompt = f"""Generate a CONCISE management plan for this lesion:
571
+
572
+ DIAGNOSIS: {diagnosis} ({top['probability']:.1%})
573
+ {f"Alternative: {runner_up['full_name']} ({runner_up['probability']:.1%})" if runner_up else ""}
574
+ KEY FEATURES: {monet_features}
575
+
576
+ {f"GUIDELINES: {context[:800]}" if context else ""}
577
+
578
+ Provide:
579
+ 1. RECOMMENDED ACTION: Biopsy, excision, monitoring, or discharge - with specific reasoning
580
+ 2. URGENCY: Routine vs urgent vs same-day referral
581
+ 3. KEY CONCERNS: What features drive this recommendation
582
+
583
+ Be specific to THIS lesion. 3-5 sentences maximum."""
584
+
585
+ messages = [
586
+ {
587
+ "role": "user",
588
+ "content": [
589
+ {"type": "image", "image": self.last_image},
590
+ {"type": "text", "text": management_prompt}
591
+ ]
592
+ }
593
+ ]
594
+
595
+ # Generate response
596
+ start = time.time()
597
+ try:
598
+ output = self.pipe(messages, max_new_tokens=250)
599
+ response = output[0]["generated_text"][-1]["content"]
600
+
601
+ yield f"[RESPONSE]\n"
602
+ words = response.split()
603
+ for i, word in enumerate(words):
604
+ time.sleep(0.015)
605
+ yield word + (" " if i < len(words) - 1 else "")
606
+ yield f"\n[/RESPONSE]\n"
607
+
608
+ except Exception as e:
609
+ yield f"[ERROR]Management generation failed: {e}[/ERROR]\n"
610
+
611
+ # Output references (pre-formatted by MCP server)
612
+ if references:
613
+ yield rag_data["references_display"]
614
+
615
+ yield f"\n[COMPLETE]Lesion-specific management plan generated in {time.time() - start:.1f}s[/COMPLETE]\n"
616
+
617
+ # Store response for recommendation extraction
618
+ self.last_management_response = response
619
+
620
+ def extract_recommendation(self) -> Generator[str, None, Dict[str, Any]]:
621
+ """
622
+ Extract structured recommendation from management guidance.
623
+ Determines: BIOPSY, EXCISION, FOLLOWUP, or DISCHARGE
624
+ For BIOPSY/EXCISION, gets coordinates from MedGemma.
625
+ """
626
+ if not self.last_management_response or not self.last_image:
627
+ yield "[ERROR]No management guidance available[/ERROR]\n"
628
+ return {"action": "UNKNOWN"}
629
+
630
+ yield f"\n[STAGE:recommendation]Extracting Clinical Recommendation[/STAGE]\n"
631
+
632
+ # Ask MedGemma to classify the recommendation
633
+ classification_prompt = f"""Based on the management plan you just provided:
634
+
635
+ {self.last_management_response[:1000]}
636
+
637
+ Classify the PRIMARY recommended action into exactly ONE of these categories:
638
+ - BIOPSY: If punch biopsy, shave biopsy, or incisional biopsy is recommended
639
+ - EXCISION: If complete surgical excision is recommended
640
+ - FOLLOWUP: If monitoring with repeat photography/dermoscopy is recommended
641
+ - DISCHARGE: If the lesion is clearly benign and no follow-up needed
642
+
643
+ Respond with ONLY the category name (BIOPSY, EXCISION, FOLLOWUP, or DISCHARGE) on the first line.
644
+ Then on the second line, provide a brief (1 sentence) justification."""
645
+
646
+ messages = [
647
+ {
648
+ "role": "user",
649
+ "content": [
650
+ {"type": "image", "image": self.last_image},
651
+ {"type": "text", "text": classification_prompt}
652
+ ]
653
+ }
654
+ ]
655
+
656
+ try:
657
+ output = self.pipe(messages, max_new_tokens=100)
658
+ response = output[0]["generated_text"][-1]["content"].strip()
659
+ lines = response.split('\n')
660
+ action = lines[0].strip().upper()
661
+ justification = lines[1].strip() if len(lines) > 1 else ""
662
+
663
+ # Validate action
664
+ valid_actions = ["BIOPSY", "EXCISION", "FOLLOWUP", "DISCHARGE"]
665
+ if action not in valid_actions:
666
+ # Try to extract from response
667
+ for valid in valid_actions:
668
+ if valid in response.upper():
669
+ action = valid
670
+ break
671
+ else:
672
+ action = "FOLLOWUP" # Default to safe option
673
+
674
+ yield f"[RESULT]Recommended Action: {action}[/RESULT]\n"
675
+ yield f"[OBSERVATION]{justification}[/OBSERVATION]\n"
676
+
677
+ result = {
678
+ "action": action,
679
+ "justification": justification
680
+ }
681
+
682
+ return result
683
+
684
+ except Exception as e:
685
+ yield f"[ERROR]Failed to extract recommendation: {e}[/ERROR]\n"
686
+ return {"action": "UNKNOWN", "error": str(e)}
687
+
688
+ def compare_followup_images(
689
+ self,
690
+ previous_image_path: str,
691
+ current_image_path: str
692
+ ) -> Generator[str, None, None]:
693
+ """
694
+ Compare a follow-up image with the previous one.
695
+ Runs full analysis pipeline on current image, then compares findings.
696
+ """
697
+ yield f"\n[STAGE:comparison]Follow-up Comparison Analysis[/STAGE]\n"
698
+
699
+ try:
700
+ current_image = Image.open(current_image_path).convert("RGB")
701
+ except Exception as e:
702
+ yield f"[ERROR]Failed to load images: {e}[/ERROR]\n"
703
+ return
704
+
705
+ # Store previous analysis state
706
+ prev_exam = self.last_medgemma_exam
707
+
708
+ # Generate comparison image and MONET deltas via MCP
709
+ yield f"\n[STAGE:current_analysis]Analyzing Current Image[/STAGE]\n"
710
+
711
+ if self.tools_loaded:
712
+ try:
713
+ compare_data = self.mcp_client.call_tool_sync(
714
+ "compare_images",
715
+ {
716
+ "image1_path": previous_image_path,
717
+ "image2_path": current_image_path,
718
+ },
719
+ )
720
+ yield f"[COMPARISON_IMAGE:{compare_data['comparison_path']}]\n"
721
+
722
+ # Side-by-side GradCAM comparison if both paths available
723
+ prev_gc = compare_data.get("prev_gradcam_path")
724
+ curr_gc = compare_data.get("curr_gradcam_path")
725
+ if prev_gc and curr_gc:
726
+ yield f"[GRADCAM_COMPARE:{prev_gc}:{curr_gc}]\n"
727
+
728
+ # Display MONET feature deltas
729
+ if compare_data["monet_deltas"]:
730
+ yield f"[TOOL_OUTPUT:Feature Comparison]\n"
731
+ for name, delta_info in compare_data["monet_deltas"].items():
732
+ prev_val = delta_info["previous"]
733
+ curr_val = delta_info["current"]
734
+ diff = delta_info["delta"]
735
+ short_name = name.replace("MONET_", "").replace("_", " ").title()
736
+ direction = "↑" if diff > 0 else "↓"
737
+ yield f" {short_name}: {prev_val:.0%} → {curr_val:.0%} ({direction}{abs(diff):.0%})\n"
738
+ yield f"[/TOOL_OUTPUT]\n"
739
+
740
+ except Exception as e:
741
+ yield f"[ERROR]MCP comparison failed: {e}[/ERROR]\n"
742
+
743
+ # MedGemma comparison analysis
744
+ comparison_prompt = f"""You are comparing TWO images of the same skin lesion taken at different times.
745
+
746
+ PREVIOUS ANALYSIS:
747
+ {prev_exam.get('synthesis', 'Not available')[:500] if prev_exam else 'Not available'}
748
+
749
+ Now examine the CURRENT image and compare to your memory of the previous findings.
750
+
751
+ Assess for changes in:
752
+ 1. SIZE: Has the lesion grown, shrunk, or stayed the same?
753
+ 2. COLOR: Any new colors appeared? Any colors faded?
754
+ 3. SHAPE/SYMMETRY: Has the shape changed? More or less symmetric?
755
+ 4. BORDERS: Sharper, more irregular, or unchanged?
756
+ 5. STRUCTURES: New dermoscopic structures? Lost structures?
757
+
758
+ Provide your assessment:
759
+ - CHANGE_LEVEL: SIGNIFICANT_CHANGE / MINOR_CHANGE / STABLE / IMPROVED
760
+ - Specific changes observed
761
+ - Clinical recommendation based on changes"""
762
+
763
+ messages = [
764
+ {
765
+ "role": "user",
766
+ "content": [
767
+ {"type": "image", "image": current_image},
768
+ {"type": "text", "text": comparison_prompt}
769
+ ]
770
+ }
771
+ ]
772
+
773
+ try:
774
+ yield f"[THINKING]Comparing current image to previous findings...[/THINKING]\n"
775
+ output = self.pipe(messages, max_new_tokens=400)
776
+ comparison_result = output[0]["generated_text"][-1]["content"]
777
+
778
+ yield f"[RESPONSE]\n"
779
+ words = comparison_result.split()
780
+ for i, word in enumerate(words):
781
+ time.sleep(0.02)
782
+ yield word + (" " if i < len(words) - 1 else "")
783
+ yield f"\n[/RESPONSE]\n"
784
+
785
+ # Extract change level
786
+ change_level = "UNKNOWN"
787
+ for level in ["SIGNIFICANT_CHANGE", "MINOR_CHANGE", "STABLE", "IMPROVED"]:
788
+ if level in comparison_result.upper():
789
+ change_level = level
790
+ break
791
+
792
+ if change_level == "SIGNIFICANT_CHANGE":
793
+ yield f"[RESULT]⚠️ SIGNIFICANT CHANGES DETECTED - Further evaluation recommended[/RESULT]\n"
794
+ elif change_level == "IMPROVED":
795
+ yield f"[RESULT]✓ LESION IMPROVED - Continue monitoring[/RESULT]\n"
796
+ elif change_level == "STABLE":
797
+ yield f"[RESULT]✓ LESION STABLE - Continue scheduled follow-up[/RESULT]\n"
798
+ else:
799
+ yield f"[RESULT]Minor changes noted - Clinical correlation recommended[/RESULT]\n"
800
+
801
+ except Exception as e:
802
+ yield f"[ERROR]Comparison analysis failed: {e}[/ERROR]\n"
803
+
804
+ yield f"\n[COMPLETE]Follow-up comparison complete[/COMPLETE]\n"
805
+
806
+ def chat(self, message: str, image_path: Optional[str] = None) -> str:
807
+ """Simple chat interface"""
808
+ if not self.loaded:
809
+ self.load_model()
810
+
811
+ content = []
812
+ if image_path:
813
+ image = Image.open(image_path).convert("RGB")
814
+ content.append({"type": "image", "image": image})
815
+ content.append({"type": "text", "text": message})
816
+
817
+ messages = [{"role": "user", "content": content}]
818
+ output = self.pipe(messages, max_new_tokens=512)
819
+ return output[0]["generated_text"][-1]["content"]
820
+
821
+ def chat_followup(self, message: str) -> Generator[str, None, None]:
822
+ """
823
+ Handle follow-up questions using the stored analysis context.
824
+ Uses the last analyzed image and diagnosis to provide contextual responses.
825
+ """
826
+ if not self.loaded:
827
+ yield "[ERROR]Model not loaded[/ERROR]\n"
828
+ return
829
+
830
+ if not self.last_diagnosis or not self.last_image:
831
+ yield "[ERROR]No previous analysis context. Please analyze an image first.[/ERROR]\n"
832
+ return
833
+
834
+ # Build context from previous analysis
835
+ top_diagnosis = self.last_diagnosis['predictions'][0]
836
+ differentials = ", ".join([
837
+ f"{p['class']} ({p['probability']:.0%})"
838
+ for p in self.last_diagnosis['predictions'][:3]
839
+ ])
840
+
841
+ monet_desc = ""
842
+ if self.last_monet_result:
843
+ monet_desc = self.last_monet_result.get('description', '')
844
+
845
+ context_prompt = f"""You are a dermatology assistant helping with skin lesion analysis.
846
+
847
+ PREVIOUS ANALYSIS CONTEXT:
848
+ - Primary diagnosis: {top_diagnosis['full_name']} ({top_diagnosis['probability']:.1%} confidence)
849
+ - Differential diagnoses: {differentials}
850
+ - Visual features: {monet_desc}
851
+
852
+ The user has a follow-up question about this lesion. Please provide a helpful, medically accurate response.
853
+
854
+ USER QUESTION: {message}
855
+
856
+ Provide a concise, informative response. If the question is outside your expertise or requires in-person examination, say so."""
857
+
858
+ messages = [
859
+ {
860
+ "role": "user",
861
+ "content": [
862
+ {"type": "image", "image": self.last_image},
863
+ {"type": "text", "text": context_prompt}
864
+ ]
865
+ }
866
+ ]
867
+
868
+ try:
869
+ yield f"[THINKING]Considering your question in context of the previous analysis...[/THINKING]\n"
870
+ time.sleep(0.2)
871
+
872
+ output = self.pipe(messages, max_new_tokens=400)
873
+ response = output[0]["generated_text"][-1]["content"]
874
+
875
+ yield f"[RESPONSE]\n"
876
+ # Stream word by word for typewriter effect
877
+ words = response.split()
878
+ for i, word in enumerate(words):
879
+ time.sleep(0.02)
880
+ yield word + (" " if i < len(words) - 1 else "")
881
+ yield f"\n[/RESPONSE]\n"
882
+
883
+ except Exception as e:
884
+ yield f"[ERROR]Failed to generate response: {e}[/ERROR]\n"
885
+
886
+
887
+ def main():
888
+ """Interactive terminal interface"""
889
+ print("=" * 60)
890
+ print(" MedGemma Agent - Medical Image Analysis")
891
+ print("=" * 60)
892
+
893
+ agent = MedGemmaAgent(verbose=True)
894
+ agent.load_model()
895
+
896
+ print("\nCommands: analyze <path>, chat <message>, quit")
897
+
898
+ while True:
899
+ try:
900
+ user_input = input("\n> ").strip()
901
+ if not user_input:
902
+ continue
903
+
904
+ if user_input.lower() in ["quit", "exit", "q"]:
905
+ break
906
+
907
+ parts = user_input.split(maxsplit=1)
908
+ cmd = parts[0].lower()
909
+
910
+ if cmd == "analyze" and len(parts) > 1:
911
+ for chunk in agent.analyze_image_stream(parts[1].strip()):
912
+ print(chunk, end="", flush=True)
913
+
914
+ elif cmd == "chat" and len(parts) > 1:
915
+ print(agent.chat(parts[1]))
916
+
917
+ else:
918
+ print("Unknown command")
919
+
920
+ except KeyboardInterrupt:
921
+ break
922
+ except Exception as e:
923
+ print(f"Error: {e}")
924
+
925
+
926
+ if __name__ == "__main__":
927
+ main()
models/medsiglip_convnext_fusion.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # models/medsiglip_convnext_fusion.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from typing import Dict, List, Tuple, Optional
6
+ import numpy as np
7
+ import timm
8
+ from transformers import AutoModel, AutoProcessor
9
+
10
+ class MedSigLIPConvNeXtFusion(nn.Module):
11
+ """
12
+ Your trained MedSigLIP-ConvNeXt fusion model from MILK10 challenge
13
+ Supports 11-class skin lesion classification
14
+ """
15
+
16
+ # Class names from your training
17
+ CLASS_NAMES = [
18
+ 'AKIEC', # Actinic Keratoses and Intraepithelial Carcinoma
19
+ 'BCC', # Basal Cell Carcinoma
20
+ 'BEN_OTH', # Benign Other
21
+ 'BKL', # Benign Keratosis-like Lesions
22
+ 'DF', # Dermatofibroma
23
+ 'INF', # Inflammatory
24
+ 'MAL_OTH', # Malignant Other
25
+ 'MEL', # Melanoma
26
+ 'NV', # Melanocytic Nevi
27
+ 'SCCKA', # Squamous Cell Carcinoma and Keratoacanthoma
28
+ 'VASC' # Vascular Lesions
29
+ ]
30
+
31
+ def __init__(
32
+ self,
33
+ num_classes: int = 11,
34
+ medsiglip_model: str = "google/medsiglip-base",
35
+ convnext_variant: str = "convnext_base",
36
+ fusion_dim: int = 512,
37
+ dropout: float = 0.3,
38
+ metadata_dim: int = 20 # For metadata features
39
+ ):
40
+ super().__init__()
41
+
42
+ self.num_classes = num_classes
43
+
44
+ # MedSigLIP Vision Encoder
45
+ print(f"Loading MedSigLIP: {medsiglip_model}")
46
+ self.medsiglip = AutoModel.from_pretrained(medsiglip_model)
47
+ self.medsiglip_processor = AutoProcessor.from_pretrained(medsiglip_model)
48
+
49
+ # ConvNeXt Backbone
50
+ print(f"Loading ConvNeXt: {convnext_variant}")
51
+ self.convnext = timm.create_model(
52
+ convnext_variant,
53
+ pretrained=True,
54
+ num_classes=0,
55
+ global_pool='avg'
56
+ )
57
+
58
+ # Feature dimensions
59
+ self.medsiglip_dim = self.medsiglip.config.hidden_size # 768
60
+ self.convnext_dim = self.convnext.num_features # 1024
61
+
62
+ # Optional metadata branch
63
+ self.use_metadata = metadata_dim > 0
64
+ if self.use_metadata:
65
+ self.metadata_encoder = nn.Sequential(
66
+ nn.Linear(metadata_dim, 64),
67
+ nn.LayerNorm(64),
68
+ nn.GELU(),
69
+ nn.Dropout(0.2),
70
+ nn.Linear(64, 32)
71
+ )
72
+ total_dim = self.medsiglip_dim + self.convnext_dim + 32
73
+ else:
74
+ total_dim = self.medsiglip_dim + self.convnext_dim
75
+
76
+ # Fusion layers
77
+ self.fusion = nn.Sequential(
78
+ nn.Linear(total_dim, fusion_dim),
79
+ nn.LayerNorm(fusion_dim),
80
+ nn.GELU(),
81
+ nn.Dropout(dropout),
82
+ nn.Linear(fusion_dim, fusion_dim // 2),
83
+ nn.LayerNorm(fusion_dim // 2),
84
+ nn.GELU(),
85
+ nn.Dropout(dropout)
86
+ )
87
+
88
+ # Classification head
89
+ self.classifier = nn.Linear(fusion_dim // 2, num_classes)
90
+
91
+ # Store intermediate features for Grad-CAM
92
+ self.convnext_features = None
93
+ self.medsiglip_features = None
94
+
95
+ # Register hooks
96
+ self.convnext.stages[-1].register_forward_hook(self._save_convnext_features)
97
+
98
+ def _save_convnext_features(self, module, input, output):
99
+ """Hook to save ConvNeXt feature maps for Grad-CAM"""
100
+ self.convnext_features = output
101
+
102
+ def forward(
103
+ self,
104
+ image: torch.Tensor,
105
+ metadata: Optional[torch.Tensor] = None
106
+ ) -> torch.Tensor:
107
+ """
108
+ Forward pass
109
+
110
+ Args:
111
+ image: [B, 3, H, W] tensor
112
+ metadata: [B, metadata_dim] optional metadata features
113
+
114
+ Returns:
115
+ logits: [B, num_classes]
116
+ """
117
+ # MedSigLIP features
118
+ medsiglip_out = self.medsiglip.vision_model(image)
119
+ medsiglip_features = medsiglip_out.pooler_output # [B, 768]
120
+
121
+ # ConvNeXt features
122
+ convnext_features = self.convnext(image) # [B, 1024]
123
+
124
+ # Concatenate vision features
125
+ fused = torch.cat([medsiglip_features, convnext_features], dim=1)
126
+
127
+ # Add metadata if available
128
+ if self.use_metadata and metadata is not None:
129
+ metadata_features = self.metadata_encoder(metadata)
130
+ fused = torch.cat([fused, metadata_features], dim=1)
131
+
132
+ # Fusion layers
133
+ fused = self.fusion(fused)
134
+
135
+ # Classification
136
+ logits = self.classifier(fused)
137
+
138
+ return logits
139
+
140
+ def predict(
141
+ self,
142
+ image: torch.Tensor,
143
+ metadata: Optional[torch.Tensor] = None,
144
+ top_k: int = 5
145
+ ) -> Dict:
146
+ """
147
+ Get predictions with probabilities
148
+
149
+ Args:
150
+ image: [B, 3, H, W] or [3, H, W]
151
+ metadata: Optional metadata features
152
+ top_k: Number of top predictions
153
+
154
+ Returns:
155
+ Dictionary with predictions and features
156
+ """
157
+ if image.dim() == 3:
158
+ image = image.unsqueeze(0)
159
+
160
+ self.eval()
161
+ with torch.no_grad():
162
+ logits = self.forward(image, metadata)
163
+ probs = torch.softmax(logits, dim=1)
164
+
165
+ # Top-k predictions
166
+ top_probs, top_indices = torch.topk(
167
+ probs,
168
+ k=min(top_k, self.num_classes),
169
+ dim=1
170
+ )
171
+
172
+ # Format results
173
+ predictions = []
174
+ for i in range(top_probs.size(1)):
175
+ predictions.append({
176
+ 'class': self.CLASS_NAMES[top_indices[0, i].item()],
177
+ 'probability': top_probs[0, i].item(),
178
+ 'class_idx': top_indices[0, i].item()
179
+ })
180
+
181
+ return {
182
+ 'predictions': predictions,
183
+ 'all_probabilities': probs[0].cpu().numpy(),
184
+ 'logits': logits[0].cpu().numpy(),
185
+ 'convnext_features': self.convnext_features,
186
+ 'medsiglip_features': self.medsiglip_features
187
+ }
188
+
189
+ @classmethod
190
+ def load_from_checkpoint(
191
+ cls,
192
+ medsiglip_path: str,
193
+ convnext_path: Optional[str] = None,
194
+ ensemble_weights: tuple = (0.6, 0.4),
195
+ device: str = 'cpu'
196
+ ):
197
+ """
198
+ Load model from your training checkpoints
199
+
200
+ Args:
201
+ medsiglip_path: Path to MedSigLIP model weights
202
+ convnext_path: Path to ConvNeXt model weights (optional)
203
+ ensemble_weights: (w_medsiglip, w_convnext)
204
+ device: Device to load on
205
+ """
206
+ model = cls(num_classes=11)
207
+
208
+ # Load MedSigLIP weights
209
+ print(f"Loading MedSigLIP from: {medsiglip_path}")
210
+ medsiglip_state = torch.load(medsiglip_path, map_location=device)
211
+
212
+ # Handle different checkpoint formats
213
+ if 'model_state_dict' in medsiglip_state:
214
+ model.load_state_dict(medsiglip_state['model_state_dict'])
215
+ else:
216
+ model.load_state_dict(medsiglip_state)
217
+
218
+ # Store ensemble weights for prediction fusion
219
+ model.ensemble_weights = ensemble_weights
220
+
221
+ model.to(device)
222
+ model.eval()
223
+
224
+ return model
models/monet_concepts.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # models/monet_concepts.py
2
+
3
+ import torch
4
+ import numpy as np
5
+ from typing import Dict, List
6
+ from dataclasses import dataclass
7
+
8
+ @dataclass
9
+ class ConceptScore:
10
+ """Single MONET concept with score and evidence"""
11
+ name: str
12
+ score: float
13
+ confidence: float
14
+ description: str
15
+ clinical_relevance: str # How this affects diagnosis
16
+
17
+ class MONETConceptScorer:
18
+ """
19
+ MONET concept scoring using your trained metadata patterns
20
+ Integrates the boosting logic from your ensemble code
21
+ """
22
+
23
+ # MONET concepts used in your training
24
+ CONCEPT_DEFINITIONS = {
25
+ 'MONET_ulceration_crust': {
26
+ 'description': 'Ulceration or crusting present',
27
+ 'high_in': ['SCCKA', 'BCC', 'MAL_OTH'],
28
+ 'low_in': ['NV', 'BKL'],
29
+ 'threshold_high': 0.50
30
+ },
31
+ 'MONET_erythema': {
32
+ 'description': 'Redness or inflammation',
33
+ 'high_in': ['INF', 'BCC', 'SCCKA'],
34
+ 'low_in': ['MEL', 'NV'],
35
+ 'threshold_high': 0.40
36
+ },
37
+ 'MONET_pigmented': {
38
+ 'description': 'Pigmentation present',
39
+ 'high_in': ['MEL', 'NV', 'BKL'],
40
+ 'low_in': ['BCC', 'SCCKA', 'INF'],
41
+ 'threshold_high': 0.55
42
+ },
43
+ 'MONET_vasculature_vessels': {
44
+ 'description': 'Vascular structures visible',
45
+ 'high_in': ['VASC', 'BCC'],
46
+ 'low_in': ['MEL', 'NV'],
47
+ 'threshold_high': 0.35
48
+ },
49
+ 'MONET_hair': {
50
+ 'description': 'Hair follicles present',
51
+ 'high_in': ['NV', 'BKL'],
52
+ 'low_in': ['BCC', 'MEL'],
53
+ 'threshold_high': 0.30
54
+ },
55
+ 'MONET_gel_water_drop_fluid_dermoscopy_liquid': {
56
+ 'description': 'Gel/fluid artifacts',
57
+ 'high_in': [],
58
+ 'low_in': [],
59
+ 'threshold_high': 0.40
60
+ },
61
+ 'MONET_skin_markings_pen_ink_purple_pen': {
62
+ 'description': 'Pen markings present',
63
+ 'high_in': [],
64
+ 'low_in': [],
65
+ 'threshold_high': 0.40
66
+ }
67
+ }
68
+
69
+ # Class-specific patterns from your metadata boosting
70
+ CLASS_PATTERNS = {
71
+ 'MAL_OTH': {
72
+ 'sex': 'male', # 88.9% male
73
+ 'site_preference': 'trunk',
74
+ 'age_range': (60, 80),
75
+ 'key_concepts': {'MONET_ulceration_crust': 0.35}
76
+ },
77
+ 'INF': {
78
+ 'key_concepts': {
79
+ 'MONET_erythema': 0.42,
80
+ 'MONET_pigmented': (None, 0.30) # Low pigmentation
81
+ }
82
+ },
83
+ 'BEN_OTH': {
84
+ 'site_preference': ['head', 'neck', 'face'], # 47.7%
85
+ 'key_concepts': {'MONET_pigmented': (0.30, 0.50)}
86
+ },
87
+ 'DF': {
88
+ 'site_preference': ['lower', 'leg', 'ankle', 'foot'], # 65.4%
89
+ 'age_range': (40, 65)
90
+ },
91
+ 'SCCKA': {
92
+ 'age_range': (65, None),
93
+ 'key_concepts': {
94
+ 'MONET_ulceration_crust': 0.50,
95
+ 'MONET_pigmented': (None, 0.15)
96
+ }
97
+ },
98
+ 'MEL': {
99
+ 'age_range': (55, None), # 61.8 years average
100
+ 'key_concepts': {'MONET_pigmented': 0.55}
101
+ },
102
+ 'NV': {
103
+ 'age_range': (None, 45), # 42.0 years average
104
+ 'key_concepts': {'MONET_pigmented': 0.55}
105
+ }
106
+ }
107
+
108
+ def __init__(self):
109
+ """Initialize MONET scorer with class patterns"""
110
+ self.class_names = [
111
+ 'AKIEC', 'BCC', 'BEN_OTH', 'BKL', 'DF',
112
+ 'INF', 'MAL_OTH', 'MEL', 'NV', 'SCCKA', 'VASC'
113
+ ]
114
+
115
+ def compute_concept_scores(
116
+ self,
117
+ metadata: Dict[str, float]
118
+ ) -> Dict[str, ConceptScore]:
119
+ """
120
+ Compute MONET concept scores from metadata
121
+
122
+ Args:
123
+ metadata: Dictionary with MONET scores, age, sex, site, etc.
124
+
125
+ Returns:
126
+ Dictionary of concept scores
127
+ """
128
+ concept_scores = {}
129
+
130
+ for concept_name, definition in self.CONCEPT_DEFINITIONS.items():
131
+ score = metadata.get(concept_name, 0.0)
132
+
133
+ # Determine confidence based on how extreme the score is
134
+ if score > definition['threshold_high']:
135
+ confidence = min((score - definition['threshold_high']) / 0.2, 1.0)
136
+ level = "HIGH"
137
+ elif score < 0.2:
138
+ confidence = min((0.2 - score) / 0.2, 1.0)
139
+ level = "LOW"
140
+ else:
141
+ confidence = 0.5
142
+ level = "MODERATE"
143
+
144
+ # Clinical relevance
145
+ if level == "HIGH":
146
+ relevant_classes = definition['high_in']
147
+ clinical_relevance = f"Supports: {', '.join(relevant_classes)}"
148
+ elif level == "LOW":
149
+ excluded_classes = definition['low_in']
150
+ clinical_relevance = f"Against: {', '.join(excluded_classes)}"
151
+ else:
152
+ clinical_relevance = "Non-specific"
153
+
154
+ concept_scores[concept_name] = ConceptScore(
155
+ name=concept_name.replace('MONET_', '').replace('_', ' ').title(),
156
+ score=score,
157
+ confidence=confidence,
158
+ description=f"{definition['description']} ({level})",
159
+ clinical_relevance=clinical_relevance
160
+ )
161
+
162
+ return concept_scores
163
+
164
+ def apply_metadata_boosting(
165
+ self,
166
+ probs: np.ndarray,
167
+ metadata: Dict
168
+ ) -> np.ndarray:
169
+ """
170
+ Apply your metadata boosting logic
171
+ This is directly from your ensemble optimization code
172
+
173
+ Args:
174
+ probs: [11] probability array
175
+ metadata: Dictionary with age, sex, site, MONET scores
176
+
177
+ Returns:
178
+ boosted_probs: [11] adjusted probabilities
179
+ """
180
+ boosted_probs = probs.copy()
181
+
182
+ # 1. MAL_OTH boosting
183
+ if metadata.get('sex') == 'male':
184
+ site = str(metadata.get('site', '')).lower()
185
+ if 'trunk' in site:
186
+ age = metadata.get('age_approx', 60)
187
+ ulceration = metadata.get('MONET_ulceration_crust', 0)
188
+
189
+ score = 0
190
+ score += 3 if metadata.get('sex') == 'male' else 0
191
+ score += 2 if 'trunk' in site else 0
192
+ score += 1 if 60 <= age <= 80 else 0
193
+ score += 2 if ulceration > 0.35 else 0
194
+
195
+ confidence = score / 8.0
196
+ if confidence > 0.5:
197
+ boosted_probs[6] *= (1.0 + confidence) # MAL_OTH index
198
+
199
+ # 2. INF boosting
200
+ erythema = metadata.get('MONET_erythema', 0)
201
+ pigmentation = metadata.get('MONET_pigmented', 0)
202
+
203
+ if erythema > 0.42 and pigmentation < 0.30:
204
+ confidence = min((erythema - 0.42) / 0.10 + 0.5, 1.0)
205
+ boosted_probs[5] *= (1.0 + confidence * 0.8) # INF index
206
+
207
+ # 3. BEN_OTH boosting
208
+ site = str(metadata.get('site', '')).lower()
209
+ is_head_neck = any(x in site for x in ['head', 'neck', 'face'])
210
+
211
+ if is_head_neck and 0.30 < pigmentation < 0.50:
212
+ ulceration = metadata.get('MONET_ulceration_crust', 0)
213
+ confidence = 0.7 if ulceration < 0.30 else 0.4
214
+ boosted_probs[2] *= (1.0 + confidence * 0.5) # BEN_OTH index
215
+
216
+ # 4. DF boosting
217
+ is_lower_ext = any(x in site for x in ['lower', 'leg', 'ankle', 'foot'])
218
+
219
+ if is_lower_ext:
220
+ age = metadata.get('age_approx', 60)
221
+ if 40 <= age <= 65:
222
+ boosted_probs[4] *= 1.8 # DF index
223
+ elif 30 <= age <= 75:
224
+ boosted_probs[4] *= 1.5
225
+
226
+ # 5. SCCKA boosting
227
+ ulceration = metadata.get('MONET_ulceration_crust', 0)
228
+ age = metadata.get('age_approx', 60)
229
+
230
+ if ulceration > 0.50 and age >= 65 and pigmentation < 0.15:
231
+ boosted_probs[9] *= 1.9 # SCCKA index
232
+ elif ulceration > 0.45 and age >= 60 and pigmentation < 0.20:
233
+ boosted_probs[9] *= 1.5
234
+
235
+ # 6. MEL vs NV age separation
236
+ if pigmentation > 0.55:
237
+ if age >= 55:
238
+ age_score = min((age - 55) / 20.0, 1.0)
239
+ boosted_probs[7] *= (1.0 + age_score * 0.5) # MEL
240
+ boosted_probs[8] *= (1.0 - age_score * 0.3) # NV
241
+ elif age <= 45:
242
+ age_score = min((45 - age) / 30.0, 1.0)
243
+ boosted_probs[7] *= (1.0 - age_score * 0.3) # MEL
244
+ boosted_probs[8] *= (1.0 + age_score * 0.5) # NV
245
+
246
+ # 7. Exclusions based on pigmentation/erythema
247
+ if pigmentation > 0.50:
248
+ boosted_probs[0] *= 0.7 # AKIEC
249
+ boosted_probs[1] *= 0.6 # BCC
250
+ boosted_probs[5] *= 0.5 # INF
251
+ boosted_probs[9] *= 0.3 # SCCKA
252
+
253
+ if erythema > 0.40:
254
+ boosted_probs[7] *= 0.7 # MEL
255
+ boosted_probs[8] *= 0.7 # NV
256
+
257
+ if pigmentation < 0.20:
258
+ boosted_probs[7] *= 0.5 # MEL
259
+ boosted_probs[8] *= 0.5 # NV
260
+
261
+ # Renormalize
262
+ return boosted_probs / boosted_probs.sum()
263
+
264
+ def explain_prediction(
265
+ self,
266
+ probs: np.ndarray,
267
+ concept_scores: Dict[str, ConceptScore],
268
+ metadata: Dict
269
+ ) -> str:
270
+ """
271
+ Generate natural language explanation
272
+
273
+ Args:
274
+ probs: Class probabilities
275
+ concept_scores: MONET concept scores
276
+ metadata: Clinical metadata
277
+
278
+ Returns:
279
+ Natural language explanation
280
+ """
281
+ predicted_idx = np.argmax(probs)
282
+ predicted_class = self.class_names[predicted_idx]
283
+ confidence = probs[predicted_idx]
284
+
285
+ explanation = f"**Primary Diagnosis: {predicted_class}**\n"
286
+ explanation += f"Confidence: {confidence:.1%}\n\n"
287
+
288
+ # Key MONET features
289
+ explanation += "**Key Dermoscopic Features:**\n"
290
+
291
+ sorted_concepts = sorted(
292
+ concept_scores.values(),
293
+ key=lambda x: x.score * x.confidence,
294
+ reverse=True
295
+ )
296
+
297
+ for i, concept in enumerate(sorted_concepts[:5], 1):
298
+ if concept.score > 0.3 or concept.score < 0.2:
299
+ explanation += f"{i}. {concept.name}: {concept.score:.2f} - {concept.description}\n"
300
+ if concept.clinical_relevance != "Non-specific":
301
+ explanation += f" → {concept.clinical_relevance}\n"
302
+
303
+ # Clinical context
304
+ explanation += "\n**Clinical Context:**\n"
305
+ if 'age_approx' in metadata:
306
+ explanation += f"• Age: {metadata['age_approx']} years\n"
307
+ if 'sex' in metadata:
308
+ explanation += f"• Sex: {metadata['sex']}\n"
309
+ if 'site' in metadata:
310
+ explanation += f"• Location: {metadata['site']}\n"
311
+
312
+ return explanation
313
+
314
+ def get_top_concepts(
315
+ self,
316
+ concept_scores: Dict[str, ConceptScore],
317
+ top_k: int = 5,
318
+ min_score: float = 0.3
319
+ ) -> List[ConceptScore]:
320
+ """Get top-k most important concepts"""
321
+ filtered = [
322
+ cs for cs in concept_scores.values()
323
+ if cs.score >= min_score or cs.score < 0.2 # High or low
324
+ ]
325
+
326
+ sorted_concepts = sorted(
327
+ filtered,
328
+ key=lambda x: x.score * x.confidence,
329
+ reverse=True
330
+ )
331
+
332
+ return sorted_concepts[:top_k]
models/monet_tool.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MONET Tool - Skin lesion feature extraction using MONET model
3
+ Correct implementation based on MONET tutorial: automatic_concept_annotation.ipynb
4
+ """
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import numpy as np
9
+ import scipy.special
10
+ from PIL import Image
11
+ from typing import Optional, Dict, List
12
+ import torchvision.transforms as T
13
+
14
+
15
+ # The 7 MONET feature columns expected by ConvNeXt
16
+ MONET_FEATURES = [
17
+ "MONET_ulceration_crust",
18
+ "MONET_hair",
19
+ "MONET_vasculature_vessels",
20
+ "MONET_erythema",
21
+ "MONET_pigmented",
22
+ "MONET_gel_water_drop_fluid_dermoscopy_liquid",
23
+ "MONET_skin_markings_pen_ink_purple_pen",
24
+ ]
25
+
26
+ # Concept terms for each MONET feature (multiple synonyms improve detection)
27
+ MONET_CONCEPT_TERMS = {
28
+ "MONET_ulceration_crust": ["ulceration", "crust", "crusting", "ulcer"],
29
+ "MONET_hair": ["hair", "hairy"],
30
+ "MONET_vasculature_vessels": ["blood vessels", "vasculature", "vessels", "telangiectasia"],
31
+ "MONET_erythema": ["erythema", "redness", "red"],
32
+ "MONET_pigmented": ["pigmented", "pigmentation", "melanin", "brown"],
33
+ "MONET_gel_water_drop_fluid_dermoscopy_liquid": ["dermoscopy gel", "fluid", "water drop", "immersion fluid"],
34
+ "MONET_skin_markings_pen_ink_purple_pen": ["pen marking", "ink", "surgical marking", "purple pen"],
35
+ }
36
+
37
+ # Prompt templates (from MONET paper)
38
+ PROMPT_TEMPLATES = [
39
+ "This is skin image of {}",
40
+ "This is dermatology image of {}",
41
+ "This is image of {}",
42
+ ]
43
+
44
+ # Reference prompts (baseline for contrastive scoring)
45
+ PROMPT_REFS = [
46
+ ["This is skin image"],
47
+ ["This is dermatology image"],
48
+ ["This is image"],
49
+ ]
50
+
51
+
52
+ def get_transform(n_px=224):
53
+ """Get MONET preprocessing transform"""
54
+ def convert_image_to_rgb(image):
55
+ return image.convert("RGB")
56
+
57
+ return T.Compose([
58
+ T.Resize(n_px, interpolation=T.InterpolationMode.BICUBIC),
59
+ T.CenterCrop(n_px),
60
+ convert_image_to_rgb,
61
+ T.ToTensor(),
62
+ T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
63
+ ])
64
+
65
+
66
+ class MonetTool:
67
+ """
68
+ MONET tool for extracting concept presence scores from skin lesion images.
69
+ Uses the proper contrastive scoring method from the MONET paper.
70
+ """
71
+
72
+ def __init__(self, device: Optional[str] = None, use_hf: bool = True):
73
+ """
74
+ Args:
75
+ device: Device to run on (cuda, mps, cpu)
76
+ use_hf: Use HuggingFace implementation (True) or original CLIP (False)
77
+ """
78
+ self.model = None
79
+ self.processor = None
80
+ self.device = device
81
+ self.use_hf = use_hf
82
+ self.loaded = False
83
+ self.transform = get_transform(224)
84
+
85
+ # Cache for concept embeddings
86
+ self._concept_embeddings = {}
87
+
88
+ def load(self):
89
+ """Load MONET model"""
90
+ if self.loaded:
91
+ return
92
+
93
+ # Determine device
94
+ if self.device is None:
95
+ if torch.cuda.is_available():
96
+ self.device = "cuda:0"
97
+ elif torch.backends.mps.is_available():
98
+ self.device = "mps"
99
+ else:
100
+ self.device = "cpu"
101
+
102
+ if self.use_hf:
103
+ # HuggingFace implementation
104
+ from transformers import AutoProcessor, AutoModelForZeroShotImageClassification
105
+
106
+ self.processor = AutoProcessor.from_pretrained("chanwkim/monet")
107
+ self.model = AutoModelForZeroShotImageClassification.from_pretrained("chanwkim/monet")
108
+ self.model.to(self.device)
109
+ self.model.eval()
110
+ else:
111
+ # Original CLIP implementation
112
+ import clip
113
+
114
+ self.model, _ = clip.load("ViT-L/14", device=self.device, jit=False)
115
+ self.model.load_state_dict(
116
+ torch.hub.load_state_dict_from_url(
117
+ "https://aimslab.cs.washington.edu/MONET/weight_clip.pt"
118
+ )
119
+ )
120
+ self.model.eval()
121
+
122
+ self.loaded = True
123
+
124
+ # Pre-compute concept embeddings for all MONET features
125
+ self._precompute_concept_embeddings()
126
+
127
+ def _encode_text(self, text_list: List[str]) -> torch.Tensor:
128
+ """Encode text to embeddings"""
129
+ if self.use_hf:
130
+ inputs = self.processor(text=text_list, return_tensors="pt", padding=True)
131
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
132
+ with torch.no_grad():
133
+ embeddings = self.model.get_text_features(**inputs)
134
+ else:
135
+ import clip
136
+ tokens = clip.tokenize(text_list, truncate=True).to(self.device)
137
+ with torch.no_grad():
138
+ embeddings = self.model.encode_text(tokens)
139
+
140
+ return embeddings.cpu()
141
+
142
+ def _encode_image(self, image: Image.Image) -> torch.Tensor:
143
+ """Encode image to embedding"""
144
+ image_tensor = self.transform(image).unsqueeze(0).to(self.device)
145
+
146
+ if self.use_hf:
147
+ with torch.no_grad():
148
+ embedding = self.model.get_image_features(image_tensor)
149
+ else:
150
+ with torch.no_grad():
151
+ embedding = self.model.encode_image(image_tensor)
152
+
153
+ return embedding.cpu()
154
+
155
+ def _precompute_concept_embeddings(self):
156
+ """Pre-compute embeddings for all MONET concepts"""
157
+ for feature_name, concept_terms in MONET_CONCEPT_TERMS.items():
158
+ self._concept_embeddings[feature_name] = self._get_concept_embedding(concept_terms)
159
+
160
+ def _get_concept_embedding(self, concept_terms: List[str]) -> Dict:
161
+ """
162
+ Generate prompt embeddings for a concept using multiple templates.
163
+
164
+ Args:
165
+ concept_terms: List of synonymous terms for the concept
166
+
167
+ Returns:
168
+ dict with target and reference embeddings
169
+ """
170
+ # Target prompts: "This is skin image of {term}"
171
+ prompt_target = [
172
+ [template.format(term) for term in concept_terms]
173
+ for template in PROMPT_TEMPLATES
174
+ ]
175
+
176
+ # Flatten and encode
177
+ prompt_target_flat = [p for template_prompts in prompt_target for p in template_prompts]
178
+ target_embeddings = self._encode_text(prompt_target_flat)
179
+
180
+ # Reshape to [num_templates, num_terms, embed_dim]
181
+ num_templates = len(PROMPT_TEMPLATES)
182
+ num_terms = len(concept_terms)
183
+ embed_dim = target_embeddings.shape[-1]
184
+ target_embeddings = target_embeddings.view(num_templates, num_terms, embed_dim)
185
+
186
+ # Normalize
187
+ target_embeddings_norm = F.normalize(target_embeddings, dim=2)
188
+
189
+ # Reference prompts: "This is skin image"
190
+ prompt_ref_flat = [p for ref_list in PROMPT_REFS for p in ref_list]
191
+ ref_embeddings = self._encode_text(prompt_ref_flat)
192
+ ref_embeddings = ref_embeddings.view(num_templates, -1, embed_dim)
193
+ ref_embeddings_norm = F.normalize(ref_embeddings, dim=2)
194
+
195
+ return {
196
+ "target_embedding_norm": target_embeddings_norm,
197
+ "ref_embedding_norm": ref_embeddings_norm,
198
+ }
199
+
200
+ def _calculate_concept_score(
201
+ self,
202
+ image_features_norm: torch.Tensor,
203
+ concept_embedding: Dict,
204
+ temp: float = 1 / np.exp(4.5944)
205
+ ) -> float:
206
+ """
207
+ Calculate concept presence score using contrastive comparison.
208
+
209
+ Args:
210
+ image_features_norm: Normalized image embedding [1, embed_dim]
211
+ concept_embedding: Dict with target and reference embeddings
212
+ temp: Temperature for softmax
213
+
214
+ Returns:
215
+ Concept presence score (0-1)
216
+ """
217
+ target_emb = concept_embedding["target_embedding_norm"].float()
218
+ ref_emb = concept_embedding["ref_embedding_norm"].float()
219
+
220
+ # Similarity: [num_templates, num_terms] @ [embed_dim, 1] -> [num_templates, num_terms, 1]
221
+ target_similarity = target_emb @ image_features_norm.T.float()
222
+ ref_similarity = ref_emb @ image_features_norm.T.float()
223
+
224
+ # Mean over terms for each template
225
+ target_mean = target_similarity.mean(dim=1).squeeze() # [num_templates]
226
+ ref_mean = ref_similarity.mean(dim=1).squeeze() # [num_templates]
227
+
228
+ # Softmax between target and reference (contrastive scoring)
229
+ scores = scipy.special.softmax(
230
+ np.array([target_mean.numpy() / temp, ref_mean.numpy() / temp]),
231
+ axis=0
232
+ )
233
+
234
+ # Return mean of target scores across templates
235
+ return float(scores[0].mean())
236
+
237
+ def extract_features(self, image: Image.Image) -> Dict[str, float]:
238
+ """
239
+ Extract MONET feature scores from a skin lesion image.
240
+
241
+ Args:
242
+ image: PIL Image to analyze
243
+
244
+ Returns:
245
+ dict with 7 MONET feature scores (0-1 range)
246
+ """
247
+ if not self.loaded:
248
+ self.load()
249
+
250
+ # Ensure RGB
251
+ if image.mode != "RGB":
252
+ image = image.convert("RGB")
253
+
254
+ # Get image embedding
255
+ image_features = self._encode_image(image)
256
+ image_features_norm = F.normalize(image_features, dim=1)
257
+
258
+ # Calculate score for each MONET feature
259
+ features = {}
260
+ for feature_name in MONET_FEATURES:
261
+ concept_emb = self._concept_embeddings[feature_name]
262
+ score = self._calculate_concept_score(image_features_norm, concept_emb)
263
+ features[feature_name] = score
264
+
265
+ return features
266
+
267
+ def get_feature_vector(self, image: Image.Image) -> List[float]:
268
+ """Get MONET features as a list in the expected order."""
269
+ features = self.extract_features(image)
270
+ return [features[f] for f in MONET_FEATURES]
271
+
272
+ def get_feature_tensor(self, image: Image.Image) -> torch.Tensor:
273
+ """Get MONET features as a PyTorch tensor."""
274
+ return torch.tensor(self.get_feature_vector(image), dtype=torch.float32)
275
+
276
+ def describe_features(self, features: Dict[str, float], threshold: float = 0.6) -> str:
277
+ """Generate a natural language description of the MONET features."""
278
+ descriptions = {
279
+ "MONET_ulceration_crust": "ulceration or crusting",
280
+ "MONET_hair": "visible hair",
281
+ "MONET_vasculature_vessels": "visible blood vessels",
282
+ "MONET_erythema": "erythema (redness)",
283
+ "MONET_pigmented": "pigmentation",
284
+ "MONET_gel_water_drop_fluid_dermoscopy_liquid": "dermoscopy gel/fluid",
285
+ "MONET_skin_markings_pen_ink_purple_pen": "pen markings",
286
+ }
287
+
288
+ present = []
289
+ for feature, score in features.items():
290
+ if score >= threshold:
291
+ desc = descriptions.get(feature, feature)
292
+ present.append(f"{desc} ({score:.0%})")
293
+
294
+ if not present:
295
+ # Show top features even if below threshold
296
+ sorted_features = sorted(features.items(), key=lambda x: x[1], reverse=True)[:3]
297
+ present = [f"{descriptions.get(f, f)} ({s:.0%})" for f, s in sorted_features]
298
+
299
+ return "Detected features: " + ", ".join(present)
300
+
301
+ def analyze(self, image: Image.Image) -> Dict:
302
+ """Full analysis returning features, vector, and description."""
303
+ features = self.extract_features(image)
304
+ vector = [features[f] for f in MONET_FEATURES]
305
+ description = self.describe_features(features)
306
+
307
+ return {
308
+ "features": features,
309
+ "vector": vector,
310
+ "description": description,
311
+ "feature_names": MONET_FEATURES,
312
+ }
313
+
314
+ def __call__(self, image: Image.Image) -> Dict:
315
+ """Shorthand for analyze()"""
316
+ return self.analyze(image)
317
+
318
+
319
+ # Singleton instance
320
+ _monet_instance = None
321
+
322
+
323
+ def get_monet_tool() -> MonetTool:
324
+ """Get or create MONET tool instance"""
325
+ global _monet_instance
326
+ if _monet_instance is None:
327
+ _monet_instance = MonetTool()
328
+ return _monet_instance
329
+
330
+
331
+ if __name__ == "__main__":
332
+ import sys
333
+
334
+ print("MONET Tool Test (Correct Implementation)")
335
+ print("=" * 50)
336
+
337
+ tool = MonetTool(use_hf=True)
338
+ print("Loading model...")
339
+ tool.load()
340
+ print("Model loaded!")
341
+
342
+ if len(sys.argv) > 1:
343
+ image_path = sys.argv[1]
344
+ print(f"\nAnalyzing: {image_path}")
345
+ image = Image.open(image_path).convert("RGB")
346
+ result = tool.analyze(image)
347
+
348
+ print("\nMONET Features (Contrastive Scores):")
349
+ for name, score in result["features"].items():
350
+ bar = "█" * int(score * 20)
351
+ print(f" {name}: {score:.3f} {bar}")
352
+
353
+ print(f"\nDescription: {result['description']}")
354
+ print(f"\nVector: {[f'{v:.3f}' for v in result['vector']]}")
models/overlay_tool.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Overlay Tool - Generates visual markers for biopsy sites and excision margins
3
+ """
4
+
5
+ import io
6
+ import tempfile
7
+ from typing import Tuple, Optional, Dict, Any
8
+ from PIL import Image, ImageDraw, ImageFont
9
+
10
+
11
+ class OverlayTool:
12
+ """
13
+ Generates image overlays for clinical decision visualization:
14
+ - Biopsy site markers (circles)
15
+ - Excision margins (dashed outlines with margin indicators)
16
+ """
17
+
18
+ # Colors for different marker types
19
+ COLORS = {
20
+ 'biopsy': (255, 69, 0, 200), # Orange-red with alpha
21
+ 'excision': (220, 20, 60, 200), # Crimson with alpha
22
+ 'margin': (255, 215, 0, 180), # Gold for margin line
23
+ 'text': (255, 255, 255, 255), # White text
24
+ 'text_bg': (0, 0, 0, 180), # Semi-transparent black bg
25
+ }
26
+
27
+ def __init__(self):
28
+ self.loaded = True
29
+
30
+ def generate_biopsy_overlay(
31
+ self,
32
+ image: Image.Image,
33
+ center_x: float,
34
+ center_y: float,
35
+ radius: float = 0.05,
36
+ label: str = "Biopsy Site"
37
+ ) -> Dict[str, Any]:
38
+ """
39
+ Generate biopsy site overlay with circle marker.
40
+
41
+ Args:
42
+ image: PIL Image
43
+ center_x: X coordinate as fraction (0-1) of image width
44
+ center_y: Y coordinate as fraction (0-1) of image height
45
+ radius: Radius as fraction of image width
46
+ label: Text label for the marker
47
+
48
+ Returns:
49
+ Dict with overlay image and metadata
50
+ """
51
+ # Convert to RGBA for transparency
52
+ img = image.convert("RGBA")
53
+ width, height = img.size
54
+
55
+ # Create overlay layer
56
+ overlay = Image.new("RGBA", img.size, (0, 0, 0, 0))
57
+ draw = ImageDraw.Draw(overlay)
58
+
59
+ # Calculate pixel coordinates
60
+ cx = int(center_x * width)
61
+ cy = int(center_y * height)
62
+ r = int(radius * width)
63
+
64
+ # Draw outer circle (thicker)
65
+ for offset in range(3):
66
+ draw.ellipse(
67
+ [cx - r - offset, cy - r - offset, cx + r + offset, cy + r + offset],
68
+ outline=self.COLORS['biopsy'],
69
+ width=2
70
+ )
71
+
72
+ # Draw crosshairs
73
+ line_len = r // 2
74
+ draw.line([(cx - line_len, cy), (cx + line_len, cy)],
75
+ fill=self.COLORS['biopsy'], width=2)
76
+ draw.line([(cx, cy - line_len), (cx, cy + line_len)],
77
+ fill=self.COLORS['biopsy'], width=2)
78
+
79
+ # Draw label with background
80
+ try:
81
+ font = ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", 14)
82
+ except:
83
+ font = ImageFont.load_default()
84
+
85
+ text_bbox = draw.textbbox((0, 0), label, font=font)
86
+ text_width = text_bbox[2] - text_bbox[0]
87
+ text_height = text_bbox[3] - text_bbox[1]
88
+
89
+ text_x = cx - text_width // 2
90
+ text_y = cy + r + 10
91
+
92
+ # Background rectangle for text
93
+ padding = 4
94
+ draw.rectangle(
95
+ [text_x - padding, text_y - padding,
96
+ text_x + text_width + padding, text_y + text_height + padding],
97
+ fill=self.COLORS['text_bg']
98
+ )
99
+ draw.text((text_x, text_y), label, fill=self.COLORS['text'], font=font)
100
+
101
+ # Composite
102
+ result = Image.alpha_composite(img, overlay)
103
+
104
+ # Save to temp file
105
+ temp_file = tempfile.NamedTemporaryFile(suffix="_biopsy_overlay.png", delete=False)
106
+ result.save(temp_file.name, "PNG")
107
+ temp_file.close()
108
+
109
+ return {
110
+ "overlay": result,
111
+ "path": temp_file.name,
112
+ "type": "biopsy",
113
+ "coordinates": {
114
+ "center_x": center_x,
115
+ "center_y": center_y,
116
+ "radius": radius
117
+ },
118
+ "label": label
119
+ }
120
+
121
+ def generate_excision_overlay(
122
+ self,
123
+ image: Image.Image,
124
+ center_x: float,
125
+ center_y: float,
126
+ lesion_radius: float,
127
+ margin_mm: int = 5,
128
+ pixels_per_mm: float = 10.0,
129
+ label: str = "Excision Margin"
130
+ ) -> Dict[str, Any]:
131
+ """
132
+ Generate excision margin overlay with inner (lesion) and outer (margin) boundaries.
133
+
134
+ Args:
135
+ image: PIL Image
136
+ center_x: X coordinate as fraction (0-1)
137
+ center_y: Y coordinate as fraction (0-1)
138
+ lesion_radius: Lesion radius as fraction of image width
139
+ margin_mm: Excision margin in millimeters
140
+ pixels_per_mm: Estimated pixels per mm (for margin calculation)
141
+ label: Text label
142
+
143
+ Returns:
144
+ Dict with overlay image and metadata
145
+ """
146
+ img = image.convert("RGBA")
147
+ width, height = img.size
148
+
149
+ overlay = Image.new("RGBA", img.size, (0, 0, 0, 0))
150
+ draw = ImageDraw.Draw(overlay)
151
+
152
+ # Calculate coordinates
153
+ cx = int(center_x * width)
154
+ cy = int(center_y * height)
155
+ inner_r = int(lesion_radius * width)
156
+
157
+ # Calculate margin in pixels
158
+ margin_px = int(margin_mm * pixels_per_mm)
159
+ outer_r = inner_r + margin_px
160
+
161
+ # Draw outer margin (dashed effect using multiple arcs)
162
+ dash_length = 10
163
+ for angle in range(0, 360, dash_length * 2):
164
+ draw.arc(
165
+ [cx - outer_r, cy - outer_r, cx + outer_r, cy + outer_r],
166
+ start=angle,
167
+ end=angle + dash_length,
168
+ fill=self.COLORS['margin'],
169
+ width=3
170
+ )
171
+
172
+ # Draw inner lesion boundary (solid)
173
+ draw.ellipse(
174
+ [cx - inner_r, cy - inner_r, cx + inner_r, cy + inner_r],
175
+ outline=self.COLORS['excision'],
176
+ width=2
177
+ )
178
+
179
+ # Draw margin indicator lines (radial)
180
+ for angle in [0, 90, 180, 270]:
181
+ import math
182
+ rad = math.radians(angle)
183
+ inner_x = cx + int(inner_r * math.cos(rad))
184
+ inner_y = cy + int(inner_r * math.sin(rad))
185
+ outer_x = cx + int(outer_r * math.cos(rad))
186
+ outer_y = cy + int(outer_r * math.sin(rad))
187
+ draw.line([(inner_x, inner_y), (outer_x, outer_y)],
188
+ fill=self.COLORS['margin'], width=2)
189
+
190
+ # Draw labels
191
+ try:
192
+ font = ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", 12)
193
+ font_small = ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", 10)
194
+ except:
195
+ font = ImageFont.load_default()
196
+ font_small = font
197
+
198
+ # Main label
199
+ text_bbox = draw.textbbox((0, 0), label, font=font)
200
+ text_width = text_bbox[2] - text_bbox[0]
201
+ text_height = text_bbox[3] - text_bbox[1]
202
+
203
+ text_x = cx - text_width // 2
204
+ text_y = cy + outer_r + 15
205
+
206
+ padding = 4
207
+ draw.rectangle(
208
+ [text_x - padding, text_y - padding,
209
+ text_x + text_width + padding, text_y + text_height + padding],
210
+ fill=self.COLORS['text_bg']
211
+ )
212
+ draw.text((text_x, text_y), label, fill=self.COLORS['text'], font=font)
213
+
214
+ # Margin measurement label
215
+ margin_label = f"{margin_mm}mm margin"
216
+ margin_bbox = draw.textbbox((0, 0), margin_label, font=font_small)
217
+ margin_width = margin_bbox[2] - margin_bbox[0]
218
+
219
+ margin_text_x = cx + outer_r + 5
220
+ margin_text_y = cy - 6
221
+
222
+ draw.rectangle(
223
+ [margin_text_x - 2, margin_text_y - 2,
224
+ margin_text_x + margin_width + 2, margin_text_y + 12],
225
+ fill=self.COLORS['text_bg']
226
+ )
227
+ draw.text((margin_text_x, margin_text_y), margin_label,
228
+ fill=self.COLORS['margin'], font=font_small)
229
+
230
+ # Composite
231
+ result = Image.alpha_composite(img, overlay)
232
+
233
+ temp_file = tempfile.NamedTemporaryFile(suffix="_excision_overlay.png", delete=False)
234
+ result.save(temp_file.name, "PNG")
235
+ temp_file.close()
236
+
237
+ return {
238
+ "overlay": result,
239
+ "path": temp_file.name,
240
+ "type": "excision",
241
+ "coordinates": {
242
+ "center_x": center_x,
243
+ "center_y": center_y,
244
+ "lesion_radius": lesion_radius,
245
+ "margin_mm": margin_mm,
246
+ "total_radius": outer_r / width
247
+ },
248
+ "label": label
249
+ }
250
+
251
+ def generate_comparison_overlay(
252
+ self,
253
+ image1: Image.Image,
254
+ image2: Image.Image,
255
+ label1: str = "Previous",
256
+ label2: str = "Current"
257
+ ) -> Dict[str, Any]:
258
+ """
259
+ Generate side-by-side comparison of two images for follow-up.
260
+
261
+ Args:
262
+ image1: First (previous) image
263
+ image2: Second (current) image
264
+ label1: Label for first image
265
+ label2: Label for second image
266
+
267
+ Returns:
268
+ Dict with comparison image and metadata
269
+ """
270
+ # Resize to same height
271
+ max_height = 400
272
+
273
+ # Calculate sizes maintaining aspect ratio
274
+ w1, h1 = image1.size
275
+ w2, h2 = image2.size
276
+
277
+ ratio1 = max_height / h1
278
+ ratio2 = max_height / h2
279
+
280
+ new_w1 = int(w1 * ratio1)
281
+ new_w2 = int(w2 * ratio2)
282
+
283
+ img1 = image1.resize((new_w1, max_height), Image.Resampling.LANCZOS)
284
+ img2 = image2.resize((new_w2, max_height), Image.Resampling.LANCZOS)
285
+
286
+ # Create comparison canvas
287
+ gap = 20
288
+ total_width = new_w1 + gap + new_w2
289
+ header_height = 30
290
+ total_height = max_height + header_height
291
+
292
+ canvas = Image.new("RGB", (total_width, total_height), (255, 255, 255))
293
+ draw = ImageDraw.Draw(canvas)
294
+
295
+ # Draw labels
296
+ try:
297
+ font = ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", 14)
298
+ except:
299
+ font = ImageFont.load_default()
300
+
301
+ # Previous label
302
+ draw.rectangle([0, 0, new_w1, header_height], fill=(70, 130, 180))
303
+ bbox1 = draw.textbbox((0, 0), label1, font=font)
304
+ text_w1 = bbox1[2] - bbox1[0]
305
+ draw.text(((new_w1 - text_w1) // 2, 8), label1, fill=(255, 255, 255), font=font)
306
+
307
+ # Current label
308
+ draw.rectangle([new_w1 + gap, 0, total_width, header_height], fill=(60, 179, 113))
309
+ bbox2 = draw.textbbox((0, 0), label2, font=font)
310
+ text_w2 = bbox2[2] - bbox2[0]
311
+ draw.text((new_w1 + gap + (new_w2 - text_w2) // 2, 8), label2,
312
+ fill=(255, 255, 255), font=font)
313
+
314
+ # Paste images
315
+ canvas.paste(img1, (0, header_height))
316
+ canvas.paste(img2, (new_w1 + gap, header_height))
317
+
318
+ # Draw divider
319
+ draw.line([(new_w1 + gap // 2, header_height), (new_w1 + gap // 2, total_height)],
320
+ fill=(200, 200, 200), width=2)
321
+
322
+ temp_file = tempfile.NamedTemporaryFile(suffix="_comparison.png", delete=False)
323
+ canvas.save(temp_file.name, "PNG")
324
+ temp_file.close()
325
+
326
+ return {
327
+ "comparison": canvas,
328
+ "path": temp_file.name,
329
+ "type": "comparison"
330
+ }
331
+
332
+
333
+ def get_overlay_tool() -> OverlayTool:
334
+ """Get overlay tool instance"""
335
+ return OverlayTool()
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # requirements.txt
2
+
3
+ torch>=2.0.0
4
+ torchvision>=0.15.0
5
+ transformers>=4.40.0
6
+ timm>=0.9.0
7
+ gradio==4.44.0
8
+ gradio-client==1.3.0
9
+ opencv-python>=4.8.0
10
+ numpy>=1.24.0
11
+ Pillow>=10.0.0
12
+ sentencepiece>=0.1.99
13
+ accelerate>=0.25.0
14
+ protobuf>=4.0.0
15
+ mcp>=1.0.0 # installed via python3.11 (requires Python >=3.10)
test_models.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Test script to verify model loading"""
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import timm
7
+ from transformers import AutoModel, AutoProcessor
8
+ import numpy as np
9
+
10
+ DEVICE = "cpu"
11
+ print(f"Device: {DEVICE}")
12
+
13
+ # ConvNeXt model definition (matching checkpoint)
14
+ class ConvNeXtDualEncoder(nn.Module):
15
+ def __init__(self, model_name="convnext_base.fb_in22k_ft_in1k",
16
+ metadata_dim=19, num_classes=11, dropout=0.3):
17
+ super().__init__()
18
+ self.backbone = timm.create_model(model_name, pretrained=False, num_classes=0)
19
+ backbone_dim = self.backbone.num_features
20
+ self.meta_mlp = nn.Sequential(
21
+ nn.Linear(metadata_dim, 64), nn.LayerNorm(64), nn.GELU(), nn.Dropout(dropout)
22
+ )
23
+ fusion_dim = backbone_dim * 2 + 64
24
+ self.classifier = nn.Sequential(
25
+ nn.Linear(fusion_dim, 512), nn.LayerNorm(512), nn.GELU(), nn.Dropout(dropout),
26
+ nn.Linear(512, 256), nn.LayerNorm(256), nn.GELU(), nn.Dropout(dropout),
27
+ nn.Linear(256, num_classes)
28
+ )
29
+
30
+ def forward(self, clinical_img, derm_img=None, metadata=None):
31
+ clinical_features = self.backbone(clinical_img)
32
+ derm_features = self.backbone(derm_img) if derm_img is not None else clinical_features
33
+ if metadata is not None:
34
+ meta_features = self.meta_mlp(metadata)
35
+ else:
36
+ meta_features = torch.zeros(clinical_features.size(0), 64, device=clinical_features.device)
37
+ fused = torch.cat([clinical_features, derm_features, meta_features], dim=1)
38
+ return self.classifier(fused)
39
+
40
+
41
+ # MedSigLIP model definition
42
+ class MedSigLIPClassifier(nn.Module):
43
+ def __init__(self, num_classes=11, model_name="google/siglip-base-patch16-384"):
44
+ super().__init__()
45
+ self.siglip = AutoModel.from_pretrained(model_name)
46
+ self.processor = AutoProcessor.from_pretrained(model_name)
47
+ hidden_dim = self.siglip.config.vision_config.hidden_size
48
+ self.classifier = nn.Sequential(
49
+ nn.Linear(hidden_dim, 512), nn.LayerNorm(512), nn.GELU(), nn.Dropout(0.3),
50
+ nn.Linear(512, num_classes)
51
+ )
52
+ for param in self.siglip.parameters():
53
+ param.requires_grad = False
54
+
55
+ def forward(self, pixel_values):
56
+ vision_outputs = self.siglip.vision_model(pixel_values=pixel_values)
57
+ pooled_features = vision_outputs.pooler_output
58
+ return self.classifier(pooled_features)
59
+
60
+
61
+ if __name__ == "__main__":
62
+ print("\n[1/2] Loading ConvNeXt...")
63
+ convnext_model = ConvNeXtDualEncoder()
64
+ ckpt = torch.load("models/seed42_fold0.pt", map_location=DEVICE, weights_only=False)
65
+ convnext_model.load_state_dict(ckpt)
66
+ convnext_model.eval()
67
+ print(" ConvNeXt loaded!")
68
+
69
+ print("\n[2/2] Loading MedSigLIP...")
70
+ medsiglip_model = MedSigLIPClassifier()
71
+ medsiglip_model.eval()
72
+ print(" MedSigLIP loaded!")
73
+
74
+ # Quick inference test
75
+ print("\nTesting inference...")
76
+ dummy_img = torch.randn(1, 3, 384, 384)
77
+ with torch.no_grad():
78
+ convnext_out = convnext_model(dummy_img)
79
+ print(f" ConvNeXt output: {convnext_out.shape}")
80
+
81
+ dummy_pil = np.random.randint(0, 255, (384, 384, 3), dtype=np.uint8)
82
+ siglip_input = medsiglip_model.processor(images=[dummy_pil], return_tensors="pt")
83
+ siglip_out = medsiglip_model(siglip_input["pixel_values"])
84
+ print(f" MedSigLIP output: {siglip_out.shape}")
85
+
86
+ print("\nAll tests passed!")
web/index.html ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8" />
5
+ <link rel="icon" type="image/svg+xml" href="/vite.svg" />
6
+ <meta name="viewport" content="width=device-width, initial-scale=1.0" />
7
+ <title>SkinProAI</title>
8
+ <link rel="preconnect" href="https://fonts.googleapis.com">
9
+ <link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
10
+ <link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap" rel="stylesheet">
11
+ </head>
12
+ <body>
13
+ <div id="root"></div>
14
+ <script type="module" src="/src/main.tsx"></script>
15
+ </body>
16
+ </html>
web/package-lock.json ADDED
The diff for this file is too large to render. See raw diff
 
web/package.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "skinproai-web",
3
+ "private": true,
4
+ "version": "1.0.0",
5
+ "type": "module",
6
+ "scripts": {
7
+ "dev": "vite",
8
+ "build": "tsc && vite build",
9
+ "preview": "vite preview"
10
+ },
11
+ "dependencies": {
12
+ "react": "^18.2.0",
13
+ "react-dom": "^18.2.0",
14
+ "react-markdown": "^10.1.0",
15
+ "react-router-dom": "^6.20.0"
16
+ },
17
+ "devDependencies": {
18
+ "@types/react": "^18.2.0",
19
+ "@types/react-dom": "^18.2.0",
20
+ "@vitejs/plugin-react": "^4.2.0",
21
+ "typescript": "^5.3.0",
22
+ "vite": "^5.0.0"
23
+ }
24
+ }
web/src/App.tsx ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { BrowserRouter, Routes, Route } from 'react-router-dom';
2
+ import { PatientsPage } from './pages/PatientsPage';
3
+ import { ChatPage } from './pages/ChatPage';
4
+
5
+ export function App() {
6
+ return (
7
+ <BrowserRouter>
8
+ <Routes>
9
+ <Route path="/" element={<PatientsPage />} />
10
+ <Route path="/chat/:patientId" element={<ChatPage />} />
11
+ </Routes>
12
+ </BrowserRouter>
13
+ );
14
+ }
web/src/components/MessageContent.css ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* ─── Root ───────────────────────────────────────────────────────────────── */
2
+ .mc-root {
3
+ display: flex;
4
+ flex-direction: column;
5
+ gap: 6px;
6
+ font-size: 0.9375rem;
7
+ line-height: 1.6;
8
+ color: var(--gray-800);
9
+ width: 100%;
10
+ }
11
+
12
+ /* ─── Stage header ───────────────────────────────────────────────────────── */
13
+ .mc-stage {
14
+ font-size: 0.75rem;
15
+ font-weight: 600;
16
+ color: var(--primary);
17
+ text-transform: uppercase;
18
+ letter-spacing: 0.06em;
19
+ padding: 8px 0 2px;
20
+ border-top: 1px solid var(--gray-100);
21
+ margin-top: 4px;
22
+ }
23
+
24
+ .mc-stage:first-child {
25
+ border-top: none;
26
+ margin-top: 0;
27
+ padding-top: 0;
28
+ }
29
+
30
+ /* ─── Thinking text ──────────────────────────────────────────────────────── */
31
+ .mc-thinking {
32
+ font-size: 0.8125rem;
33
+ color: var(--gray-500);
34
+ font-style: italic;
35
+ }
36
+
37
+ /* ─── Response block (markdown) ──────────────────────────────────────────── */
38
+ .mc-response {
39
+ color: var(--gray-800);
40
+ }
41
+
42
+ .mc-response p {
43
+ margin: 0 0 8px;
44
+ }
45
+
46
+ .mc-response p:last-child {
47
+ margin-bottom: 0;
48
+ }
49
+
50
+ .mc-response strong {
51
+ font-weight: 600;
52
+ color: var(--gray-900);
53
+ }
54
+
55
+ .mc-response em {
56
+ font-style: italic;
57
+ }
58
+
59
+ .mc-response ul,
60
+ .mc-response ol {
61
+ margin: 4px 0 8px 20px;
62
+ padding: 0;
63
+ }
64
+
65
+ .mc-response li {
66
+ margin-bottom: 2px;
67
+ }
68
+
69
+ .mc-response h1,
70
+ .mc-response h2,
71
+ .mc-response h3,
72
+ .mc-response h4 {
73
+ font-size: 0.9375rem;
74
+ font-weight: 600;
75
+ color: var(--gray-900);
76
+ margin: 10px 0 4px;
77
+ }
78
+
79
+ .mc-response code {
80
+ font-family: monospace;
81
+ font-size: 0.875em;
82
+ background: var(--gray-100);
83
+ padding: 1px 5px;
84
+ border-radius: 4px;
85
+ }
86
+
87
+ /* ─── Tool output (monospace block) ─────────────────────────────────────── */
88
+ .mc-tool-output {
89
+ background: var(--gray-900);
90
+ border-radius: 8px;
91
+ overflow: hidden;
92
+ }
93
+
94
+ .mc-tool-output-label {
95
+ font-size: 0.6875rem;
96
+ font-weight: 600;
97
+ color: var(--gray-400);
98
+ text-transform: uppercase;
99
+ letter-spacing: 0.05em;
100
+ padding: 6px 12px 4px;
101
+ background: rgba(255, 255, 255, 0.05);
102
+ border-bottom: 1px solid rgba(255, 255, 255, 0.08);
103
+ }
104
+
105
+ .mc-tool-output pre {
106
+ margin: 0;
107
+ padding: 10px 12px;
108
+ font-family: 'SF Mono', 'Fira Code', monospace;
109
+ font-size: 0.8rem;
110
+ line-height: 1.5;
111
+ color: #e2e8f0;
112
+ white-space: pre;
113
+ overflow-x: auto;
114
+ }
115
+
116
+ /* ─── Image blocks (GradCAM, comparison) ────────────────────────────────── */
117
+ .mc-image-block {
118
+ margin-top: 4px;
119
+ }
120
+
121
+ .mc-image-label {
122
+ font-size: 0.75rem;
123
+ font-weight: 600;
124
+ color: var(--gray-500);
125
+ text-transform: uppercase;
126
+ letter-spacing: 0.05em;
127
+ margin-bottom: 6px;
128
+ }
129
+
130
+ .mc-gradcam-img {
131
+ width: 100%;
132
+ max-width: 380px;
133
+ border-radius: 10px;
134
+ border: 1px solid var(--gray-200);
135
+ display: block;
136
+ }
137
+
138
+ .mc-comparison-img {
139
+ width: 100%;
140
+ max-width: 560px;
141
+ border-radius: 10px;
142
+ border: 1px solid var(--gray-200);
143
+ display: block;
144
+ }
145
+
146
+ /* ─── GradCAM side-by-side comparison ───────────────────────────────────── */
147
+ .mc-gradcam-compare {
148
+ display: grid;
149
+ grid-template-columns: 1fr 1fr;
150
+ gap: 10px;
151
+ max-width: 560px;
152
+ }
153
+
154
+ .mc-gradcam-compare-item {
155
+ display: flex;
156
+ flex-direction: column;
157
+ gap: 4px;
158
+ }
159
+
160
+ .mc-gradcam-compare-title {
161
+ font-size: 0.75rem;
162
+ font-weight: 600;
163
+ color: var(--gray-600);
164
+ text-align: center;
165
+ }
166
+
167
+ .mc-gradcam-compare-img {
168
+ width: 100%;
169
+ border-radius: 8px;
170
+ border: 1px solid var(--gray-200);
171
+ display: block;
172
+ }
173
+
174
+ /* ─── Result / error / complete / observation ───────────────────────────── */
175
+ .mc-result {
176
+ background: linear-gradient(135deg, #f0fdf4, #dcfce7);
177
+ border: 1px solid #86efac;
178
+ border-radius: 8px;
179
+ padding: 8px 12px;
180
+ font-size: 0.875rem;
181
+ font-weight: 500;
182
+ color: #15803d;
183
+ }
184
+
185
+ .mc-error {
186
+ background: #fef2f2;
187
+ border: 1px solid #fca5a5;
188
+ border-radius: 8px;
189
+ padding: 8px 12px;
190
+ font-size: 0.875rem;
191
+ color: #dc2626;
192
+ }
193
+
194
+ .mc-complete {
195
+ font-size: 0.8rem;
196
+ color: var(--gray-400);
197
+ text-align: right;
198
+ }
199
+
200
+ .mc-observation {
201
+ font-size: 0.875rem;
202
+ color: var(--gray-600);
203
+ font-style: italic;
204
+ }
205
+
206
+ /* ─── Plain streaming text ───────────────────────────────────────────────── */
207
+ .mc-text {
208
+ white-space: pre-wrap;
209
+ word-break: break-word;
210
+ color: var(--gray-700);
211
+ font-size: 0.875rem;
212
+ }
213
+
214
+ /* ─── References ─────────────────────────────────────────────────────────── */
215
+ .mc-references {
216
+ border-top: 1px solid var(--gray-100);
217
+ padding-top: 8px;
218
+ margin-top: 4px;
219
+ }
220
+
221
+ .mc-references-title {
222
+ font-size: 0.75rem;
223
+ font-weight: 600;
224
+ color: var(--gray-500);
225
+ text-transform: uppercase;
226
+ letter-spacing: 0.05em;
227
+ margin-bottom: 4px;
228
+ }
229
+
230
+ .mc-ref-item {
231
+ font-size: 0.8125rem;
232
+ color: var(--gray-600);
233
+ line-height: 1.5;
234
+ }
235
+
236
+ .mc-ref-sup {
237
+ font-size: 0.6875rem;
238
+ vertical-align: super;
239
+ margin-right: 4px;
240
+ color: var(--primary);
241
+ font-weight: 600;
242
+ }
243
+
244
+ .mc-ref-source {
245
+ font-style: italic;
246
+ }
247
+
248
+ .mc-ref-page {
249
+ color: var(--gray-400);
250
+ }
web/src/components/MessageContent.tsx ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ReactMarkdown from 'react-markdown';
2
+ import './MessageContent.css';
3
+
4
+ // Serve any temp visualization image (GradCAM, comparison) through the API
5
+ const TEMP_IMG_URL = (path: string) =>
6
+ `/api/patients/gradcam?path=${encodeURIComponent(path)}`;
7
+
8
+ // ─── Types ─────────────────────────────────────────────────────────────────
9
+
10
+ type Segment =
11
+ | { type: 'stage'; label: string }
12
+ | { type: 'thinking'; content: string }
13
+ | { type: 'response'; content: string }
14
+ | { type: 'tool_output'; label: string; content: string }
15
+ | { type: 'gradcam'; path: string }
16
+ | { type: 'comparison'; path: string }
17
+ | { type: 'gradcam_compare'; path1: string; path2: string }
18
+ | { type: 'result'; content: string }
19
+ | { type: 'error'; content: string }
20
+ | { type: 'complete'; content: string }
21
+ | { type: 'references'; content: string }
22
+ | { type: 'observation'; content: string }
23
+ | { type: 'text'; content: string };
24
+
25
+ // ─── Parser ────────────────────────────────────────────────────────────────
26
+
27
+ // Splits raw text by all known complete tag patterns (capturing group preserves them)
28
+ const TAG_SPLIT_RE = new RegExp(
29
+ '(' +
30
+ [
31
+ '\\[STAGE:[^\\]]*\\][\\s\\S]*?\\[\\/STAGE\\]',
32
+ '\\[THINKING\\][\\s\\S]*?\\[\\/THINKING\\]',
33
+ '\\[RESPONSE\\][\\s\\S]*?\\[\\/RESPONSE\\]',
34
+ '\\[TOOL_OUTPUT:[^\\]]*\\][\\s\\S]*?\\[\\/TOOL_OUTPUT\\]',
35
+ '\\[GRADCAM_IMAGE:[^\\]]+\\]',
36
+ '\\[COMPARISON_IMAGE:[^\\]]+\\]',
37
+ '\\[GRADCAM_COMPARE:[^:\\]]+:[^\\]]+\\]',
38
+ '\\[RESULT\\][\\s\\S]*?\\[\\/RESULT\\]',
39
+ '\\[ERROR\\][\\s\\S]*?\\[\\/ERROR\\]',
40
+ '\\[COMPLETE\\][\\s\\S]*?\\[\\/COMPLETE\\]',
41
+ '\\[REFERENCES\\][\\s\\S]*?\\[\\/REFERENCES\\]',
42
+ '\\[OBSERVATION\\][\\s\\S]*?\\[\\/OBSERVATION\\]',
43
+ '\\[CONFIRM:[^\\]]*\\][\\s\\S]*?\\[\\/CONFIRM\\]',
44
+ ].join('|') +
45
+ ')',
46
+ 'g',
47
+ );
48
+
49
+ // Strips known opening tags that haven't yet been closed (mid-stream partial content)
50
+ function cleanStreamingText(text: string): string {
51
+ return text.replace(
52
+ /\[(STAGE:[^\]]*|THINKING|RESPONSE|TOOL_OUTPUT:[^\]]*|RESULT|ERROR|COMPLETE|REFERENCES|OBSERVATION|CONFIRM:[^\]]*)\]/g,
53
+ '',
54
+ );
55
+ }
56
+
57
+ function parseContent(raw: string): Segment[] {
58
+ const segments: Segment[] = [];
59
+
60
+ for (const part of raw.split(TAG_SPLIT_RE)) {
61
+ if (!part) continue;
62
+
63
+ let m: RegExpMatchArray | null;
64
+
65
+ if ((m = part.match(/^\[STAGE:([^\]]*)\]([\s\S]*)\[\/STAGE\]$/))) {
66
+ const label = m[2].trim();
67
+ if (label) segments.push({ type: 'stage', label });
68
+
69
+ } else if ((m = part.match(/^\[THINKING\]([\s\S]*)\[\/THINKING\]$/))) {
70
+ const c = m[1].trim();
71
+ if (c) segments.push({ type: 'thinking', content: c });
72
+
73
+ } else if ((m = part.match(/^\[RESPONSE\]([\s\S]*)\[\/RESPONSE\]$/))) {
74
+ const c = m[1].trim();
75
+ if (c) segments.push({ type: 'response', content: c });
76
+
77
+ } else if ((m = part.match(/^\[TOOL_OUTPUT:([^\]]*)\]([\s\S]*)\[\/TOOL_OUTPUT\]$/))) {
78
+ segments.push({ type: 'tool_output', label: m[1], content: m[2] });
79
+
80
+ } else if ((m = part.match(/^\[GRADCAM_IMAGE:([^\]]+)\]$/))) {
81
+ segments.push({ type: 'gradcam', path: m[1] });
82
+
83
+ } else if ((m = part.match(/^\[COMPARISON_IMAGE:([^\]]+)\]$/))) {
84
+ segments.push({ type: 'comparison', path: m[1] });
85
+
86
+ } else if ((m = part.match(/^\[GRADCAM_COMPARE:([^:\]]+):([^\]]+)\]$/))) {
87
+ segments.push({ type: 'gradcam_compare', path1: m[1], path2: m[2] });
88
+
89
+ } else if ((m = part.match(/^\[RESULT\]([\s\S]*)\[\/RESULT\]$/))) {
90
+ const c = m[1].trim();
91
+ if (c) segments.push({ type: 'result', content: c });
92
+
93
+ } else if ((m = part.match(/^\[ERROR\]([\s\S]*)\[\/ERROR\]$/))) {
94
+ const c = m[1].trim();
95
+ if (c) segments.push({ type: 'error', content: c });
96
+
97
+ } else if ((m = part.match(/^\[COMPLETE\]([\s\S]*)\[\/COMPLETE\]$/))) {
98
+ const c = m[1].trim();
99
+ if (c) segments.push({ type: 'complete', content: c });
100
+
101
+ } else if ((m = part.match(/^\[REFERENCES\]([\s\S]*)\[\/REFERENCES\]$/))) {
102
+ segments.push({ type: 'references', content: m[1].trim() });
103
+
104
+ } else if ((m = part.match(/^\[OBSERVATION\]([\s\S]*)\[\/OBSERVATION\]$/))) {
105
+ const c = m[1].trim();
106
+ if (c) segments.push({ type: 'observation', content: c });
107
+
108
+ } else if ((m = part.match(/^\[CONFIRM:[^\]]*\]([\s\S]*)\[\/CONFIRM\]$/))) {
109
+ const c = m[1].trim();
110
+ if (c) segments.push({ type: 'result', content: c });
111
+
112
+ } else {
113
+ // Plain text (may be mid-stream with incomplete opening tags)
114
+ const cleaned = cleanStreamingText(part);
115
+ if (cleaned.trim()) segments.push({ type: 'text', content: cleaned });
116
+ }
117
+ }
118
+
119
+ return segments;
120
+ }
121
+
122
+ // ─── References renderer ──────────────────��────────────────────────────────
123
+
124
+ function References({ content }: { content: string }) {
125
+ const refs = content.match(/\[REF:[^\]]+\]/g) ?? [];
126
+ if (!refs.length) return null;
127
+
128
+ return (
129
+ <div className="mc-references">
130
+ <div className="mc-references-title">References</div>
131
+ {refs.map((ref, i) => {
132
+ // [REF:id:source:page:file:superscript]
133
+ const parts = ref.slice(1, -1).split(':');
134
+ const source = parts[2] ?? '';
135
+ const page = parts[3] ?? '';
136
+ const sup = parts[5] ?? `[${i + 1}]`;
137
+ return (
138
+ <div key={i} className="mc-ref-item">
139
+ <span className="mc-ref-sup">{sup}</span>
140
+ <span className="mc-ref-source">{source}</span>
141
+ {page && <span className="mc-ref-page">, p.{page}</span>}
142
+ </div>
143
+ );
144
+ })}
145
+ </div>
146
+ );
147
+ }
148
+
149
+ // ─── Main component ────────────────────────────────────────────────────────
150
+
151
+ export function MessageContent({ text }: { text: string }) {
152
+ const segments = parseContent(text);
153
+
154
+ return (
155
+ <div className="mc-root">
156
+ {segments.map((seg, i) => {
157
+ switch (seg.type) {
158
+ case 'stage':
159
+ return <div key={i} className="mc-stage">{seg.label}</div>;
160
+
161
+ case 'thinking':
162
+ return <div key={i} className="mc-thinking">{seg.content}</div>;
163
+
164
+ case 'response':
165
+ return (
166
+ <div key={i} className="mc-response">
167
+ <ReactMarkdown>{seg.content}</ReactMarkdown>
168
+ </div>
169
+ );
170
+
171
+ case 'tool_output':
172
+ return (
173
+ <div key={i} className="mc-tool-output">
174
+ {seg.label && <div className="mc-tool-output-label">{seg.label}</div>}
175
+ <pre>{seg.content}</pre>
176
+ </div>
177
+ );
178
+
179
+ case 'gradcam':
180
+ return (
181
+ <div key={i} className="mc-image-block">
182
+ <div className="mc-image-label">Grad-CAM Attention Map</div>
183
+ <img
184
+ src={TEMP_IMG_URL(seg.path)}
185
+ className="mc-gradcam-img"
186
+ alt="Grad-CAM attention map"
187
+ />
188
+ </div>
189
+ );
190
+
191
+ case 'comparison':
192
+ return (
193
+ <div key={i} className="mc-image-block">
194
+ <div className="mc-image-label">Lesion Comparison</div>
195
+ <img
196
+ src={TEMP_IMG_URL(seg.path)}
197
+ className="mc-comparison-img"
198
+ alt="Side-by-side lesion comparison"
199
+ />
200
+ </div>
201
+ );
202
+
203
+ case 'gradcam_compare':
204
+ return (
205
+ <div key={i} className="mc-image-block">
206
+ <div className="mc-image-label">Grad-CAM Comparison</div>
207
+ <div className="mc-gradcam-compare">
208
+ <div className="mc-gradcam-compare-item">
209
+ <div className="mc-gradcam-compare-title">Previous</div>
210
+ <img
211
+ src={TEMP_IMG_URL(seg.path1)}
212
+ className="mc-gradcam-compare-img"
213
+ alt="Previous GradCAM"
214
+ />
215
+ </div>
216
+ <div className="mc-gradcam-compare-item">
217
+ <div className="mc-gradcam-compare-title">Current</div>
218
+ <img
219
+ src={TEMP_IMG_URL(seg.path2)}
220
+ className="mc-gradcam-compare-img"
221
+ alt="Current GradCAM"
222
+ />
223
+ </div>
224
+ </div>
225
+ </div>
226
+ );
227
+
228
+ case 'result':
229
+ return <div key={i} className="mc-result">{seg.content}</div>;
230
+
231
+ case 'error':
232
+ return <div key={i} className="mc-error">{seg.content}</div>;
233
+
234
+ case 'complete':
235
+ return <div key={i} className="mc-complete">{seg.content}</div>;
236
+
237
+ case 'references':
238
+ return <References key={i} content={seg.content} />;
239
+
240
+ case 'observation':
241
+ return <div key={i} className="mc-observation">{seg.content}</div>;
242
+
243
+ case 'text':
244
+ return seg.content.trim() ? (
245
+ <div key={i} className="mc-text">{seg.content}</div>
246
+ ) : null;
247
+
248
+ default:
249
+ return null;
250
+ }
251
+ })}
252
+ </div>
253
+ );
254
+ }
web/src/components/ToolCallCard.css ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* ─── Card container ─────────────────────────────────────────────────────── */
2
+ .tool-card {
3
+ border: 1px solid var(--gray-200);
4
+ border-left: 3px solid var(--primary);
5
+ border-radius: 10px;
6
+ overflow: hidden;
7
+ background: var(--gray-50);
8
+ margin-top: 8px;
9
+ }
10
+
11
+ .tool-card.loading {
12
+ border-left-color: var(--gray-400);
13
+ }
14
+
15
+ .tool-card.error {
16
+ border-left-color: #ef4444;
17
+ }
18
+
19
+ /* ─── Header (collapsed row) ─────────────────────────────────────────────── */
20
+ .tool-card-header {
21
+ width: 100%;
22
+ display: flex;
23
+ align-items: center;
24
+ gap: 8px;
25
+ padding: 10px 14px;
26
+ background: transparent;
27
+ border: none;
28
+ cursor: pointer;
29
+ text-align: left;
30
+ transition: background 0.15s;
31
+ }
32
+
33
+ .tool-card-header:hover:not(:disabled) {
34
+ background: var(--gray-100);
35
+ }
36
+
37
+ .tool-card-header:disabled {
38
+ cursor: default;
39
+ }
40
+
41
+ .tool-icon {
42
+ font-size: 1rem;
43
+ flex-shrink: 0;
44
+ }
45
+
46
+ .tool-label {
47
+ flex: 1;
48
+ font-size: 0.875rem;
49
+ font-weight: 500;
50
+ color: var(--gray-700);
51
+ text-transform: capitalize;
52
+ }
53
+
54
+ .tool-status {
55
+ font-size: 0.8125rem;
56
+ flex-shrink: 0;
57
+ }
58
+
59
+ .tool-status.done {
60
+ color: var(--success, #22c55e);
61
+ font-weight: 600;
62
+ }
63
+
64
+ .tool-status.calling {
65
+ color: var(--gray-500);
66
+ display: flex;
67
+ align-items: center;
68
+ gap: 5px;
69
+ }
70
+
71
+ .tool-status.error-text {
72
+ color: #ef4444;
73
+ }
74
+
75
+ .tool-header-summary {
76
+ font-size: 0.8125rem;
77
+ color: var(--gray-500);
78
+ font-weight: 400;
79
+ white-space: nowrap;
80
+ overflow: hidden;
81
+ text-overflow: ellipsis;
82
+ max-width: 200px;
83
+ }
84
+
85
+ .tool-chevron {
86
+ font-size: 0.625rem;
87
+ color: var(--gray-400);
88
+ margin-left: 2px;
89
+ flex-shrink: 0;
90
+ }
91
+
92
+ /* ─── Spinner ────────────────────────────────────────────────────────────── */
93
+ .spinner {
94
+ display: inline-block;
95
+ width: 12px;
96
+ height: 12px;
97
+ border: 2px solid var(--gray-300);
98
+ border-top-color: var(--gray-600);
99
+ border-radius: 50%;
100
+ animation: spin 0.8s linear infinite;
101
+ }
102
+
103
+ @keyframes spin {
104
+ to { transform: rotate(360deg); }
105
+ }
106
+
107
+ /* ─── Card body ──────────────────────────────────────────────────────────── */
108
+ .tool-card-body {
109
+ padding: 14px;
110
+ border-top: 1px solid var(--gray-200);
111
+ background: white;
112
+ }
113
+
114
+ /* ─── analyze_image ──────────────────────────────────────────────────────── */
115
+ .analyze-result {
116
+ display: flex;
117
+ flex-direction: column;
118
+ gap: 12px;
119
+ }
120
+
121
+ .analyze-top {
122
+ display: flex;
123
+ gap: 14px;
124
+ align-items: flex-start;
125
+ }
126
+
127
+ .analyze-thumb {
128
+ width: 72px;
129
+ height: 72px;
130
+ object-fit: cover;
131
+ border-radius: 8px;
132
+ border: 1px solid var(--gray-200);
133
+ flex-shrink: 0;
134
+ }
135
+
136
+ .analyze-info {
137
+ flex: 1;
138
+ min-width: 0;
139
+ }
140
+
141
+ .diagnosis-name {
142
+ font-size: 0.9375rem;
143
+ font-weight: 600;
144
+ color: var(--gray-900);
145
+ margin: 0 0 4px;
146
+ line-height: 1.3;
147
+ }
148
+
149
+ .confidence-label {
150
+ font-size: 0.8125rem;
151
+ font-weight: 500;
152
+ margin: 0 0 6px;
153
+ }
154
+
155
+ .confidence-bar-track {
156
+ height: 6px;
157
+ background: var(--gray-200);
158
+ border-radius: 999px;
159
+ overflow: hidden;
160
+ }
161
+
162
+ .confidence-bar-fill {
163
+ height: 100%;
164
+ border-radius: 999px;
165
+ transition: width 0.3s ease;
166
+ }
167
+
168
+ .analyze-summary {
169
+ font-size: 0.875rem;
170
+ color: var(--gray-700);
171
+ line-height: 1.6;
172
+ margin: 0;
173
+ border-top: 1px solid var(--gray-100);
174
+ padding-top: 10px;
175
+ white-space: pre-wrap;
176
+ }
177
+
178
+ .other-predictions {
179
+ list-style: none;
180
+ padding: 0;
181
+ margin: 0;
182
+ display: flex;
183
+ flex-direction: column;
184
+ gap: 6px;
185
+ border-top: 1px solid var(--gray-100);
186
+ padding-top: 10px;
187
+ }
188
+
189
+ .prediction-row {
190
+ display: flex;
191
+ justify-content: space-between;
192
+ font-size: 0.8125rem;
193
+ }
194
+
195
+ .pred-name {
196
+ color: var(--gray-600);
197
+ }
198
+
199
+ .pred-pct {
200
+ color: var(--gray-500);
201
+ font-variant-numeric: tabular-nums;
202
+ }
203
+
204
+ /* ─── compare_images ─────────────────────────────────────────────────────── */
205
+ .compare-result {
206
+ display: flex;
207
+ flex-direction: column;
208
+ gap: 12px;
209
+ }
210
+
211
+ .carousel {
212
+ position: relative;
213
+ display: flex;
214
+ align-items: center;
215
+ justify-content: center;
216
+ gap: 8px;
217
+ }
218
+
219
+ .carousel-image-wrap {
220
+ position: relative;
221
+ display: inline-block;
222
+ }
223
+
224
+ .carousel-image {
225
+ width: 200px;
226
+ height: 160px;
227
+ object-fit: cover;
228
+ border-radius: 10px;
229
+ border: 1px solid var(--gray-200);
230
+ display: block;
231
+ }
232
+
233
+ .carousel-label {
234
+ position: absolute;
235
+ bottom: 8px;
236
+ left: 50%;
237
+ transform: translateX(-50%);
238
+ background: rgba(0, 0, 0, 0.55);
239
+ color: white;
240
+ font-size: 0.75rem;
241
+ font-weight: 600;
242
+ padding: 3px 10px;
243
+ border-radius: 999px;
244
+ white-space: nowrap;
245
+ }
246
+
247
+ .carousel-btn {
248
+ background: white;
249
+ border: 1px solid var(--gray-300);
250
+ border-radius: 6px;
251
+ width: 28px;
252
+ height: 28px;
253
+ cursor: pointer;
254
+ font-size: 0.75rem;
255
+ color: var(--gray-600);
256
+ display: flex;
257
+ align-items: center;
258
+ justify-content: center;
259
+ flex-shrink: 0;
260
+ }
261
+
262
+ .carousel-btn:hover {
263
+ background: var(--gray-100);
264
+ }
265
+
266
+ .carousel-dots {
267
+ position: absolute;
268
+ bottom: -18px;
269
+ left: 50%;
270
+ transform: translateX(-50%);
271
+ display: flex;
272
+ gap: 5px;
273
+ }
274
+
275
+ .carousel-dot {
276
+ width: 6px;
277
+ height: 6px;
278
+ border-radius: 50%;
279
+ background: var(--gray-300);
280
+ cursor: pointer;
281
+ transition: background 0.15s;
282
+ }
283
+
284
+ .carousel-dot.active {
285
+ background: var(--primary);
286
+ }
287
+
288
+ .compare-status {
289
+ font-size: 0.9375rem;
290
+ margin-top: 6px;
291
+ }
292
+
293
+ .feature-changes {
294
+ list-style: none;
295
+ padding: 0;
296
+ margin: 0;
297
+ display: flex;
298
+ flex-direction: column;
299
+ gap: 6px;
300
+ }
301
+
302
+ .feature-row {
303
+ display: flex;
304
+ justify-content: space-between;
305
+ font-size: 0.8125rem;
306
+ }
307
+
308
+ .feature-name {
309
+ color: var(--gray-600);
310
+ text-transform: capitalize;
311
+ }
312
+
313
+ .feature-delta {
314
+ font-variant-numeric: tabular-nums;
315
+ font-weight: 500;
316
+ }
317
+
318
+ .compare-summary {
319
+ font-size: 0.875rem;
320
+ color: var(--gray-600);
321
+ line-height: 1.5;
322
+ margin: 0;
323
+ border-top: 1px solid var(--gray-100);
324
+ padding-top: 10px;
325
+ }
326
+
327
+ /* ─── Generic fallback ───────────────────────────────────────────────────── */
328
+ .generic-result {
329
+ font-size: 0.75rem;
330
+ background: var(--gray-50);
331
+ border-radius: 6px;
332
+ padding: 10px;
333
+ overflow-x: auto;
334
+ color: var(--gray-700);
335
+ margin: 0;
336
+ white-space: pre-wrap;
337
+ word-break: break-all;
338
+ }
web/src/components/ToolCallCard.tsx ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useEffect, useState } from 'react';
2
+ import { ToolCall } from '../types';
3
+ import './ToolCallCard.css';
4
+
5
+ interface ToolCallCardProps {
6
+ toolCall: ToolCall;
7
+ }
8
+
9
+ /** One-line summary shown in the collapsed header so results are visible at a glance */
10
+ function CollapsedSummary({ toolCall }: { toolCall: ToolCall }) {
11
+ const r = toolCall.result;
12
+ if (!r) return null;
13
+
14
+ if (toolCall.tool === 'analyze_image') {
15
+ const name = r.full_name ?? r.diagnosis;
16
+ const pct = r.confidence != null ? `${Math.round(r.confidence * 100)}%` : null;
17
+ if (name) return (
18
+ <span className="tool-header-summary">
19
+ {name}{pct ? ` — ${pct}` : ''}
20
+ </span>
21
+ );
22
+ }
23
+
24
+ if (toolCall.tool === 'compare_images') {
25
+ const key = r.status_label ?? 'STABLE';
26
+ const cfg = STATUS_CONFIG[key] ?? { emoji: '⚪', label: key };
27
+ return (
28
+ <span className="tool-header-summary">
29
+ {cfg.emoji} {cfg.label}
30
+ </span>
31
+ );
32
+ }
33
+
34
+ return null;
35
+ }
36
+
37
+ export function ToolCallCard({ toolCall }: ToolCallCardProps) {
38
+ // Auto-expand when the tool completes so results are immediately visible.
39
+ // User can collapse manually afterwards.
40
+ const [expanded, setExpanded] = useState(false);
41
+
42
+ useEffect(() => {
43
+ if (toolCall.status === 'complete') setExpanded(true);
44
+ }, [toolCall.status]);
45
+
46
+ const isLoading = toolCall.status === 'calling';
47
+ const isError = toolCall.status === 'error';
48
+
49
+ const icon = toolCall.tool === 'compare_images' ? '🔄' : '🔬';
50
+ const label = toolCall.tool.replace(/_/g, ' ');
51
+
52
+ return (
53
+ <div className={`tool-card ${isLoading ? 'loading' : ''} ${isError ? 'error' : ''}`}>
54
+ <button
55
+ className="tool-card-header"
56
+ onClick={() => !isLoading && setExpanded(e => !e)}
57
+ disabled={isLoading}
58
+ >
59
+ <span className="tool-icon">{icon}</span>
60
+ <span className="tool-label">{label}</span>
61
+ {isLoading ? (
62
+ <span className="tool-status calling">
63
+ <span className="spinner" /> running…
64
+ </span>
65
+ ) : isError ? (
66
+ <span className="tool-status error-text">error</span>
67
+ ) : (
68
+ <>
69
+ <span className="tool-status done">✓</span>
70
+ {!expanded && <CollapsedSummary toolCall={toolCall} />}
71
+ </>
72
+ )}
73
+ {!isLoading && (
74
+ <span className="tool-chevron">{expanded ? '▲' : '▼'}</span>
75
+ )}
76
+ </button>
77
+
78
+ {expanded && !isLoading && toolCall.result && (
79
+ <div className="tool-card-body">
80
+ {toolCall.tool === 'analyze_image' && (
81
+ <AnalyzeImageResult result={toolCall.result} />
82
+ )}
83
+ {toolCall.tool === 'compare_images' && (
84
+ <CompareImagesResult result={toolCall.result} />
85
+ )}
86
+ {toolCall.tool !== 'analyze_image' && toolCall.tool !== 'compare_images' && (
87
+ <GenericResult result={toolCall.result} />
88
+ )}
89
+ </div>
90
+ )}
91
+ </div>
92
+ );
93
+ }
94
+
95
+ /* ─── analyze_image renderer ─────────────────────────────────────────────── */
96
+
97
+ function AnalyzeImageResult({ result }: { result: ToolCall['result'] }) {
98
+ if (!result) return null;
99
+
100
+ const hasClassifier = result.diagnosis != null;
101
+ const topPrediction = result.all_predictions?.[0];
102
+ const otherPredictions = result.all_predictions?.slice(1) ?? [];
103
+ const confidence = result.confidence ?? topPrediction?.probability ?? 0;
104
+ const pct = Math.round(confidence * 100);
105
+ const statusColor = pct >= 70 ? '#ef4444' : pct >= 40 ? '#f59e0b' : '#22c55e';
106
+
107
+ return (
108
+ <div className="analyze-result">
109
+ <div className="analyze-top">
110
+ {result.image_url && (
111
+ <img
112
+ src={result.image_url}
113
+ alt="Analyzed lesion"
114
+ className="analyze-thumb"
115
+ />
116
+ )}
117
+ <div className="analyze-info">
118
+ {hasClassifier ? (
119
+ <>
120
+ <p className="diagnosis-name">{result.full_name ?? result.diagnosis}</p>
121
+ <p className="confidence-label" style={{ color: statusColor }}>
122
+ Confidence: {pct}%
123
+ </p>
124
+ <div className="confidence-bar-track">
125
+ <div
126
+ className="confidence-bar-fill"
127
+ style={{ width: `${pct}%`, background: statusColor }}
128
+ />
129
+ </div>
130
+ </>
131
+ ) : (
132
+ <p className="diagnosis-name" style={{ color: 'var(--gray-500)', fontWeight: 400, fontSize: '0.875rem' }}>
133
+ Visual assessment complete — classifier unavailable
134
+ </p>
135
+ )}
136
+ </div>
137
+ </div>
138
+
139
+ {hasClassifier && otherPredictions.length > 0 && (
140
+ <ul className="other-predictions">
141
+ {otherPredictions.map(p => (
142
+ <li key={p.class} className="prediction-row">
143
+ <span className="pred-name">{p.full_name ?? p.class}</span>
144
+ <span className="pred-pct">{Math.round(p.probability * 100)}%</span>
145
+ </li>
146
+ ))}
147
+ </ul>
148
+ )}
149
+ </div>
150
+ );
151
+ }
152
+
153
+ /* ─── compare_images renderer ────────────────────────────────────────────── */
154
+
155
+ const STATUS_CONFIG: Record<string, { label: string; color: string; emoji: string }> = {
156
+ STABLE: { label: 'Stable', color: '#22c55e', emoji: '🟢' },
157
+ MINOR_CHANGE: { label: 'Minor Change', color: '#f59e0b', emoji: '🟡' },
158
+ SIGNIFICANT_CHANGE: { label: 'Significant Change', color: '#ef4444', emoji: '🔴' },
159
+ IMPROVED: { label: 'Improved', color: '#3b82f6', emoji: '🔵' },
160
+ };
161
+
162
+ function CompareImagesResult({ result }: { result: ToolCall['result'] }) {
163
+ if (!result) return null;
164
+
165
+ const statusKey = result.status_label ?? 'STABLE';
166
+ const status = STATUS_CONFIG[statusKey] ?? { label: statusKey, color: '#6b7280', emoji: '⚪' };
167
+ const featureChanges = Object.entries(result.feature_changes ?? {});
168
+
169
+ return (
170
+ <div className="compare-result">
171
+ <div className="compare-status" style={{ color: status.color }}>
172
+ <strong>Status: {status.label} {status.emoji}</strong>
173
+ </div>
174
+
175
+ {featureChanges.length > 0 && (
176
+ <ul className="feature-changes">
177
+ {featureChanges.map(([name, vals]) => {
178
+ const delta = vals.curr - vals.prev;
179
+ const sign = delta > 0 ? '+' : '';
180
+ return (
181
+ <li key={name} className="feature-row">
182
+ <span className="feature-name">{name}</span>
183
+ <span className="feature-delta" style={{ color: Math.abs(delta) > 0.1 ? '#f59e0b' : '#6b7280' }}>
184
+ {sign}{(delta * 100).toFixed(1)}%
185
+ </span>
186
+ </li>
187
+ );
188
+ })}
189
+ </ul>
190
+ )}
191
+
192
+ {result.summary && (
193
+ <p className="compare-summary">{result.summary}</p>
194
+ )}
195
+ </div>
196
+ );
197
+ }
198
+
199
+ /* ─── Generic (unknown tool) renderer ───────────────────────────────────── */
200
+
201
+ function GenericResult({ result }: { result: ToolCall['result'] }) {
202
+ return (
203
+ <pre className="generic-result">
204
+ {JSON.stringify(result, null, 2)}
205
+ </pre>
206
+ );
207
+ }
web/src/index.css ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ * {
2
+ margin: 0;
3
+ padding: 0;
4
+ box-sizing: border-box;
5
+ }
6
+
7
+ body {
8
+ font-family: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
9
+ background: #f8fafc;
10
+ color: #1e293b;
11
+ line-height: 1.5;
12
+ }
13
+
14
+ button {
15
+ font-family: inherit;
16
+ cursor: pointer;
17
+ }
18
+
19
+ input {
20
+ font-family: inherit;
21
+ }
22
+
23
+ :root {
24
+ --primary: #6366f1;
25
+ --primary-hover: #4f46e5;
26
+ --success: #10b981;
27
+ --error: #ef4444;
28
+ --gray-50: #f8fafc;
29
+ --gray-100: #f1f5f9;
30
+ --gray-200: #e2e8f0;
31
+ --gray-300: #cbd5e1;
32
+ --gray-400: #94a3b8;
33
+ --gray-500: #64748b;
34
+ --gray-600: #475569;
35
+ --gray-700: #334155;
36
+ --gray-800: #1e293b;
37
+ --gray-900: #0f172a;
38
+ }
web/src/main.tsx ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import React from 'react'
2
+ import ReactDOM from 'react-dom/client'
3
+ import { App } from './App'
4
+ import './index.css'
5
+
6
+ ReactDOM.createRoot(document.getElementById('root')!).render(
7
+ <React.StrictMode>
8
+ <App />
9
+ </React.StrictMode>,
10
+ )
web/src/pages/ChatPage.css ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* ─── Layout ───────────────────────────────────────────────────────────── */
2
+ .chat-page {
3
+ display: flex;
4
+ flex-direction: column;
5
+ height: 100vh;
6
+ background: var(--gray-50);
7
+ overflow: hidden;
8
+ }
9
+
10
+ /* ─── Header ────────────────────────────────────────────────────────────── */
11
+ .chat-header {
12
+ display: flex;
13
+ align-items: center;
14
+ gap: 12px;
15
+ padding: 0 16px;
16
+ height: 56px;
17
+ background: white;
18
+ border-bottom: 1px solid var(--gray-200);
19
+ flex-shrink: 0;
20
+ z-index: 10;
21
+ }
22
+
23
+ .header-back-btn {
24
+ width: 36px;
25
+ height: 36px;
26
+ display: flex;
27
+ align-items: center;
28
+ justify-content: center;
29
+ border: none;
30
+ background: transparent;
31
+ cursor: pointer;
32
+ color: var(--gray-600);
33
+ border-radius: 8px;
34
+ transition: background 0.15s;
35
+ flex-shrink: 0;
36
+ }
37
+
38
+ .header-back-btn:hover {
39
+ background: var(--gray-100);
40
+ }
41
+
42
+ .header-back-btn svg {
43
+ width: 20px;
44
+ height: 20px;
45
+ }
46
+
47
+ .header-center {
48
+ flex: 1;
49
+ display: flex;
50
+ flex-direction: column;
51
+ min-width: 0;
52
+ }
53
+
54
+ .header-app-name {
55
+ font-size: 0.7rem;
56
+ font-weight: 600;
57
+ color: var(--primary);
58
+ text-transform: uppercase;
59
+ letter-spacing: 0.05em;
60
+ line-height: 1;
61
+ }
62
+
63
+ .header-patient-name {
64
+ font-size: 1rem;
65
+ font-weight: 600;
66
+ color: var(--gray-900);
67
+ white-space: nowrap;
68
+ overflow: hidden;
69
+ text-overflow: ellipsis;
70
+ line-height: 1.3;
71
+ }
72
+
73
+ .header-clear-btn {
74
+ border: 1px solid var(--gray-300);
75
+ background: transparent;
76
+ border-radius: 8px;
77
+ padding: 6px 14px;
78
+ font-size: 0.8125rem;
79
+ color: var(--gray-600);
80
+ cursor: pointer;
81
+ transition: all 0.15s;
82
+ flex-shrink: 0;
83
+ }
84
+
85
+ .header-clear-btn:hover {
86
+ background: var(--gray-100);
87
+ border-color: var(--gray-400);
88
+ }
89
+
90
+ /* ─── Messages ──────────────────────────────────────────────────────────── */
91
+ .chat-messages {
92
+ flex: 1;
93
+ overflow-y: auto;
94
+ padding: 20px 16px;
95
+ display: flex;
96
+ flex-direction: column;
97
+ gap: 12px;
98
+ }
99
+
100
+ .chat-empty {
101
+ flex: 1;
102
+ display: flex;
103
+ flex-direction: column;
104
+ align-items: center;
105
+ justify-content: center;
106
+ color: var(--gray-400);
107
+ text-align: center;
108
+ gap: 12px;
109
+ margin: auto;
110
+ }
111
+
112
+ .chat-empty-icon svg {
113
+ width: 40px;
114
+ height: 40px;
115
+ color: var(--gray-300);
116
+ }
117
+
118
+ .chat-empty p {
119
+ font-size: 0.9375rem;
120
+ max-width: 280px;
121
+ line-height: 1.5;
122
+ }
123
+
124
+ .message-row {
125
+ display: flex;
126
+ max-width: 720px;
127
+ width: 100%;
128
+ }
129
+
130
+ .message-row.user {
131
+ align-self: flex-end;
132
+ justify-content: flex-end;
133
+ }
134
+
135
+ .message-row.assistant {
136
+ align-self: flex-start;
137
+ justify-content: flex-start;
138
+ }
139
+
140
+ /* ─── Bubbles ────────────────────────────────────────────────────────────── */
141
+ .bubble {
142
+ max-width: 85%;
143
+ border-radius: 16px;
144
+ padding: 12px 16px;
145
+ display: flex;
146
+ flex-direction: column;
147
+ gap: 8px;
148
+ }
149
+
150
+ .user-bubble {
151
+ background: var(--primary);
152
+ color: white;
153
+ border-bottom-right-radius: 4px;
154
+ }
155
+
156
+ .assistant-bubble {
157
+ background: white;
158
+ border: 1px solid var(--gray-200);
159
+ border-bottom-left-radius: 4px;
160
+ box-shadow: 0 1px 3px rgba(0, 0, 0, 0.06);
161
+ max-width: 90%;
162
+ }
163
+
164
+ .bubble-text {
165
+ font-size: 0.9375rem;
166
+ line-height: 1.6;
167
+ white-space: pre-wrap;
168
+ word-break: break-word;
169
+ margin: 0;
170
+ }
171
+
172
+ .user-bubble .bubble-text {
173
+ color: white;
174
+ }
175
+
176
+ .assistant-text {
177
+ color: var(--gray-800);
178
+ }
179
+
180
+ /* Image in user bubble */
181
+ .message-image {
182
+ width: 100%;
183
+ max-width: 260px;
184
+ border-radius: 10px;
185
+ display: block;
186
+ }
187
+
188
+ /* ─── Thinking indicator ─────────────────────────────────────────────────── */
189
+ .thinking {
190
+ display: flex;
191
+ gap: 4px;
192
+ padding: 4px 0;
193
+ }
194
+
195
+ .dot {
196
+ width: 7px;
197
+ height: 7px;
198
+ background: var(--gray-400);
199
+ border-radius: 50%;
200
+ animation: bounce 1.2s infinite;
201
+ }
202
+
203
+ .dot:nth-child(2) { animation-delay: 0.2s; }
204
+ .dot:nth-child(3) { animation-delay: 0.4s; }
205
+
206
+ @keyframes bounce {
207
+ 0%, 60%, 100% { transform: translateY(0); }
208
+ 30% { transform: translateY(-6px); }
209
+ }
210
+
211
+ /* ─── Input bar ──────────────────────────────────────────────────────────── */
212
+ .chat-input-bar {
213
+ background: white;
214
+ border-top: 1px solid var(--gray-200);
215
+ padding: 12px 16px;
216
+ flex-shrink: 0;
217
+ }
218
+
219
+ .image-preview-container {
220
+ position: relative;
221
+ display: inline-block;
222
+ margin-bottom: 10px;
223
+ }
224
+
225
+ .image-preview-thumb {
226
+ width: 72px;
227
+ height: 72px;
228
+ object-fit: cover;
229
+ border-radius: 10px;
230
+ border: 2px solid var(--gray-200);
231
+ display: block;
232
+ }
233
+
234
+ .remove-image-btn {
235
+ position: absolute;
236
+ top: -8px;
237
+ right: -8px;
238
+ width: 22px;
239
+ height: 22px;
240
+ background: var(--gray-700);
241
+ color: white;
242
+ border: none;
243
+ border-radius: 50%;
244
+ font-size: 0.875rem;
245
+ line-height: 1;
246
+ cursor: pointer;
247
+ display: flex;
248
+ align-items: center;
249
+ justify-content: center;
250
+ }
251
+
252
+ .input-row {
253
+ display: flex;
254
+ align-items: flex-end;
255
+ gap: 8px;
256
+ }
257
+
258
+ .attach-btn {
259
+ width: 38px;
260
+ height: 38px;
261
+ border: 1px solid var(--gray-300);
262
+ background: transparent;
263
+ border-radius: 10px;
264
+ cursor: pointer;
265
+ color: var(--gray-500);
266
+ display: flex;
267
+ align-items: center;
268
+ justify-content: center;
269
+ flex-shrink: 0;
270
+ transition: all 0.15s;
271
+ }
272
+
273
+ .attach-btn:hover:not(:disabled) {
274
+ background: var(--gray-100);
275
+ border-color: var(--gray-400);
276
+ color: var(--gray-700);
277
+ }
278
+
279
+ .attach-btn:disabled {
280
+ opacity: 0.4;
281
+ cursor: not-allowed;
282
+ }
283
+
284
+ .attach-btn svg {
285
+ width: 18px;
286
+ height: 18px;
287
+ }
288
+
289
+ .chat-input {
290
+ flex: 1;
291
+ border: 1px solid var(--gray-300);
292
+ border-radius: 10px;
293
+ padding: 9px 14px;
294
+ font-size: 0.9375rem;
295
+ font-family: inherit;
296
+ resize: none;
297
+ line-height: 1.5;
298
+ max-height: 160px;
299
+ overflow-y: auto;
300
+ transition: border-color 0.15s;
301
+ }
302
+
303
+ .chat-input:focus {
304
+ outline: none;
305
+ border-color: var(--primary);
306
+ }
307
+
308
+ .chat-input:disabled {
309
+ background: var(--gray-50);
310
+ color: var(--gray-400);
311
+ }
312
+
313
+ .send-btn {
314
+ width: 38px;
315
+ height: 38px;
316
+ background: var(--primary);
317
+ color: white;
318
+ border: none;
319
+ border-radius: 10px;
320
+ cursor: pointer;
321
+ display: flex;
322
+ align-items: center;
323
+ justify-content: center;
324
+ flex-shrink: 0;
325
+ transition: background 0.15s;
326
+ }
327
+
328
+ .send-btn:hover:not(:disabled) {
329
+ background: var(--primary-hover);
330
+ }
331
+
332
+ .send-btn:disabled {
333
+ background: var(--gray-300);
334
+ cursor: not-allowed;
335
+ }
336
+
337
+ .send-btn svg {
338
+ width: 18px;
339
+ height: 18px;
340
+ }