Commit ·
86f402d
0
Parent(s):
Initial commit — SkinProAI dermoscopic analysis platform
Browse filesPatient-level chat UI with streaming AI analysis pipeline:
- MedGemma visual examination with stage-by-stage cascade
- ConvNeXt classifier with confidence scores and differential
- MONET feature extraction and Grad-CAM attention maps
- Temporal comparison between sequential lesion images
- Persistent chat history with full cascade replay on reload
- FastAPI backend with SSE streaming + React/TypeScript frontend
- Docker build ready for Hugging Face Spaces deployment
This view is limited to 50 files because it contains too many changes.
See raw diff
- .dockerignore +28 -0
- .gitattributes +1 -0
- .gitignore +45 -0
- Dockerfile +40 -0
- README.md +44 -0
- backend/__init__.py +0 -0
- backend/main.py +63 -0
- backend/requirements.txt +3 -0
- backend/routes/__init__.py +1 -0
- backend/routes/analysis.py +181 -0
- backend/routes/chat.py +97 -0
- backend/routes/lesions.py +241 -0
- backend/routes/patients.py +72 -0
- backend/services/__init__.py +0 -0
- backend/services/analysis_service.py +146 -0
- backend/services/chat_service.py +197 -0
- data/case_store.py +507 -0
- frontend/app.py +532 -0
- frontend/components/__init__.py +0 -0
- frontend/components/analysis_view.py +214 -0
- frontend/components/patient_select.py +48 -0
- frontend/components/sidebar.py +55 -0
- frontend/components/styles.py +517 -0
- guidelines/index/chunks.json +0 -0
- guidelines/index/faiss.index +3 -0
- mcp_server/__init__.py +0 -0
- mcp_server/server.py +286 -0
- mcp_server/tool_registry.py +55 -0
- models/convnext_classifier.py +383 -0
- models/explainability.py +183 -0
- models/gradcam_tool.py +285 -0
- models/guidelines_rag.py +349 -0
- models/medgemma_agent.py +927 -0
- models/medsiglip_convnext_fusion.py +224 -0
- models/monet_concepts.py +332 -0
- models/monet_tool.py +354 -0
- models/overlay_tool.py +335 -0
- requirements.txt +15 -0
- test_models.py +86 -0
- web/index.html +16 -0
- web/package-lock.json +0 -0
- web/package.json +24 -0
- web/src/App.tsx +14 -0
- web/src/components/MessageContent.css +250 -0
- web/src/components/MessageContent.tsx +254 -0
- web/src/components/ToolCallCard.css +338 -0
- web/src/components/ToolCallCard.tsx +207 -0
- web/src/index.css +38 -0
- web/src/main.tsx +10 -0
- web/src/pages/ChatPage.css +340 -0
.dockerignore
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
venv/
|
| 2 |
+
.venv/
|
| 3 |
+
__pycache__/
|
| 4 |
+
*.pyc
|
| 5 |
+
*.pyo
|
| 6 |
+
*.pyd
|
| 7 |
+
.Python
|
| 8 |
+
.env
|
| 9 |
+
|
| 10 |
+
# Frontend dev artifacts (only dist is needed)
|
| 11 |
+
web/node_modules/
|
| 12 |
+
web/.vite/
|
| 13 |
+
|
| 14 |
+
# Local data — don't ship patient records or uploads
|
| 15 |
+
data/uploads/
|
| 16 |
+
data/patient_chats/
|
| 17 |
+
data/lesions/
|
| 18 |
+
data/patients.json
|
| 19 |
+
|
| 20 |
+
# Misc
|
| 21 |
+
.git/
|
| 22 |
+
.gitignore
|
| 23 |
+
*.md
|
| 24 |
+
test*.py
|
| 25 |
+
test*.jpg
|
| 26 |
+
test*.png
|
| 27 |
+
frontend/
|
| 28 |
+
mcp_server/
|
.gitattributes
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
guidelines/index/faiss.index filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.pyc
|
| 4 |
+
*.pyo
|
| 5 |
+
*.pyd
|
| 6 |
+
*.egg-info/
|
| 7 |
+
.Python
|
| 8 |
+
venv/
|
| 9 |
+
.venv/
|
| 10 |
+
*.egg
|
| 11 |
+
|
| 12 |
+
# Environment
|
| 13 |
+
.env
|
| 14 |
+
.env.*
|
| 15 |
+
|
| 16 |
+
# Node
|
| 17 |
+
web/node_modules/
|
| 18 |
+
web/dist/
|
| 19 |
+
web/.vite/
|
| 20 |
+
|
| 21 |
+
# Patient data — never commit
|
| 22 |
+
data/uploads/
|
| 23 |
+
data/patient_chats/
|
| 24 |
+
data/lesions/
|
| 25 |
+
data/patients.json
|
| 26 |
+
|
| 27 |
+
# Model weights (large binaries — store separately)
|
| 28 |
+
models/*.pt
|
| 29 |
+
models/*.pth
|
| 30 |
+
models/*.bin
|
| 31 |
+
models/*.safetensors
|
| 32 |
+
|
| 33 |
+
# macOS
|
| 34 |
+
.DS_Store
|
| 35 |
+
|
| 36 |
+
# Test artifacts
|
| 37 |
+
test*.jpg
|
| 38 |
+
test*.png
|
| 39 |
+
*.log
|
| 40 |
+
|
| 41 |
+
# Clinical guidelines PDFs (copyrighted — obtain separately)
|
| 42 |
+
guidelines/*.pdf
|
| 43 |
+
|
| 44 |
+
# Temp
|
| 45 |
+
/tmp/
|
Dockerfile
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.10-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
# Install Node.js for building the React frontend
|
| 6 |
+
RUN apt-get update && \
|
| 7 |
+
apt-get install -y curl && \
|
| 8 |
+
curl -fsSL https://deb.nodesource.com/setup_20.x | bash - && \
|
| 9 |
+
apt-get install -y nodejs && \
|
| 10 |
+
rm -rf /var/lib/apt/lists/*
|
| 11 |
+
|
| 12 |
+
# Install Python dependencies
|
| 13 |
+
COPY requirements.txt ml-requirements.txt
|
| 14 |
+
COPY backend/requirements.txt api-requirements.txt
|
| 15 |
+
RUN pip install --no-cache-dir -r ml-requirements.txt -r api-requirements.txt
|
| 16 |
+
|
| 17 |
+
# Build React frontend
|
| 18 |
+
COPY web/ web/
|
| 19 |
+
WORKDIR /app/web
|
| 20 |
+
RUN npm ci && npm run build
|
| 21 |
+
|
| 22 |
+
WORKDIR /app
|
| 23 |
+
|
| 24 |
+
# Copy application source
|
| 25 |
+
COPY models/ models/
|
| 26 |
+
COPY backend/ backend/
|
| 27 |
+
COPY data/case_store.py data/case_store.py
|
| 28 |
+
COPY guidelines/ guidelines/
|
| 29 |
+
|
| 30 |
+
# Runtime directories (writable by the app)
|
| 31 |
+
RUN mkdir -p data/uploads data/patient_chats data/lesions && \
|
| 32 |
+
echo '{"patients": []}' > data/patients.json
|
| 33 |
+
|
| 34 |
+
# HF Spaces runs as a non-root user — ensure data dirs are writable
|
| 35 |
+
RUN chmod -R 777 data/
|
| 36 |
+
|
| 37 |
+
# HF Spaces uses port 7860
|
| 38 |
+
EXPOSE 7860
|
| 39 |
+
|
| 40 |
+
CMD ["uvicorn", "backend.main:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: SkinProAI
|
| 3 |
+
emoji: 🔬
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: indigo
|
| 6 |
+
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
+
pinned: false
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
# SkinProAI
|
| 12 |
+
|
| 13 |
+
AI-assisted dermoscopic lesion analysis for clinical decision support.
|
| 14 |
+
|
| 15 |
+
## Features
|
| 16 |
+
|
| 17 |
+
- **Patient management** — create and select patient profiles
|
| 18 |
+
- **Image analysis** — upload dermoscopic images for automated assessment via MedGemma visual examination, MONET feature extraction, and ConvNeXt classification
|
| 19 |
+
- **Temporal comparison** — sequential images are automatically compared to detect change over time
|
| 20 |
+
- **Grad-CAM visualisation** — attention maps highlight regions driving the classification
|
| 21 |
+
- **Persistent chat history** — full analysis cascade is stored and replayed on reload
|
| 22 |
+
|
| 23 |
+
## Architecture
|
| 24 |
+
|
| 25 |
+
| Layer | Technology |
|
| 26 |
+
|-------|-----------|
|
| 27 |
+
| Frontend | React 18 + TypeScript (Vite) |
|
| 28 |
+
| Backend | FastAPI + uvicorn |
|
| 29 |
+
| Vision-language model | MedGemma (Google) via Hugging Face |
|
| 30 |
+
| Classifier | ConvNeXt fine-tuned on ISIC HAM10000 |
|
| 31 |
+
| Feature extraction | MONET skin concept probes |
|
| 32 |
+
| Explainability | Grad-CAM |
|
| 33 |
+
|
| 34 |
+
## Usage
|
| 35 |
+
|
| 36 |
+
1. Open the app and create a patient record
|
| 37 |
+
2. Click the patient card to open the chat
|
| 38 |
+
3. Attach a dermoscopic image and send — analysis runs automatically
|
| 39 |
+
4. Upload further images for the same patient to trigger temporal comparison
|
| 40 |
+
5. Ask follow-up questions in text to query the AI about the findings
|
| 41 |
+
|
| 42 |
+
## Disclaimer
|
| 43 |
+
|
| 44 |
+
SkinProAI is a research prototype intended for educational and investigational use only. It is **not** a certified medical device and must not be used as a substitute for professional clinical judgement.
|
backend/__init__.py
ADDED
|
File without changes
|
backend/main.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SkinProAI FastAPI Backend
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from fastapi import FastAPI
|
| 6 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 7 |
+
from fastapi.staticfiles import StaticFiles
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
import sys
|
| 10 |
+
|
| 11 |
+
# Add project root to path for model imports
|
| 12 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 13 |
+
|
| 14 |
+
from backend.routes import patients, lesions, analysis, chat
|
| 15 |
+
|
| 16 |
+
app = FastAPI(title="SkinProAI API", version="1.0.0")
|
| 17 |
+
|
| 18 |
+
# CORS middleware
|
| 19 |
+
app.add_middleware(
|
| 20 |
+
CORSMiddleware,
|
| 21 |
+
allow_origins=["*"],
|
| 22 |
+
allow_credentials=True,
|
| 23 |
+
allow_methods=["*"],
|
| 24 |
+
allow_headers=["*"],
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
# API routes — analysis must be registered BEFORE patients so the literal
|
| 28 |
+
# /gradcam route is not shadowed by the parameterised /{patient_id} route.
|
| 29 |
+
app.include_router(analysis.router, prefix="/api/patients", tags=["analysis"])
|
| 30 |
+
app.include_router(chat.router, prefix="/api/patients", tags=["chat"])
|
| 31 |
+
app.include_router(patients.router, prefix="/api/patients", tags=["patients"])
|
| 32 |
+
app.include_router(lesions.router, prefix="/api/patients", tags=["lesions"])
|
| 33 |
+
|
| 34 |
+
# Ensure upload directories exist
|
| 35 |
+
UPLOADS_DIR = Path(__file__).parent.parent / "data" / "uploads"
|
| 36 |
+
UPLOADS_DIR.mkdir(parents=True, exist_ok=True)
|
| 37 |
+
|
| 38 |
+
# Serve uploaded images
|
| 39 |
+
if UPLOADS_DIR.exists():
|
| 40 |
+
app.mount("/uploads", StaticFiles(directory=str(UPLOADS_DIR)), name="uploads")
|
| 41 |
+
|
| 42 |
+
# Serve React build (production)
|
| 43 |
+
BUILD_DIR = Path(__file__).parent.parent / "web" / "dist"
|
| 44 |
+
if BUILD_DIR.exists():
|
| 45 |
+
app.mount("/", StaticFiles(directory=str(BUILD_DIR), html=True), name="static")
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@app.on_event("shutdown")
|
| 49 |
+
async def shutdown_event():
|
| 50 |
+
from backend.services.analysis_service import get_analysis_service
|
| 51 |
+
svc = get_analysis_service()
|
| 52 |
+
if svc.agent.mcp_client:
|
| 53 |
+
svc.agent.mcp_client.stop()
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
@app.get("/api/health")
|
| 57 |
+
def health_check():
|
| 58 |
+
return {"status": "ok"}
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
if __name__ == "__main__":
|
| 62 |
+
import uvicorn
|
| 63 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
backend/requirements.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi>=0.100.0
|
| 2 |
+
uvicorn[standard]>=0.23.0
|
| 3 |
+
python-multipart>=0.0.6
|
backend/routes/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from . import patients, lesions, analysis
|
backend/routes/analysis.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Analysis Routes - Image analysis with SSE streaming
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from fastapi import APIRouter, Query, HTTPException
|
| 6 |
+
from fastapi.responses import StreamingResponse, FileResponse
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
import json
|
| 9 |
+
import tempfile
|
| 10 |
+
|
| 11 |
+
from backend.services.analysis_service import get_analysis_service
|
| 12 |
+
from data.case_store import get_case_store
|
| 13 |
+
|
| 14 |
+
router = APIRouter()
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@router.get("/gradcam")
|
| 18 |
+
def get_gradcam_by_path(path: str = Query(...)):
|
| 19 |
+
"""Serve a temp visualization image (GradCAM or comparison overlay)"""
|
| 20 |
+
if not path:
|
| 21 |
+
raise HTTPException(status_code=400, detail="No path provided")
|
| 22 |
+
|
| 23 |
+
temp_dir = Path(tempfile.gettempdir()).resolve()
|
| 24 |
+
resolved_path = Path(path).resolve()
|
| 25 |
+
if not str(resolved_path).startswith(str(temp_dir)):
|
| 26 |
+
raise HTTPException(status_code=403, detail="Access denied")
|
| 27 |
+
|
| 28 |
+
allowed_suffixes = ("_gradcam.png", "_comparison.png")
|
| 29 |
+
if not any(resolved_path.name.endswith(s) for s in allowed_suffixes):
|
| 30 |
+
raise HTTPException(status_code=400, detail="Invalid image path")
|
| 31 |
+
|
| 32 |
+
if resolved_path.exists():
|
| 33 |
+
return FileResponse(str(resolved_path), media_type="image/png")
|
| 34 |
+
raise HTTPException(status_code=404, detail="Image not found")
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@router.post("/{patient_id}/lesions/{lesion_id}/images/{image_id}/analyze")
|
| 38 |
+
async def analyze_image(
|
| 39 |
+
patient_id: str,
|
| 40 |
+
lesion_id: str,
|
| 41 |
+
image_id: str,
|
| 42 |
+
question: str = Query(None)
|
| 43 |
+
):
|
| 44 |
+
"""Analyze an image with SSE streaming"""
|
| 45 |
+
store = get_case_store()
|
| 46 |
+
|
| 47 |
+
# Verify image exists
|
| 48 |
+
img = store.get_image(patient_id, lesion_id, image_id)
|
| 49 |
+
if not img:
|
| 50 |
+
raise HTTPException(status_code=404, detail="Image not found")
|
| 51 |
+
if not img.image_path:
|
| 52 |
+
raise HTTPException(status_code=400, detail="Image has no file uploaded")
|
| 53 |
+
|
| 54 |
+
service = get_analysis_service()
|
| 55 |
+
|
| 56 |
+
async def generate():
|
| 57 |
+
try:
|
| 58 |
+
for chunk in service.analyze(patient_id, lesion_id, image_id, question):
|
| 59 |
+
yield f"data: {json.dumps(chunk)}\n\n"
|
| 60 |
+
yield "data: [DONE]\n\n"
|
| 61 |
+
except Exception as e:
|
| 62 |
+
yield f"data: {json.dumps(f'[ERROR]{str(e)}[/ERROR]')}\n\n"
|
| 63 |
+
yield "data: [DONE]\n\n"
|
| 64 |
+
|
| 65 |
+
return StreamingResponse(
|
| 66 |
+
generate(),
|
| 67 |
+
media_type="text/event-stream",
|
| 68 |
+
headers={
|
| 69 |
+
"Cache-Control": "no-cache",
|
| 70 |
+
"Connection": "keep-alive",
|
| 71 |
+
}
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
@router.post("/{patient_id}/lesions/{lesion_id}/images/{image_id}/confirm")
|
| 76 |
+
async def confirm_diagnosis(
|
| 77 |
+
patient_id: str,
|
| 78 |
+
lesion_id: str,
|
| 79 |
+
image_id: str,
|
| 80 |
+
confirmed: bool = Query(...),
|
| 81 |
+
feedback: str = Query(None)
|
| 82 |
+
):
|
| 83 |
+
"""Confirm or reject diagnosis and get management guidance"""
|
| 84 |
+
service = get_analysis_service()
|
| 85 |
+
|
| 86 |
+
async def generate():
|
| 87 |
+
try:
|
| 88 |
+
for chunk in service.confirm(patient_id, lesion_id, image_id, confirmed, feedback):
|
| 89 |
+
yield f"data: {json.dumps(chunk)}\n\n"
|
| 90 |
+
yield "data: [DONE]\n\n"
|
| 91 |
+
except Exception as e:
|
| 92 |
+
yield f"data: {json.dumps(f'[ERROR]{str(e)}[/ERROR]')}\n\n"
|
| 93 |
+
yield "data: [DONE]\n\n"
|
| 94 |
+
|
| 95 |
+
return StreamingResponse(
|
| 96 |
+
generate(),
|
| 97 |
+
media_type="text/event-stream",
|
| 98 |
+
headers={
|
| 99 |
+
"Cache-Control": "no-cache",
|
| 100 |
+
"Connection": "keep-alive",
|
| 101 |
+
}
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
@router.post("/{patient_id}/lesions/{lesion_id}/images/{image_id}/compare")
|
| 106 |
+
async def compare_to_previous(
|
| 107 |
+
patient_id: str,
|
| 108 |
+
lesion_id: str,
|
| 109 |
+
image_id: str
|
| 110 |
+
):
|
| 111 |
+
"""Compare this image to the previous one in the timeline"""
|
| 112 |
+
store = get_case_store()
|
| 113 |
+
|
| 114 |
+
# Get current and previous images
|
| 115 |
+
current_img = store.get_image(patient_id, lesion_id, image_id)
|
| 116 |
+
if not current_img:
|
| 117 |
+
raise HTTPException(status_code=404, detail="Image not found")
|
| 118 |
+
|
| 119 |
+
previous_img = store.get_previous_image(patient_id, lesion_id, image_id)
|
| 120 |
+
if not previous_img:
|
| 121 |
+
raise HTTPException(status_code=400, detail="No previous image to compare")
|
| 122 |
+
|
| 123 |
+
service = get_analysis_service()
|
| 124 |
+
|
| 125 |
+
async def generate():
|
| 126 |
+
try:
|
| 127 |
+
for chunk in service.compare_images(
|
| 128 |
+
patient_id, lesion_id,
|
| 129 |
+
previous_img.image_path,
|
| 130 |
+
current_img.image_path,
|
| 131 |
+
image_id
|
| 132 |
+
):
|
| 133 |
+
yield f"data: {json.dumps(chunk)}\n\n"
|
| 134 |
+
yield "data: [DONE]\n\n"
|
| 135 |
+
except Exception as e:
|
| 136 |
+
yield f"data: {json.dumps(f'[ERROR]{str(e)}[/ERROR]')}\n\n"
|
| 137 |
+
yield "data: [DONE]\n\n"
|
| 138 |
+
|
| 139 |
+
return StreamingResponse(
|
| 140 |
+
generate(),
|
| 141 |
+
media_type="text/event-stream",
|
| 142 |
+
headers={
|
| 143 |
+
"Cache-Control": "no-cache",
|
| 144 |
+
"Connection": "keep-alive",
|
| 145 |
+
}
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
@router.post("/{patient_id}/lesions/{lesion_id}/chat")
|
| 150 |
+
async def chat_message(
|
| 151 |
+
patient_id: str,
|
| 152 |
+
lesion_id: str,
|
| 153 |
+
message: dict
|
| 154 |
+
):
|
| 155 |
+
"""Send a chat message with SSE streaming response"""
|
| 156 |
+
store = get_case_store()
|
| 157 |
+
|
| 158 |
+
lesion = store.get_lesion(patient_id, lesion_id)
|
| 159 |
+
if not lesion:
|
| 160 |
+
raise HTTPException(status_code=404, detail="Lesion not found")
|
| 161 |
+
|
| 162 |
+
service = get_analysis_service()
|
| 163 |
+
content = message.get("content", "")
|
| 164 |
+
|
| 165 |
+
async def generate():
|
| 166 |
+
try:
|
| 167 |
+
for chunk in service.chat_followup(patient_id, lesion_id, content):
|
| 168 |
+
yield f"data: {json.dumps(chunk)}\n\n"
|
| 169 |
+
yield "data: [DONE]\n\n"
|
| 170 |
+
except Exception as e:
|
| 171 |
+
yield f"data: {json.dumps(f'[ERROR]{str(e)}[/ERROR]')}\n\n"
|
| 172 |
+
yield "data: [DONE]\n\n"
|
| 173 |
+
|
| 174 |
+
return StreamingResponse(
|
| 175 |
+
generate(),
|
| 176 |
+
media_type="text/event-stream",
|
| 177 |
+
headers={
|
| 178 |
+
"Cache-Control": "no-cache",
|
| 179 |
+
"Connection": "keep-alive",
|
| 180 |
+
}
|
| 181 |
+
)
|
backend/routes/chat.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Chat Routes - Patient-level chat with image analysis tools
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import asyncio
|
| 6 |
+
import json
|
| 7 |
+
import threading
|
| 8 |
+
from typing import Optional
|
| 9 |
+
|
| 10 |
+
from fastapi import APIRouter, HTTPException, UploadFile, File, Form
|
| 11 |
+
from fastapi.responses import StreamingResponse
|
| 12 |
+
|
| 13 |
+
from data.case_store import get_case_store
|
| 14 |
+
from backend.services.chat_service import get_chat_service
|
| 15 |
+
|
| 16 |
+
router = APIRouter()
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@router.get("/{patient_id}/chat")
|
| 20 |
+
def get_chat_history(patient_id: str):
|
| 21 |
+
"""Get patient-level chat history"""
|
| 22 |
+
store = get_case_store()
|
| 23 |
+
if not store.get_patient(patient_id):
|
| 24 |
+
raise HTTPException(status_code=404, detail="Patient not found")
|
| 25 |
+
messages = store.get_patient_chat_history(patient_id)
|
| 26 |
+
return {"messages": messages}
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@router.delete("/{patient_id}/chat")
|
| 30 |
+
def clear_chat(patient_id: str):
|
| 31 |
+
"""Clear patient-level chat history"""
|
| 32 |
+
store = get_case_store()
|
| 33 |
+
if not store.get_patient(patient_id):
|
| 34 |
+
raise HTTPException(status_code=404, detail="Patient not found")
|
| 35 |
+
store.clear_patient_chat_history(patient_id)
|
| 36 |
+
return {"success": True}
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@router.post("/{patient_id}/chat")
|
| 40 |
+
async def post_chat_message(
|
| 41 |
+
patient_id: str,
|
| 42 |
+
content: str = Form(""),
|
| 43 |
+
image: Optional[UploadFile] = File(None),
|
| 44 |
+
):
|
| 45 |
+
"""Send a chat message, optionally with an image — SSE streaming response.
|
| 46 |
+
|
| 47 |
+
The sync ML generator runs in a background thread so it never blocks the
|
| 48 |
+
event loop. Events flow through an asyncio.Queue, so each SSE event is
|
| 49 |
+
flushed to the browser the moment it is produced (spinner shows instantly).
|
| 50 |
+
"""
|
| 51 |
+
store = get_case_store()
|
| 52 |
+
if not store.get_patient(patient_id):
|
| 53 |
+
raise HTTPException(status_code=404, detail="Patient not found")
|
| 54 |
+
|
| 55 |
+
image_bytes = None
|
| 56 |
+
if image and image.filename:
|
| 57 |
+
image_bytes = await image.read()
|
| 58 |
+
|
| 59 |
+
chat_service = get_chat_service()
|
| 60 |
+
|
| 61 |
+
async def generate():
|
| 62 |
+
loop = asyncio.get_event_loop()
|
| 63 |
+
queue: asyncio.Queue = asyncio.Queue()
|
| 64 |
+
|
| 65 |
+
_SENTINEL = object()
|
| 66 |
+
|
| 67 |
+
def run_sync():
|
| 68 |
+
try:
|
| 69 |
+
for event in chat_service.stream_chat(patient_id, content, image_bytes):
|
| 70 |
+
loop.call_soon_threadsafe(queue.put_nowait, event)
|
| 71 |
+
except Exception as e:
|
| 72 |
+
loop.call_soon_threadsafe(
|
| 73 |
+
queue.put_nowait,
|
| 74 |
+
{"type": "error", "message": str(e)},
|
| 75 |
+
)
|
| 76 |
+
finally:
|
| 77 |
+
loop.call_soon_threadsafe(queue.put_nowait, _SENTINEL)
|
| 78 |
+
|
| 79 |
+
thread = threading.Thread(target=run_sync, daemon=True)
|
| 80 |
+
thread.start()
|
| 81 |
+
|
| 82 |
+
while True:
|
| 83 |
+
event = await queue.get()
|
| 84 |
+
if event is _SENTINEL:
|
| 85 |
+
break
|
| 86 |
+
yield f"data: {json.dumps(event)}\n\n"
|
| 87 |
+
|
| 88 |
+
yield f"data: {json.dumps({'type': 'done'})}\n\n"
|
| 89 |
+
|
| 90 |
+
return StreamingResponse(
|
| 91 |
+
generate(),
|
| 92 |
+
media_type="text/event-stream",
|
| 93 |
+
headers={
|
| 94 |
+
"Cache-Control": "no-cache",
|
| 95 |
+
"Connection": "keep-alive",
|
| 96 |
+
},
|
| 97 |
+
)
|
backend/routes/lesions.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Lesion Routes - CRUD for lesions and images
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from fastapi import APIRouter, HTTPException, UploadFile, File
|
| 6 |
+
from fastapi.responses import FileResponse
|
| 7 |
+
from pydantic import BaseModel
|
| 8 |
+
from dataclasses import asdict
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from PIL import Image
|
| 11 |
+
import io
|
| 12 |
+
|
| 13 |
+
from data.case_store import get_case_store
|
| 14 |
+
|
| 15 |
+
router = APIRouter()
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class CreateLesionRequest(BaseModel):
|
| 19 |
+
name: str
|
| 20 |
+
location: str = ""
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class UpdateLesionRequest(BaseModel):
|
| 24 |
+
name: str = None
|
| 25 |
+
location: str = None
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# -------------------------------------------------------------------------
|
| 29 |
+
# Lesion CRUD
|
| 30 |
+
# -------------------------------------------------------------------------
|
| 31 |
+
|
| 32 |
+
@router.get("/{patient_id}/lesions")
|
| 33 |
+
def list_lesions(patient_id: str):
|
| 34 |
+
"""List all lesions for a patient"""
|
| 35 |
+
store = get_case_store()
|
| 36 |
+
|
| 37 |
+
patient = store.get_patient(patient_id)
|
| 38 |
+
if not patient:
|
| 39 |
+
raise HTTPException(status_code=404, detail="Patient not found")
|
| 40 |
+
|
| 41 |
+
lesions = store.list_lesions(patient_id)
|
| 42 |
+
|
| 43 |
+
result = []
|
| 44 |
+
for lesion in lesions:
|
| 45 |
+
images = store.list_images(patient_id, lesion.id)
|
| 46 |
+
# Get the most recent image as thumbnail
|
| 47 |
+
latest_image = images[-1] if images else None
|
| 48 |
+
|
| 49 |
+
result.append({
|
| 50 |
+
"id": lesion.id,
|
| 51 |
+
"patient_id": lesion.patient_id,
|
| 52 |
+
"name": lesion.name,
|
| 53 |
+
"location": lesion.location,
|
| 54 |
+
"created_at": lesion.created_at,
|
| 55 |
+
"image_count": len(images),
|
| 56 |
+
"latest_image": asdict(latest_image) if latest_image else None
|
| 57 |
+
})
|
| 58 |
+
|
| 59 |
+
return {"lesions": result}
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
@router.post("/{patient_id}/lesions")
|
| 63 |
+
def create_lesion(patient_id: str, req: CreateLesionRequest):
|
| 64 |
+
"""Create a new lesion for a patient"""
|
| 65 |
+
store = get_case_store()
|
| 66 |
+
|
| 67 |
+
patient = store.get_patient(patient_id)
|
| 68 |
+
if not patient:
|
| 69 |
+
raise HTTPException(status_code=404, detail="Patient not found")
|
| 70 |
+
|
| 71 |
+
lesion = store.create_lesion(patient_id, req.name, req.location)
|
| 72 |
+
return {
|
| 73 |
+
"lesion": {
|
| 74 |
+
**asdict(lesion),
|
| 75 |
+
"image_count": 0,
|
| 76 |
+
"images": []
|
| 77 |
+
}
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
@router.get("/{patient_id}/lesions/{lesion_id}")
|
| 82 |
+
def get_lesion(patient_id: str, lesion_id: str):
|
| 83 |
+
"""Get a lesion with all its images"""
|
| 84 |
+
store = get_case_store()
|
| 85 |
+
|
| 86 |
+
lesion = store.get_lesion(patient_id, lesion_id)
|
| 87 |
+
if not lesion:
|
| 88 |
+
raise HTTPException(status_code=404, detail="Lesion not found")
|
| 89 |
+
|
| 90 |
+
images = store.list_images(patient_id, lesion_id)
|
| 91 |
+
|
| 92 |
+
return {
|
| 93 |
+
"lesion": {
|
| 94 |
+
**asdict(lesion),
|
| 95 |
+
"image_count": len(images),
|
| 96 |
+
"images": [asdict(img) for img in images]
|
| 97 |
+
}
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
@router.patch("/{patient_id}/lesions/{lesion_id}")
|
| 102 |
+
def update_lesion(patient_id: str, lesion_id: str, req: UpdateLesionRequest):
|
| 103 |
+
"""Update a lesion's name or location"""
|
| 104 |
+
store = get_case_store()
|
| 105 |
+
|
| 106 |
+
lesion = store.get_lesion(patient_id, lesion_id)
|
| 107 |
+
if not lesion:
|
| 108 |
+
raise HTTPException(status_code=404, detail="Lesion not found")
|
| 109 |
+
|
| 110 |
+
store.update_lesion(patient_id, lesion_id, req.name, req.location)
|
| 111 |
+
|
| 112 |
+
# Return updated lesion
|
| 113 |
+
lesion = store.get_lesion(patient_id, lesion_id)
|
| 114 |
+
images = store.list_images(patient_id, lesion_id)
|
| 115 |
+
|
| 116 |
+
return {
|
| 117 |
+
"lesion": {
|
| 118 |
+
**asdict(lesion),
|
| 119 |
+
"image_count": len(images),
|
| 120 |
+
"images": [asdict(img) for img in images]
|
| 121 |
+
}
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
@router.delete("/{patient_id}/lesions/{lesion_id}")
|
| 126 |
+
def delete_lesion(patient_id: str, lesion_id: str):
|
| 127 |
+
"""Delete a lesion and all its images"""
|
| 128 |
+
store = get_case_store()
|
| 129 |
+
|
| 130 |
+
lesion = store.get_lesion(patient_id, lesion_id)
|
| 131 |
+
if not lesion:
|
| 132 |
+
raise HTTPException(status_code=404, detail="Lesion not found")
|
| 133 |
+
|
| 134 |
+
store.delete_lesion(patient_id, lesion_id)
|
| 135 |
+
return {"success": True}
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
# -------------------------------------------------------------------------
|
| 139 |
+
# Image CRUD
|
| 140 |
+
# -------------------------------------------------------------------------
|
| 141 |
+
|
| 142 |
+
@router.post("/{patient_id}/lesions/{lesion_id}/images")
|
| 143 |
+
async def upload_image(patient_id: str, lesion_id: str, image: UploadFile = File(...)):
|
| 144 |
+
"""Upload a new image to a lesion's timeline"""
|
| 145 |
+
store = get_case_store()
|
| 146 |
+
|
| 147 |
+
lesion = store.get_lesion(patient_id, lesion_id)
|
| 148 |
+
if not lesion:
|
| 149 |
+
raise HTTPException(status_code=404, detail="Lesion not found")
|
| 150 |
+
|
| 151 |
+
try:
|
| 152 |
+
# Create image record
|
| 153 |
+
img_record = store.add_image(patient_id, lesion_id)
|
| 154 |
+
|
| 155 |
+
# Save the actual image file
|
| 156 |
+
pil_image = Image.open(io.BytesIO(await image.read())).convert("RGB")
|
| 157 |
+
image_path = store.save_lesion_image(patient_id, lesion_id, img_record.id, pil_image)
|
| 158 |
+
|
| 159 |
+
# Update image record with path
|
| 160 |
+
store.update_image(patient_id, lesion_id, img_record.id, image_path=image_path)
|
| 161 |
+
|
| 162 |
+
# Return updated record
|
| 163 |
+
img_record = store.get_image(patient_id, lesion_id, img_record.id)
|
| 164 |
+
return {"image": asdict(img_record)}
|
| 165 |
+
|
| 166 |
+
except Exception as e:
|
| 167 |
+
raise HTTPException(status_code=400, detail=f"Failed to upload image: {e}")
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
@router.get("/{patient_id}/lesions/{lesion_id}/images/{image_id}")
|
| 171 |
+
def get_image_record(patient_id: str, lesion_id: str, image_id: str):
|
| 172 |
+
"""Get an image record"""
|
| 173 |
+
store = get_case_store()
|
| 174 |
+
|
| 175 |
+
img = store.get_image(patient_id, lesion_id, image_id)
|
| 176 |
+
if not img:
|
| 177 |
+
raise HTTPException(status_code=404, detail="Image not found")
|
| 178 |
+
|
| 179 |
+
return {"image": asdict(img)}
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
@router.get("/{patient_id}/lesions/{lesion_id}/images/{image_id}/file")
|
| 183 |
+
def get_image_file(patient_id: str, lesion_id: str, image_id: str):
|
| 184 |
+
"""Get the actual image file"""
|
| 185 |
+
store = get_case_store()
|
| 186 |
+
|
| 187 |
+
img = store.get_image(patient_id, lesion_id, image_id)
|
| 188 |
+
if not img or not img.image_path:
|
| 189 |
+
raise HTTPException(status_code=404, detail="Image not found")
|
| 190 |
+
|
| 191 |
+
path = Path(img.image_path)
|
| 192 |
+
if not path.exists():
|
| 193 |
+
raise HTTPException(status_code=404, detail="Image file not found")
|
| 194 |
+
|
| 195 |
+
return FileResponse(str(path), media_type="image/png")
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
@router.get("/{patient_id}/lesions/{lesion_id}/images/{image_id}/gradcam")
|
| 199 |
+
def get_gradcam_file(patient_id: str, lesion_id: str, image_id: str):
|
| 200 |
+
"""Get the GradCAM visualization for an image"""
|
| 201 |
+
store = get_case_store()
|
| 202 |
+
|
| 203 |
+
img = store.get_image(patient_id, lesion_id, image_id)
|
| 204 |
+
if not img or not img.gradcam_path:
|
| 205 |
+
raise HTTPException(status_code=404, detail="GradCAM not found")
|
| 206 |
+
|
| 207 |
+
path = Path(img.gradcam_path)
|
| 208 |
+
if not path.exists():
|
| 209 |
+
raise HTTPException(status_code=404, detail="GradCAM file not found")
|
| 210 |
+
|
| 211 |
+
return FileResponse(str(path), media_type="image/png")
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
# -------------------------------------------------------------------------
|
| 215 |
+
# Chat
|
| 216 |
+
# -------------------------------------------------------------------------
|
| 217 |
+
|
| 218 |
+
@router.get("/{patient_id}/lesions/{lesion_id}/chat")
|
| 219 |
+
def get_chat_history(patient_id: str, lesion_id: str):
|
| 220 |
+
"""Get chat history for a lesion"""
|
| 221 |
+
store = get_case_store()
|
| 222 |
+
|
| 223 |
+
lesion = store.get_lesion(patient_id, lesion_id)
|
| 224 |
+
if not lesion:
|
| 225 |
+
raise HTTPException(status_code=404, detail="Lesion not found")
|
| 226 |
+
|
| 227 |
+
messages = store.get_chat_history(patient_id, lesion_id)
|
| 228 |
+
return {"messages": [asdict(m) for m in messages]}
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
@router.delete("/{patient_id}/lesions/{lesion_id}/chat")
|
| 232 |
+
def clear_chat_history(patient_id: str, lesion_id: str):
|
| 233 |
+
"""Clear chat history for a lesion"""
|
| 234 |
+
store = get_case_store()
|
| 235 |
+
|
| 236 |
+
lesion = store.get_lesion(patient_id, lesion_id)
|
| 237 |
+
if not lesion:
|
| 238 |
+
raise HTTPException(status_code=404, detail="Lesion not found")
|
| 239 |
+
|
| 240 |
+
store.clear_chat_history(patient_id, lesion_id)
|
| 241 |
+
return {"success": True}
|
backend/routes/patients.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Patient Routes - CRUD for patients
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from fastapi import APIRouter, HTTPException
|
| 6 |
+
from pydantic import BaseModel
|
| 7 |
+
from dataclasses import asdict
|
| 8 |
+
|
| 9 |
+
from data.case_store import get_case_store
|
| 10 |
+
|
| 11 |
+
router = APIRouter()
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class CreatePatientRequest(BaseModel):
|
| 15 |
+
name: str
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@router.get("")
|
| 19 |
+
def list_patients():
|
| 20 |
+
"""List all patients with lesion counts"""
|
| 21 |
+
store = get_case_store()
|
| 22 |
+
patients = store.list_patients()
|
| 23 |
+
|
| 24 |
+
result = []
|
| 25 |
+
for p in patients:
|
| 26 |
+
result.append({
|
| 27 |
+
**asdict(p),
|
| 28 |
+
"lesion_count": store.get_patient_lesion_count(p.id)
|
| 29 |
+
})
|
| 30 |
+
|
| 31 |
+
return {"patients": result}
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@router.post("")
|
| 35 |
+
def create_patient(req: CreatePatientRequest):
|
| 36 |
+
"""Create a new patient"""
|
| 37 |
+
store = get_case_store()
|
| 38 |
+
patient = store.create_patient(req.name)
|
| 39 |
+
return {
|
| 40 |
+
"patient": {
|
| 41 |
+
**asdict(patient),
|
| 42 |
+
"lesion_count": 0
|
| 43 |
+
}
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@router.get("/{patient_id}")
|
| 48 |
+
def get_patient(patient_id: str):
|
| 49 |
+
"""Get a patient by ID"""
|
| 50 |
+
store = get_case_store()
|
| 51 |
+
patient = store.get_patient(patient_id)
|
| 52 |
+
if not patient:
|
| 53 |
+
raise HTTPException(status_code=404, detail="Patient not found")
|
| 54 |
+
|
| 55 |
+
return {
|
| 56 |
+
"patient": {
|
| 57 |
+
**asdict(patient),
|
| 58 |
+
"lesion_count": store.get_patient_lesion_count(patient_id)
|
| 59 |
+
}
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@router.delete("/{patient_id}")
|
| 64 |
+
def delete_patient(patient_id: str):
|
| 65 |
+
"""Delete a patient and all their lesions"""
|
| 66 |
+
store = get_case_store()
|
| 67 |
+
patient = store.get_patient(patient_id)
|
| 68 |
+
if not patient:
|
| 69 |
+
raise HTTPException(status_code=404, detail="Patient not found")
|
| 70 |
+
|
| 71 |
+
store.delete_patient(patient_id)
|
| 72 |
+
return {"success": True}
|
backend/services/__init__.py
ADDED
|
File without changes
|
backend/services/analysis_service.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Analysis Service - Wraps MedGemmaAgent for API use
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from dataclasses import asdict
|
| 7 |
+
from typing import Optional, Generator
|
| 8 |
+
|
| 9 |
+
from models.medgemma_agent import MedGemmaAgent
|
| 10 |
+
from data.case_store import get_case_store
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class AnalysisService:
|
| 14 |
+
"""Singleton service for managing analysis operations"""
|
| 15 |
+
|
| 16 |
+
_instance = None
|
| 17 |
+
|
| 18 |
+
def __init__(self):
|
| 19 |
+
self.agent = MedGemmaAgent(verbose=True)
|
| 20 |
+
self.store = get_case_store()
|
| 21 |
+
self._loaded = False
|
| 22 |
+
|
| 23 |
+
def _ensure_loaded(self):
|
| 24 |
+
"""Lazy load the ML models"""
|
| 25 |
+
if not self._loaded:
|
| 26 |
+
self.agent.load_model()
|
| 27 |
+
self._loaded = True
|
| 28 |
+
|
| 29 |
+
def analyze(
|
| 30 |
+
self,
|
| 31 |
+
patient_id: str,
|
| 32 |
+
lesion_id: str,
|
| 33 |
+
image_id: str,
|
| 34 |
+
question: Optional[str] = None
|
| 35 |
+
) -> Generator[str, None, None]:
|
| 36 |
+
"""Run analysis on an image, yielding streaming chunks"""
|
| 37 |
+
self._ensure_loaded()
|
| 38 |
+
|
| 39 |
+
image = self.store.get_image(patient_id, lesion_id, image_id)
|
| 40 |
+
if not image or not image.image_path:
|
| 41 |
+
yield "[ERROR]No image uploaded[/ERROR]"
|
| 42 |
+
return
|
| 43 |
+
|
| 44 |
+
# Update stage
|
| 45 |
+
self.store.update_image(patient_id, lesion_id, image_id, stage="analyzing")
|
| 46 |
+
|
| 47 |
+
# Reset agent state for new analysis
|
| 48 |
+
self.agent.reset_state()
|
| 49 |
+
|
| 50 |
+
# Run analysis with question
|
| 51 |
+
for chunk in self.agent.analyze_image_stream(image.image_path, question=question or ""):
|
| 52 |
+
yield chunk
|
| 53 |
+
|
| 54 |
+
# Save diagnosis after analysis
|
| 55 |
+
if self.agent.last_diagnosis:
|
| 56 |
+
analysis_data = {
|
| 57 |
+
"diagnosis": self.agent.last_diagnosis["predictions"][0]["class"],
|
| 58 |
+
"full_name": self.agent.last_diagnosis["predictions"][0]["full_name"],
|
| 59 |
+
"confidence": self.agent.last_diagnosis["predictions"][0]["probability"],
|
| 60 |
+
"all_predictions": self.agent.last_diagnosis["predictions"]
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
# Save MONET features if available
|
| 64 |
+
if self.agent.last_monet_result:
|
| 65 |
+
analysis_data["monet_features"] = self.agent.last_monet_result.get("features", {})
|
| 66 |
+
|
| 67 |
+
self.store.update_image(
|
| 68 |
+
patient_id, lesion_id, image_id,
|
| 69 |
+
stage="awaiting_confirmation",
|
| 70 |
+
analysis=analysis_data
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
def confirm(
|
| 74 |
+
self,
|
| 75 |
+
patient_id: str,
|
| 76 |
+
lesion_id: str,
|
| 77 |
+
image_id: str,
|
| 78 |
+
confirmed: bool,
|
| 79 |
+
feedback: Optional[str] = None
|
| 80 |
+
) -> Generator[str, None, None]:
|
| 81 |
+
"""Confirm diagnosis and generate management guidance"""
|
| 82 |
+
for chunk in self.agent.generate_management_guidance(confirmed, feedback):
|
| 83 |
+
yield chunk
|
| 84 |
+
|
| 85 |
+
# Update stage to complete
|
| 86 |
+
self.store.update_image(patient_id, lesion_id, image_id, stage="complete")
|
| 87 |
+
|
| 88 |
+
def chat_followup(
|
| 89 |
+
self,
|
| 90 |
+
patient_id: str,
|
| 91 |
+
lesion_id: str,
|
| 92 |
+
message: str
|
| 93 |
+
) -> Generator[str, None, None]:
|
| 94 |
+
"""Handle follow-up chat messages"""
|
| 95 |
+
# Save user message
|
| 96 |
+
self.store.add_chat_message(patient_id, lesion_id, "user", message)
|
| 97 |
+
|
| 98 |
+
# Generate response
|
| 99 |
+
response = ""
|
| 100 |
+
for chunk in self.agent.chat_followup(message):
|
| 101 |
+
response += chunk
|
| 102 |
+
yield chunk
|
| 103 |
+
|
| 104 |
+
# Save assistant response
|
| 105 |
+
self.store.add_chat_message(patient_id, lesion_id, "assistant", response)
|
| 106 |
+
|
| 107 |
+
def get_chat_history(self, patient_id: str, lesion_id: str):
|
| 108 |
+
"""Get chat history for a lesion"""
|
| 109 |
+
messages = self.store.get_chat_history(patient_id, lesion_id)
|
| 110 |
+
return [asdict(m) for m in messages]
|
| 111 |
+
|
| 112 |
+
def compare_images(
|
| 113 |
+
self,
|
| 114 |
+
patient_id: str,
|
| 115 |
+
lesion_id: str,
|
| 116 |
+
previous_image_path: str,
|
| 117 |
+
current_image_path: str,
|
| 118 |
+
current_image_id: str
|
| 119 |
+
) -> Generator[str, None, None]:
|
| 120 |
+
"""Compare two images and assess changes"""
|
| 121 |
+
self._ensure_loaded()
|
| 122 |
+
|
| 123 |
+
# Run comparison
|
| 124 |
+
comparison_result = None
|
| 125 |
+
for chunk in self.agent.compare_followup_images(previous_image_path, current_image_path):
|
| 126 |
+
yield chunk
|
| 127 |
+
|
| 128 |
+
# Extract comparison status from agent if available
|
| 129 |
+
# Default to STABLE if we can't determine
|
| 130 |
+
comparison_data = {
|
| 131 |
+
"status": "STABLE",
|
| 132 |
+
"summary": "Comparison complete"
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
# Update the current image with comparison data
|
| 136 |
+
self.store.update_image(
|
| 137 |
+
patient_id, lesion_id, current_image_id,
|
| 138 |
+
comparison=comparison_data
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def get_analysis_service() -> AnalysisService:
|
| 143 |
+
"""Get or create AnalysisService singleton"""
|
| 144 |
+
if AnalysisService._instance is None:
|
| 145 |
+
AnalysisService._instance = AnalysisService()
|
| 146 |
+
return AnalysisService._instance
|
backend/services/chat_service.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Chat Service - Patient-level chat with tool dispatch and streaming
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import io
|
| 6 |
+
import re
|
| 7 |
+
import uuid
|
| 8 |
+
from typing import Generator, Optional
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from PIL import Image as PILImage
|
| 11 |
+
|
| 12 |
+
from data.case_store import get_case_store
|
| 13 |
+
from backend.services.analysis_service import get_analysis_service
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def _extract_response_text(raw: str) -> str:
|
| 17 |
+
"""Pull clean text out of [RESPONSE]...[/RESPONSE]; strip all other tags."""
|
| 18 |
+
# Grab the RESPONSE block first
|
| 19 |
+
match = re.search(r'\[RESPONSE\](.*?)\[/RESPONSE\]', raw, re.DOTALL)
|
| 20 |
+
if match:
|
| 21 |
+
return match.group(1).strip()
|
| 22 |
+
# Fallback: strip every known markup tag
|
| 23 |
+
clean = re.sub(
|
| 24 |
+
r'\[(STAGE:[^\]]+|THINKING|RESPONSE|/RESPONSE|/THINKING|/STAGE'
|
| 25 |
+
r'|ERROR|/ERROR|RESULT|/RESULT|CONFIRM:\d+|/CONFIRM)\]',
|
| 26 |
+
'', raw
|
| 27 |
+
)
|
| 28 |
+
return clean.strip()
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class ChatService:
|
| 32 |
+
_instance = None
|
| 33 |
+
|
| 34 |
+
def __init__(self):
|
| 35 |
+
self.store = get_case_store()
|
| 36 |
+
|
| 37 |
+
def _get_image_url(self, patient_id: str, lesion_id: str, image_id: str) -> str:
|
| 38 |
+
return f"/uploads/{patient_id}/{lesion_id}/{image_id}/image.png"
|
| 39 |
+
|
| 40 |
+
def stream_chat(
|
| 41 |
+
self,
|
| 42 |
+
patient_id: str,
|
| 43 |
+
content: str,
|
| 44 |
+
image_bytes: Optional[bytes] = None,
|
| 45 |
+
) -> Generator[dict, None, None]:
|
| 46 |
+
"""Main chat handler — yields SSE event dicts."""
|
| 47 |
+
analysis_service = get_analysis_service()
|
| 48 |
+
|
| 49 |
+
if image_bytes:
|
| 50 |
+
# ----------------------------------------------------------------
|
| 51 |
+
# Image path: analyze (and optionally compare).
|
| 52 |
+
# We do NOT stream the raw verbose analysis text to the chat bubble —
|
| 53 |
+
# the tool card IS the display artefact. We accumulate the text
|
| 54 |
+
# internally, extract the clean [RESPONSE] block, and put it in
|
| 55 |
+
# tool_result.summary so the expanded card can show it.
|
| 56 |
+
# ----------------------------------------------------------------
|
| 57 |
+
lesion = self.store.get_or_create_chat_lesion(patient_id)
|
| 58 |
+
|
| 59 |
+
img_record = self.store.add_image(patient_id, lesion.id)
|
| 60 |
+
pil_image = PILImage.open(io.BytesIO(image_bytes)).convert("RGB")
|
| 61 |
+
abs_path = self.store.save_lesion_image(
|
| 62 |
+
patient_id, lesion.id, img_record.id, pil_image
|
| 63 |
+
)
|
| 64 |
+
self.store.update_image(patient_id, lesion.id, img_record.id, image_path=abs_path)
|
| 65 |
+
|
| 66 |
+
user_image_url = self._get_image_url(patient_id, lesion.id, img_record.id)
|
| 67 |
+
self.store.add_patient_chat_message(
|
| 68 |
+
patient_id, "user", content, image_url=user_image_url
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
# ---- tool: analyze_image ----------------------------------------
|
| 72 |
+
call_id = f"tc-{uuid.uuid4().hex[:6]}"
|
| 73 |
+
yield {"type": "tool_start", "tool": "analyze_image", "call_id": call_id}
|
| 74 |
+
|
| 75 |
+
analysis_text = ""
|
| 76 |
+
for chunk in analysis_service.analyze(patient_id, lesion.id, img_record.id):
|
| 77 |
+
yield {"type": "text", "content": chunk}
|
| 78 |
+
analysis_text += chunk
|
| 79 |
+
|
| 80 |
+
updated_img = self.store.get_image(patient_id, lesion.id, img_record.id)
|
| 81 |
+
analysis_result: dict = {
|
| 82 |
+
"image_url": user_image_url,
|
| 83 |
+
"summary": _extract_response_text(analysis_text),
|
| 84 |
+
"diagnosis": None,
|
| 85 |
+
"full_name": None,
|
| 86 |
+
"confidence": None,
|
| 87 |
+
"all_predictions": [],
|
| 88 |
+
}
|
| 89 |
+
if updated_img and updated_img.analysis:
|
| 90 |
+
a = updated_img.analysis
|
| 91 |
+
analysis_result.update({
|
| 92 |
+
"diagnosis": a.get("diagnosis"),
|
| 93 |
+
"full_name": a.get("full_name"),
|
| 94 |
+
"confidence": a.get("confidence"),
|
| 95 |
+
"all_predictions": a.get("all_predictions", []),
|
| 96 |
+
})
|
| 97 |
+
|
| 98 |
+
yield {
|
| 99 |
+
"type": "tool_result",
|
| 100 |
+
"tool": "analyze_image",
|
| 101 |
+
"call_id": call_id,
|
| 102 |
+
"result": analysis_result,
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
# ---- tool: compare_images (if a previous image exists) ----------
|
| 106 |
+
previous_img = self.store.get_previous_image(patient_id, lesion.id, img_record.id)
|
| 107 |
+
compare_call_id = None
|
| 108 |
+
compare_result = None
|
| 109 |
+
compare_text = ""
|
| 110 |
+
|
| 111 |
+
if (
|
| 112 |
+
previous_img
|
| 113 |
+
and previous_img.image_path
|
| 114 |
+
and Path(previous_img.image_path).exists()
|
| 115 |
+
):
|
| 116 |
+
compare_call_id = f"tc-{uuid.uuid4().hex[:6]}"
|
| 117 |
+
yield {
|
| 118 |
+
"type": "tool_start",
|
| 119 |
+
"tool": "compare_images",
|
| 120 |
+
"call_id": compare_call_id,
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
for chunk in analysis_service.compare_images(
|
| 124 |
+
patient_id,
|
| 125 |
+
lesion.id,
|
| 126 |
+
previous_img.image_path,
|
| 127 |
+
abs_path,
|
| 128 |
+
img_record.id,
|
| 129 |
+
):
|
| 130 |
+
yield {"type": "text", "content": chunk}
|
| 131 |
+
compare_text += chunk
|
| 132 |
+
|
| 133 |
+
updated_img2 = self.store.get_image(patient_id, lesion.id, img_record.id)
|
| 134 |
+
compare_result = {
|
| 135 |
+
"prev_image_url": self._get_image_url(patient_id, lesion.id, previous_img.id),
|
| 136 |
+
"curr_image_url": user_image_url,
|
| 137 |
+
"status_label": "STABLE",
|
| 138 |
+
"feature_changes": {},
|
| 139 |
+
"summary": _extract_response_text(compare_text),
|
| 140 |
+
}
|
| 141 |
+
if updated_img2 and updated_img2.comparison:
|
| 142 |
+
c = updated_img2.comparison
|
| 143 |
+
compare_result.update({
|
| 144 |
+
"status_label": c.get("status", "STABLE"),
|
| 145 |
+
"feature_changes": c.get("feature_changes", {}),
|
| 146 |
+
})
|
| 147 |
+
if c.get("summary"):
|
| 148 |
+
compare_result["summary"] = c["summary"]
|
| 149 |
+
|
| 150 |
+
yield {
|
| 151 |
+
"type": "tool_result",
|
| 152 |
+
"tool": "compare_images",
|
| 153 |
+
"call_id": compare_call_id,
|
| 154 |
+
"result": compare_result,
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
# Save assistant message
|
| 158 |
+
tool_calls_data = [{
|
| 159 |
+
"id": call_id,
|
| 160 |
+
"tool": "analyze_image",
|
| 161 |
+
"status": "complete",
|
| 162 |
+
"result": analysis_result,
|
| 163 |
+
}]
|
| 164 |
+
if compare_call_id and compare_result:
|
| 165 |
+
tool_calls_data.append({
|
| 166 |
+
"id": compare_call_id,
|
| 167 |
+
"tool": "compare_images",
|
| 168 |
+
"status": "complete",
|
| 169 |
+
"result": compare_result,
|
| 170 |
+
})
|
| 171 |
+
|
| 172 |
+
self.store.add_patient_chat_message(
|
| 173 |
+
patient_id, "assistant", analysis_text + compare_text,
|
| 174 |
+
tool_calls=tool_calls_data,
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
else:
|
| 178 |
+
# ----------------------------------------------------------------
|
| 179 |
+
# Text-only chat — stream chunks; tags are stripped on the frontend
|
| 180 |
+
# ----------------------------------------------------------------
|
| 181 |
+
self.store.add_patient_chat_message(patient_id, "user", content)
|
| 182 |
+
|
| 183 |
+
analysis_service._ensure_loaded()
|
| 184 |
+
response_text = ""
|
| 185 |
+
for chunk in analysis_service.agent.chat_followup(content):
|
| 186 |
+
yield {"type": "text", "content": chunk}
|
| 187 |
+
response_text += chunk
|
| 188 |
+
|
| 189 |
+
self.store.add_patient_chat_message(
|
| 190 |
+
patient_id, "assistant", _extract_response_text(response_text)
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def get_chat_service() -> ChatService:
|
| 195 |
+
if ChatService._instance is None:
|
| 196 |
+
ChatService._instance = ChatService()
|
| 197 |
+
return ChatService._instance
|
data/case_store.py
ADDED
|
@@ -0,0 +1,507 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Case Store - JSON-based persistence for patients, lesions, and images
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
import uuid
|
| 7 |
+
import shutil
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
from typing import List, Dict, Optional, Any
|
| 11 |
+
from dataclasses import dataclass, field, asdict
|
| 12 |
+
from PIL import Image as PILImage
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class ChatMessage:
|
| 17 |
+
role: str # "user" or "assistant"
|
| 18 |
+
content: str
|
| 19 |
+
timestamp: str = field(default_factory=lambda: datetime.utcnow().isoformat())
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class LesionImage:
|
| 24 |
+
"""A single image capture of a lesion at a point in time"""
|
| 25 |
+
id: str
|
| 26 |
+
lesion_id: str
|
| 27 |
+
timestamp: str = field(default_factory=lambda: datetime.utcnow().isoformat())
|
| 28 |
+
image_path: Optional[str] = None
|
| 29 |
+
gradcam_path: Optional[str] = None
|
| 30 |
+
analysis: Optional[Dict[str, Any]] = None # {diagnosis, confidence, monet_features}
|
| 31 |
+
comparison: Optional[Dict[str, Any]] = None # {status, feature_changes, summary}
|
| 32 |
+
is_original: bool = False
|
| 33 |
+
stage: str = "pending" # pending, analyzing, complete, error
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@dataclass
|
| 37 |
+
class Lesion:
|
| 38 |
+
"""A tracked lesion that can have multiple images over time"""
|
| 39 |
+
id: str
|
| 40 |
+
patient_id: str
|
| 41 |
+
name: str # User-provided label (e.g., "Left shoulder mole")
|
| 42 |
+
location: str = "" # Body location
|
| 43 |
+
created_at: str = field(default_factory=lambda: datetime.utcnow().isoformat())
|
| 44 |
+
chat_history: List[Dict] = field(default_factory=list)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@dataclass
|
| 48 |
+
class Patient:
|
| 49 |
+
"""A patient who can have multiple lesions"""
|
| 50 |
+
id: str
|
| 51 |
+
name: str
|
| 52 |
+
created_at: str = field(default_factory=lambda: datetime.utcnow().isoformat())
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class CaseStore:
|
| 56 |
+
"""JSON-based persistence for patients, lesions, and images"""
|
| 57 |
+
|
| 58 |
+
def __init__(self, data_dir: str = None):
|
| 59 |
+
if data_dir is None:
|
| 60 |
+
data_dir = Path(__file__).parent
|
| 61 |
+
self.data_dir = Path(data_dir)
|
| 62 |
+
self.patients_file = self.data_dir / "patients.json"
|
| 63 |
+
self.lesions_dir = self.data_dir / "lesions"
|
| 64 |
+
self.uploads_dir = self.data_dir / "uploads"
|
| 65 |
+
|
| 66 |
+
# Ensure directories exist
|
| 67 |
+
self.lesions_dir.mkdir(parents=True, exist_ok=True)
|
| 68 |
+
self.uploads_dir.mkdir(parents=True, exist_ok=True)
|
| 69 |
+
|
| 70 |
+
# Initialize patients file if needed
|
| 71 |
+
if not self.patients_file.exists():
|
| 72 |
+
self._init_patients_file()
|
| 73 |
+
|
| 74 |
+
def _init_patients_file(self):
|
| 75 |
+
"""Initialize patients file"""
|
| 76 |
+
data = {"patients": []}
|
| 77 |
+
with open(self.patients_file, 'w') as f:
|
| 78 |
+
json.dump(data, f, indent=2)
|
| 79 |
+
|
| 80 |
+
def _load_patients_data(self) -> Dict:
|
| 81 |
+
"""Load patients JSON file"""
|
| 82 |
+
with open(self.patients_file, 'r') as f:
|
| 83 |
+
return json.load(f)
|
| 84 |
+
|
| 85 |
+
def _save_patients_data(self, data: Dict):
|
| 86 |
+
"""Save patients JSON file"""
|
| 87 |
+
with open(self.patients_file, 'w') as f:
|
| 88 |
+
json.dump(data, f, indent=2)
|
| 89 |
+
|
| 90 |
+
# -------------------------------------------------------------------------
|
| 91 |
+
# Patient Methods
|
| 92 |
+
# -------------------------------------------------------------------------
|
| 93 |
+
|
| 94 |
+
def list_patients(self) -> List[Patient]:
|
| 95 |
+
"""List all patients"""
|
| 96 |
+
data = self._load_patients_data()
|
| 97 |
+
return [Patient(**p) for p in data.get("patients", [])]
|
| 98 |
+
|
| 99 |
+
def get_patient(self, patient_id: str) -> Optional[Patient]:
|
| 100 |
+
"""Get a patient by ID"""
|
| 101 |
+
data = self._load_patients_data()
|
| 102 |
+
for p in data.get("patients", []):
|
| 103 |
+
if p["id"] == patient_id:
|
| 104 |
+
return Patient(**p)
|
| 105 |
+
return None
|
| 106 |
+
|
| 107 |
+
def create_patient(self, name: str) -> Patient:
|
| 108 |
+
"""Create a new patient"""
|
| 109 |
+
patient = Patient(
|
| 110 |
+
id=f"patient-{uuid.uuid4().hex[:8]}",
|
| 111 |
+
name=name
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
data = self._load_patients_data()
|
| 115 |
+
data["patients"].append(asdict(patient))
|
| 116 |
+
self._save_patients_data(data)
|
| 117 |
+
|
| 118 |
+
# Create directory for this patient's lesions
|
| 119 |
+
(self.lesions_dir / patient.id).mkdir(exist_ok=True)
|
| 120 |
+
|
| 121 |
+
return patient
|
| 122 |
+
|
| 123 |
+
def delete_patient(self, patient_id: str):
|
| 124 |
+
"""Delete a patient and all their lesions"""
|
| 125 |
+
data = self._load_patients_data()
|
| 126 |
+
data["patients"] = [p for p in data["patients"] if p["id"] != patient_id]
|
| 127 |
+
self._save_patients_data(data)
|
| 128 |
+
|
| 129 |
+
# Delete lesion files
|
| 130 |
+
patient_lesions_dir = self.lesions_dir / patient_id
|
| 131 |
+
if patient_lesions_dir.exists():
|
| 132 |
+
shutil.rmtree(patient_lesions_dir)
|
| 133 |
+
|
| 134 |
+
# Delete uploads
|
| 135 |
+
patient_uploads_dir = self.uploads_dir / patient_id
|
| 136 |
+
if patient_uploads_dir.exists():
|
| 137 |
+
shutil.rmtree(patient_uploads_dir)
|
| 138 |
+
|
| 139 |
+
# Delete patient chat history
|
| 140 |
+
patient_chat_file = self.data_dir / "patient_chats" / f"{patient_id}.json"
|
| 141 |
+
if patient_chat_file.exists():
|
| 142 |
+
patient_chat_file.unlink()
|
| 143 |
+
|
| 144 |
+
def get_patient_lesion_count(self, patient_id: str) -> int:
|
| 145 |
+
"""Get number of lesions for a patient"""
|
| 146 |
+
return len(self.list_lesions(patient_id))
|
| 147 |
+
|
| 148 |
+
# -------------------------------------------------------------------------
|
| 149 |
+
# Lesion Methods
|
| 150 |
+
# -------------------------------------------------------------------------
|
| 151 |
+
|
| 152 |
+
def _get_lesion_path(self, patient_id: str, lesion_id: str) -> Path:
|
| 153 |
+
"""Get path to lesion JSON file"""
|
| 154 |
+
return self.lesions_dir / patient_id / f"{lesion_id}.json"
|
| 155 |
+
|
| 156 |
+
def list_lesions(self, patient_id: str) -> List[Lesion]:
|
| 157 |
+
"""List all lesions for a patient"""
|
| 158 |
+
patient_dir = self.lesions_dir / patient_id
|
| 159 |
+
if not patient_dir.exists():
|
| 160 |
+
return []
|
| 161 |
+
|
| 162 |
+
lesions = []
|
| 163 |
+
for f in sorted(patient_dir.glob("*.json")):
|
| 164 |
+
with open(f, 'r') as fp:
|
| 165 |
+
data = json.load(fp)
|
| 166 |
+
# Only load lesion data, not images
|
| 167 |
+
lesion_data = {k: v for k, v in data.items() if k != 'images'}
|
| 168 |
+
lesions.append(Lesion(**lesion_data))
|
| 169 |
+
|
| 170 |
+
lesions.sort(key=lambda x: x.created_at)
|
| 171 |
+
return lesions
|
| 172 |
+
|
| 173 |
+
def get_lesion(self, patient_id: str, lesion_id: str) -> Optional[Lesion]:
|
| 174 |
+
"""Get a lesion by ID"""
|
| 175 |
+
path = self._get_lesion_path(patient_id, lesion_id)
|
| 176 |
+
if not path.exists():
|
| 177 |
+
return None
|
| 178 |
+
|
| 179 |
+
with open(path, 'r') as f:
|
| 180 |
+
data = json.load(f)
|
| 181 |
+
lesion_data = {k: v for k, v in data.items() if k != 'images'}
|
| 182 |
+
return Lesion(**lesion_data)
|
| 183 |
+
|
| 184 |
+
def create_lesion(self, patient_id: str, name: str, location: str = "") -> Lesion:
|
| 185 |
+
"""Create a new lesion for a patient"""
|
| 186 |
+
lesion = Lesion(
|
| 187 |
+
id=f"lesion-{uuid.uuid4().hex[:8]}",
|
| 188 |
+
patient_id=patient_id,
|
| 189 |
+
name=name,
|
| 190 |
+
location=location
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
# Ensure patient directory exists
|
| 194 |
+
patient_dir = self.lesions_dir / patient_id
|
| 195 |
+
patient_dir.mkdir(exist_ok=True)
|
| 196 |
+
|
| 197 |
+
# Save lesion with empty images array
|
| 198 |
+
self._save_lesion_data(patient_id, lesion.id, {
|
| 199 |
+
**asdict(lesion),
|
| 200 |
+
"images": []
|
| 201 |
+
})
|
| 202 |
+
|
| 203 |
+
return lesion
|
| 204 |
+
|
| 205 |
+
def _save_lesion_data(self, patient_id: str, lesion_id: str, data: Dict):
|
| 206 |
+
"""Save lesion data to JSON file"""
|
| 207 |
+
path = self._get_lesion_path(patient_id, lesion_id)
|
| 208 |
+
with open(path, 'w') as f:
|
| 209 |
+
json.dump(data, f, indent=2)
|
| 210 |
+
|
| 211 |
+
def _load_lesion_data(self, patient_id: str, lesion_id: str) -> Optional[Dict]:
|
| 212 |
+
"""Load full lesion data including images"""
|
| 213 |
+
path = self._get_lesion_path(patient_id, lesion_id)
|
| 214 |
+
if not path.exists():
|
| 215 |
+
return None
|
| 216 |
+
|
| 217 |
+
with open(path, 'r') as f:
|
| 218 |
+
return json.load(f)
|
| 219 |
+
|
| 220 |
+
def delete_lesion(self, patient_id: str, lesion_id: str):
|
| 221 |
+
"""Delete a lesion and all its images"""
|
| 222 |
+
path = self._get_lesion_path(patient_id, lesion_id)
|
| 223 |
+
if path.exists():
|
| 224 |
+
path.unlink()
|
| 225 |
+
|
| 226 |
+
# Delete uploads for this lesion
|
| 227 |
+
lesion_uploads_dir = self.uploads_dir / patient_id / lesion_id
|
| 228 |
+
if lesion_uploads_dir.exists():
|
| 229 |
+
shutil.rmtree(lesion_uploads_dir)
|
| 230 |
+
|
| 231 |
+
def update_lesion(self, patient_id: str, lesion_id: str, name: str = None, location: str = None):
|
| 232 |
+
"""Update lesion name or location"""
|
| 233 |
+
data = self._load_lesion_data(patient_id, lesion_id)
|
| 234 |
+
if data is None:
|
| 235 |
+
return
|
| 236 |
+
|
| 237 |
+
if name is not None:
|
| 238 |
+
data["name"] = name
|
| 239 |
+
if location is not None:
|
| 240 |
+
data["location"] = location
|
| 241 |
+
|
| 242 |
+
self._save_lesion_data(patient_id, lesion_id, data)
|
| 243 |
+
|
| 244 |
+
# -------------------------------------------------------------------------
|
| 245 |
+
# LesionImage Methods
|
| 246 |
+
# -------------------------------------------------------------------------
|
| 247 |
+
|
| 248 |
+
def list_images(self, patient_id: str, lesion_id: str) -> List[LesionImage]:
|
| 249 |
+
"""List all images for a lesion"""
|
| 250 |
+
data = self._load_lesion_data(patient_id, lesion_id)
|
| 251 |
+
if data is None:
|
| 252 |
+
return []
|
| 253 |
+
|
| 254 |
+
images = [LesionImage(**img) for img in data.get("images", [])]
|
| 255 |
+
images.sort(key=lambda x: x.timestamp)
|
| 256 |
+
return images
|
| 257 |
+
|
| 258 |
+
def get_image(self, patient_id: str, lesion_id: str, image_id: str) -> Optional[LesionImage]:
|
| 259 |
+
"""Get an image by ID"""
|
| 260 |
+
data = self._load_lesion_data(patient_id, lesion_id)
|
| 261 |
+
if data is None:
|
| 262 |
+
return None
|
| 263 |
+
|
| 264 |
+
for img in data.get("images", []):
|
| 265 |
+
if img["id"] == image_id:
|
| 266 |
+
return LesionImage(**img)
|
| 267 |
+
return None
|
| 268 |
+
|
| 269 |
+
def add_image(self, patient_id: str, lesion_id: str) -> LesionImage:
|
| 270 |
+
"""Add a new image to a lesion's timeline"""
|
| 271 |
+
data = self._load_lesion_data(patient_id, lesion_id)
|
| 272 |
+
if data is None:
|
| 273 |
+
raise ValueError(f"Lesion {lesion_id} not found")
|
| 274 |
+
|
| 275 |
+
# Check if this is the first image
|
| 276 |
+
is_first = len(data.get("images", [])) == 0
|
| 277 |
+
|
| 278 |
+
image = LesionImage(
|
| 279 |
+
id=f"img-{uuid.uuid4().hex[:8]}",
|
| 280 |
+
lesion_id=lesion_id,
|
| 281 |
+
is_original=is_first
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
if "images" not in data:
|
| 285 |
+
data["images"] = []
|
| 286 |
+
data["images"].append(asdict(image))
|
| 287 |
+
self._save_lesion_data(patient_id, lesion_id, data)
|
| 288 |
+
|
| 289 |
+
return image
|
| 290 |
+
|
| 291 |
+
def update_image(
|
| 292 |
+
self,
|
| 293 |
+
patient_id: str,
|
| 294 |
+
lesion_id: str,
|
| 295 |
+
image_id: str,
|
| 296 |
+
image_path: str = None,
|
| 297 |
+
gradcam_path: str = None,
|
| 298 |
+
analysis: Dict = None,
|
| 299 |
+
comparison: Dict = None,
|
| 300 |
+
stage: str = None
|
| 301 |
+
):
|
| 302 |
+
"""Update an image's data"""
|
| 303 |
+
data = self._load_lesion_data(patient_id, lesion_id)
|
| 304 |
+
if data is None:
|
| 305 |
+
return
|
| 306 |
+
|
| 307 |
+
for img in data.get("images", []):
|
| 308 |
+
if img["id"] == image_id:
|
| 309 |
+
if image_path is not None:
|
| 310 |
+
img["image_path"] = image_path
|
| 311 |
+
if gradcam_path is not None:
|
| 312 |
+
img["gradcam_path"] = gradcam_path
|
| 313 |
+
if analysis is not None:
|
| 314 |
+
img["analysis"] = analysis
|
| 315 |
+
if comparison is not None:
|
| 316 |
+
img["comparison"] = comparison
|
| 317 |
+
if stage is not None:
|
| 318 |
+
img["stage"] = stage
|
| 319 |
+
break
|
| 320 |
+
|
| 321 |
+
self._save_lesion_data(patient_id, lesion_id, data)
|
| 322 |
+
|
| 323 |
+
def save_lesion_image(
|
| 324 |
+
self,
|
| 325 |
+
patient_id: str,
|
| 326 |
+
lesion_id: str,
|
| 327 |
+
image_id: str,
|
| 328 |
+
image: PILImage.Image,
|
| 329 |
+
filename: str = "image.png"
|
| 330 |
+
) -> str:
|
| 331 |
+
"""Save an uploaded image file, return the path"""
|
| 332 |
+
upload_dir = self.uploads_dir / patient_id / lesion_id / image_id
|
| 333 |
+
upload_dir.mkdir(parents=True, exist_ok=True)
|
| 334 |
+
|
| 335 |
+
image_path = upload_dir / filename
|
| 336 |
+
image.save(image_path)
|
| 337 |
+
|
| 338 |
+
return str(image_path)
|
| 339 |
+
|
| 340 |
+
def get_previous_image(
|
| 341 |
+
self,
|
| 342 |
+
patient_id: str,
|
| 343 |
+
lesion_id: str,
|
| 344 |
+
current_image_id: str
|
| 345 |
+
) -> Optional[LesionImage]:
|
| 346 |
+
"""Get the image before the current one (for comparison)"""
|
| 347 |
+
images = self.list_images(patient_id, lesion_id)
|
| 348 |
+
|
| 349 |
+
for i, img in enumerate(images):
|
| 350 |
+
if img.id == current_image_id and i > 0:
|
| 351 |
+
return images[i - 1]
|
| 352 |
+
return None
|
| 353 |
+
|
| 354 |
+
# -------------------------------------------------------------------------
|
| 355 |
+
# Chat Methods (scoped to lesion)
|
| 356 |
+
# -------------------------------------------------------------------------
|
| 357 |
+
|
| 358 |
+
def add_chat_message(self, patient_id: str, lesion_id: str, role: str, content: str):
|
| 359 |
+
"""Add a chat message to a lesion"""
|
| 360 |
+
data = self._load_lesion_data(patient_id, lesion_id)
|
| 361 |
+
if data is None:
|
| 362 |
+
return
|
| 363 |
+
|
| 364 |
+
message = ChatMessage(role=role, content=content)
|
| 365 |
+
if "chat_history" not in data:
|
| 366 |
+
data["chat_history"] = []
|
| 367 |
+
data["chat_history"].append(asdict(message))
|
| 368 |
+
self._save_lesion_data(patient_id, lesion_id, data)
|
| 369 |
+
|
| 370 |
+
def get_chat_history(self, patient_id: str, lesion_id: str) -> List[ChatMessage]:
|
| 371 |
+
"""Get chat history for a lesion"""
|
| 372 |
+
data = self._load_lesion_data(patient_id, lesion_id)
|
| 373 |
+
if data is None:
|
| 374 |
+
return []
|
| 375 |
+
|
| 376 |
+
return [ChatMessage(**m) for m in data.get("chat_history", [])]
|
| 377 |
+
|
| 378 |
+
def clear_chat_history(self, patient_id: str, lesion_id: str):
|
| 379 |
+
"""Clear chat history for a lesion"""
|
| 380 |
+
data = self._load_lesion_data(patient_id, lesion_id)
|
| 381 |
+
if data is None:
|
| 382 |
+
return
|
| 383 |
+
|
| 384 |
+
data["chat_history"] = []
|
| 385 |
+
self._save_lesion_data(patient_id, lesion_id, data)
|
| 386 |
+
|
| 387 |
+
# -------------------------------------------------------------------------
|
| 388 |
+
# Patient-level Chat Methods
|
| 389 |
+
# -------------------------------------------------------------------------
|
| 390 |
+
|
| 391 |
+
def _get_patient_chat_file(self, patient_id: str) -> Path:
|
| 392 |
+
"""Get path to patient-level chat JSON file"""
|
| 393 |
+
chat_dir = self.data_dir / "patient_chats"
|
| 394 |
+
chat_dir.mkdir(exist_ok=True)
|
| 395 |
+
return chat_dir / f"{patient_id}.json"
|
| 396 |
+
|
| 397 |
+
def get_patient_chat_history(self, patient_id: str) -> List[dict]:
|
| 398 |
+
"""Get chat history for a patient"""
|
| 399 |
+
chat_file = self._get_patient_chat_file(patient_id)
|
| 400 |
+
if not chat_file.exists():
|
| 401 |
+
return []
|
| 402 |
+
with open(chat_file, 'r') as f:
|
| 403 |
+
data = json.load(f)
|
| 404 |
+
return data.get("messages", [])
|
| 405 |
+
|
| 406 |
+
def add_patient_chat_message(
|
| 407 |
+
self,
|
| 408 |
+
patient_id: str,
|
| 409 |
+
role: str,
|
| 410 |
+
content: str,
|
| 411 |
+
image_url: Optional[str] = None,
|
| 412 |
+
tool_calls: Optional[list] = None
|
| 413 |
+
):
|
| 414 |
+
"""Add a message to patient-level chat history"""
|
| 415 |
+
chat_file = self._get_patient_chat_file(patient_id)
|
| 416 |
+
if chat_file.exists():
|
| 417 |
+
with open(chat_file, 'r') as f:
|
| 418 |
+
data = json.load(f)
|
| 419 |
+
else:
|
| 420 |
+
data = {"messages": []}
|
| 421 |
+
|
| 422 |
+
message: Dict[str, Any] = {
|
| 423 |
+
"id": f"msg-{uuid.uuid4().hex[:8]}",
|
| 424 |
+
"role": role,
|
| 425 |
+
"content": content,
|
| 426 |
+
"timestamp": datetime.utcnow().isoformat(),
|
| 427 |
+
}
|
| 428 |
+
if image_url is not None:
|
| 429 |
+
message["image_url"] = image_url
|
| 430 |
+
if tool_calls is not None:
|
| 431 |
+
message["tool_calls"] = tool_calls
|
| 432 |
+
|
| 433 |
+
data["messages"].append(message)
|
| 434 |
+
with open(chat_file, 'w') as f:
|
| 435 |
+
json.dump(data, f, indent=2)
|
| 436 |
+
|
| 437 |
+
def clear_patient_chat_history(self, patient_id: str):
|
| 438 |
+
"""Clear patient-level chat history"""
|
| 439 |
+
chat_file = self._get_patient_chat_file(patient_id)
|
| 440 |
+
with open(chat_file, 'w') as f:
|
| 441 |
+
json.dump({"messages": []}, f)
|
| 442 |
+
|
| 443 |
+
def get_or_create_chat_lesion(self, patient_id: str) -> 'Lesion':
|
| 444 |
+
"""Get or create the internal chat-images lesion for a patient"""
|
| 445 |
+
for lesion in self.list_lesions(patient_id):
|
| 446 |
+
if lesion.name == "__chat_images__":
|
| 447 |
+
return lesion
|
| 448 |
+
return self.create_lesion(patient_id, "__chat_images__", "internal")
|
| 449 |
+
|
| 450 |
+
def get_latest_chat_image(self, patient_id: str) -> Optional['LesionImage']:
|
| 451 |
+
"""Get the most recently analyzed chat image for a patient"""
|
| 452 |
+
lesion = self.get_or_create_chat_lesion(patient_id)
|
| 453 |
+
images = self.list_images(patient_id, lesion.id)
|
| 454 |
+
for img in reversed(images):
|
| 455 |
+
if img.analysis is not None:
|
| 456 |
+
return img
|
| 457 |
+
return None
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
# Singleton instance
|
| 461 |
+
_store_instance = None
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
def get_case_store() -> CaseStore:
|
| 465 |
+
"""Get or create CaseStore singleton"""
|
| 466 |
+
global _store_instance
|
| 467 |
+
if _store_instance is None:
|
| 468 |
+
_store_instance = CaseStore()
|
| 469 |
+
return _store_instance
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
if __name__ == "__main__":
|
| 473 |
+
# Test the store
|
| 474 |
+
store = CaseStore()
|
| 475 |
+
|
| 476 |
+
print("Patients:")
|
| 477 |
+
for patient in store.list_patients():
|
| 478 |
+
print(f" - {patient.id}: {patient.name}")
|
| 479 |
+
|
| 480 |
+
# Create a test patient
|
| 481 |
+
print("\nCreating test patient...")
|
| 482 |
+
patient = store.create_patient("Test Patient")
|
| 483 |
+
print(f" Created: {patient.id}")
|
| 484 |
+
|
| 485 |
+
# Create a lesion
|
| 486 |
+
print("\nCreating lesion...")
|
| 487 |
+
lesion = store.create_lesion(patient.id, "Left shoulder mole", "Left shoulder")
|
| 488 |
+
print(f" Created: {lesion.id}")
|
| 489 |
+
|
| 490 |
+
# Add an image
|
| 491 |
+
print("\nAdding image...")
|
| 492 |
+
image = store.add_image(patient.id, lesion.id)
|
| 493 |
+
print(f" Created: {image.id} (is_original={image.is_original})")
|
| 494 |
+
|
| 495 |
+
# Add another image
|
| 496 |
+
image2 = store.add_image(patient.id, lesion.id)
|
| 497 |
+
print(f" Created: {image2.id} (is_original={image2.is_original})")
|
| 498 |
+
|
| 499 |
+
# List images
|
| 500 |
+
print(f"\nImages for lesion {lesion.id}:")
|
| 501 |
+
for img in store.list_images(patient.id, lesion.id):
|
| 502 |
+
print(f" - {img.id}: original={img.is_original}, stage={img.stage}")
|
| 503 |
+
|
| 504 |
+
# Cleanup
|
| 505 |
+
print("\nCleaning up test patient...")
|
| 506 |
+
store.delete_patient(patient.id)
|
| 507 |
+
print("Done!")
|
frontend/app.py
ADDED
|
@@ -0,0 +1,532 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SkinProAI Frontend - Modular Gradio application
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import gradio as gr
|
| 6 |
+
from typing import Dict, Generator, Optional
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
import sys
|
| 9 |
+
import os
|
| 10 |
+
import re
|
| 11 |
+
import base64
|
| 12 |
+
|
| 13 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 14 |
+
|
| 15 |
+
from data.case_store import get_case_store
|
| 16 |
+
from frontend.components.styles import MAIN_CSS
|
| 17 |
+
from frontend.components.analysis_view import format_output
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# =============================================================================
|
| 21 |
+
# CONFIG
|
| 22 |
+
# =============================================================================
|
| 23 |
+
|
| 24 |
+
class Config:
|
| 25 |
+
APP_TITLE = "SkinProAI"
|
| 26 |
+
SERVER_PORT = int(os.environ.get("GRADIO_SERVER_PORT", "7860"))
|
| 27 |
+
HF_SPACES = os.environ.get("SPACE_ID") is not None
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# =============================================================================
|
| 31 |
+
# AGENT
|
| 32 |
+
# =============================================================================
|
| 33 |
+
|
| 34 |
+
class AnalysisAgent:
|
| 35 |
+
"""Wrapper for the MedGemma analysis agent"""
|
| 36 |
+
|
| 37 |
+
def __init__(self):
|
| 38 |
+
self.model = None
|
| 39 |
+
self.loaded = False
|
| 40 |
+
|
| 41 |
+
def load(self):
|
| 42 |
+
if self.loaded:
|
| 43 |
+
return
|
| 44 |
+
from models.medgemma_agent import MedGemmaAgent
|
| 45 |
+
self.model = MedGemmaAgent(verbose=True)
|
| 46 |
+
self.model.load_model()
|
| 47 |
+
self.loaded = True
|
| 48 |
+
|
| 49 |
+
def analyze(self, image_path: str, question: str = "") -> Generator[str, None, None]:
|
| 50 |
+
if not self.loaded:
|
| 51 |
+
yield "[STAGE:loading]Loading AI models...[/STAGE]\n"
|
| 52 |
+
self.load()
|
| 53 |
+
|
| 54 |
+
for chunk in self.model.analyze_image_stream(image_path, question=question):
|
| 55 |
+
yield chunk
|
| 56 |
+
|
| 57 |
+
def management_guidance(self, confirmed: bool, feedback: str = None) -> Generator[str, None, None]:
|
| 58 |
+
for chunk in self.model.generate_management_guidance(confirmed, feedback):
|
| 59 |
+
yield chunk
|
| 60 |
+
|
| 61 |
+
def followup(self, message: str) -> Generator[str, None, None]:
|
| 62 |
+
if not self.loaded or not self.model.last_diagnosis:
|
| 63 |
+
yield "[ERROR]No analysis context available.[/ERROR]\n"
|
| 64 |
+
return
|
| 65 |
+
for chunk in self.model.chat_followup(message):
|
| 66 |
+
yield chunk
|
| 67 |
+
|
| 68 |
+
def reset(self):
|
| 69 |
+
if self.model:
|
| 70 |
+
self.model.reset_state()
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
agent = AnalysisAgent()
|
| 74 |
+
case_store = get_case_store()
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
# =============================================================================
|
| 78 |
+
# APP
|
| 79 |
+
# =============================================================================
|
| 80 |
+
|
| 81 |
+
with gr.Blocks(title=Config.APP_TITLE, css=MAIN_CSS, theme=gr.themes.Soft()) as app:
|
| 82 |
+
|
| 83 |
+
# =========================================================================
|
| 84 |
+
# STATE
|
| 85 |
+
# =========================================================================
|
| 86 |
+
state = gr.State({
|
| 87 |
+
"page": "patient_select", # patient_select | analysis
|
| 88 |
+
"case_id": None,
|
| 89 |
+
"instance_id": None,
|
| 90 |
+
"output": "",
|
| 91 |
+
"gradcam_base64": None
|
| 92 |
+
})
|
| 93 |
+
|
| 94 |
+
# =========================================================================
|
| 95 |
+
# PAGE 1: PATIENT SELECTION
|
| 96 |
+
# =========================================================================
|
| 97 |
+
with gr.Group(visible=True, elem_classes=["patient-select-container"]) as page_patient:
|
| 98 |
+
gr.Markdown("# SkinProAI", elem_classes=["patient-select-title"])
|
| 99 |
+
gr.Markdown("Select a patient to continue or create a new case", elem_classes=["patient-select-subtitle"])
|
| 100 |
+
|
| 101 |
+
with gr.Row(elem_classes=["patient-grid"]):
|
| 102 |
+
btn_demo_melanoma = gr.Button("Demo: Melanocytic Lesion", elem_classes=["patient-card"])
|
| 103 |
+
btn_demo_ak = gr.Button("Demo: Actinic Keratosis", elem_classes=["patient-card"])
|
| 104 |
+
btn_new_patient = gr.Button("+ New Patient", variant="primary", elem_classes=["new-patient-btn"])
|
| 105 |
+
|
| 106 |
+
# =========================================================================
|
| 107 |
+
# PAGE 2: ANALYSIS
|
| 108 |
+
# =========================================================================
|
| 109 |
+
with gr.Group(visible=False) as page_analysis:
|
| 110 |
+
|
| 111 |
+
# Header
|
| 112 |
+
with gr.Row(elem_classes=["app-header"]):
|
| 113 |
+
gr.Markdown(f"**{Config.APP_TITLE}**", elem_classes=["app-title"])
|
| 114 |
+
btn_back = gr.Button("< Back to Patients", elem_classes=["back-btn"])
|
| 115 |
+
|
| 116 |
+
with gr.Row(elem_classes=["analysis-container"]):
|
| 117 |
+
|
| 118 |
+
# Sidebar (previous queries)
|
| 119 |
+
with gr.Column(scale=0, min_width=260, visible=False, elem_classes=["query-sidebar"]) as sidebar:
|
| 120 |
+
gr.Markdown("### Previous Queries", elem_classes=["sidebar-header"])
|
| 121 |
+
sidebar_list = gr.Column(elem_id="sidebar-queries")
|
| 122 |
+
btn_new_query = gr.Button("+ New Query", size="sm", variant="primary")
|
| 123 |
+
|
| 124 |
+
# Main content
|
| 125 |
+
with gr.Column(scale=4, elem_classes=["main-content"]):
|
| 126 |
+
|
| 127 |
+
# Input view (greeting style)
|
| 128 |
+
with gr.Group(visible=True, elem_classes=["input-greeting"]) as view_input:
|
| 129 |
+
gr.Markdown("What would you like to analyze?", elem_classes=["greeting-title"])
|
| 130 |
+
gr.Markdown("Upload an image and describe what you'd like to know", elem_classes=["greeting-subtitle"])
|
| 131 |
+
|
| 132 |
+
with gr.Column(elem_classes=["input-box-container"]):
|
| 133 |
+
input_message = gr.Textbox(
|
| 134 |
+
placeholder="Describe the lesion or ask a question...",
|
| 135 |
+
show_label=False,
|
| 136 |
+
lines=2,
|
| 137 |
+
elem_classes=["message-input"]
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
input_image = gr.Image(
|
| 141 |
+
type="pil",
|
| 142 |
+
height=180,
|
| 143 |
+
show_label=False,
|
| 144 |
+
elem_classes=["image-preview"]
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
with gr.Row(elem_classes=["input-actions"]):
|
| 148 |
+
gr.Markdown("*Upload a skin lesion image*")
|
| 149 |
+
btn_analyze = gr.Button("Analyze", elem_classes=["send-btn"], interactive=False)
|
| 150 |
+
|
| 151 |
+
# Results view (shown after analysis)
|
| 152 |
+
with gr.Group(visible=False, elem_classes=["chat-view"]) as view_results:
|
| 153 |
+
output_html = gr.HTML(
|
| 154 |
+
value='<div class="analysis-output">Starting...</div>',
|
| 155 |
+
elem_classes=["results-area"]
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
# Confirmation
|
| 159 |
+
with gr.Group(visible=False, elem_classes=["confirm-buttons"]) as confirm_box:
|
| 160 |
+
gr.Markdown("**Do you agree with this diagnosis?**")
|
| 161 |
+
with gr.Row():
|
| 162 |
+
btn_confirm_yes = gr.Button("Yes, continue", variant="primary", size="sm")
|
| 163 |
+
btn_confirm_no = gr.Button("No, I disagree", variant="secondary", size="sm")
|
| 164 |
+
input_feedback = gr.Textbox(label="Your assessment", placeholder="Enter diagnosis...", visible=False)
|
| 165 |
+
btn_submit_feedback = gr.Button("Submit", visible=False, size="sm")
|
| 166 |
+
|
| 167 |
+
# Follow-up
|
| 168 |
+
with gr.Row(elem_classes=["chat-input-area"]):
|
| 169 |
+
input_followup = gr.Textbox(placeholder="Ask a follow-up question...", show_label=False, lines=1, scale=4)
|
| 170 |
+
btn_followup = gr.Button("Send", size="sm", scale=1)
|
| 171 |
+
|
| 172 |
+
# =========================================================================
|
| 173 |
+
# DYNAMIC SIDEBAR RENDERING
|
| 174 |
+
# =========================================================================
|
| 175 |
+
@gr.render(inputs=[state], triggers=[state.change])
|
| 176 |
+
def render_sidebar(s):
|
| 177 |
+
case_id = s.get("case_id")
|
| 178 |
+
if not case_id or s.get("page") != "analysis":
|
| 179 |
+
return
|
| 180 |
+
|
| 181 |
+
instances = case_store.list_instances(case_id)
|
| 182 |
+
current = s.get("instance_id")
|
| 183 |
+
|
| 184 |
+
for i, inst in enumerate(instances, 1):
|
| 185 |
+
diagnosis = "Pending"
|
| 186 |
+
if inst.analysis and inst.analysis.get("diagnosis"):
|
| 187 |
+
d = inst.analysis["diagnosis"]
|
| 188 |
+
diagnosis = d.get("class", "?")
|
| 189 |
+
|
| 190 |
+
label = f"#{i}: {diagnosis}"
|
| 191 |
+
variant = "primary" if inst.id == current else "secondary"
|
| 192 |
+
btn = gr.Button(label, size="sm", variant=variant, elem_classes=["query-item"])
|
| 193 |
+
|
| 194 |
+
# Attach click handler to load this instance
|
| 195 |
+
def load_instance(inst_id=inst.id, c_id=case_id):
|
| 196 |
+
def _load(current_state):
|
| 197 |
+
current_state["instance_id"] = inst_id
|
| 198 |
+
instance = case_store.get_instance(c_id, inst_id)
|
| 199 |
+
|
| 200 |
+
# Load saved output if available
|
| 201 |
+
output_html = '<div class="analysis-output"><div class="result">Previous analysis loaded</div></div>'
|
| 202 |
+
if instance and instance.analysis:
|
| 203 |
+
diag = instance.analysis.get("diagnosis", {})
|
| 204 |
+
output_html = f'<div class="analysis-output"><div class="result">Diagnosis: {diag.get("full_name", diag.get("class", "Unknown"))}</div></div>'
|
| 205 |
+
|
| 206 |
+
return (
|
| 207 |
+
current_state,
|
| 208 |
+
gr.update(visible=False), # view_input
|
| 209 |
+
gr.update(visible=True), # view_results
|
| 210 |
+
output_html,
|
| 211 |
+
gr.update(visible=False) # confirm_box
|
| 212 |
+
)
|
| 213 |
+
return _load
|
| 214 |
+
|
| 215 |
+
btn.click(
|
| 216 |
+
load_instance(),
|
| 217 |
+
inputs=[state],
|
| 218 |
+
outputs=[state, view_input, view_results, output_html, confirm_box]
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
# =========================================================================
|
| 222 |
+
# EVENT HANDLERS
|
| 223 |
+
# =========================================================================
|
| 224 |
+
|
| 225 |
+
def select_patient(case_id: str, s: Dict):
|
| 226 |
+
"""Handle patient selection"""
|
| 227 |
+
s["case_id"] = case_id
|
| 228 |
+
s["page"] = "analysis"
|
| 229 |
+
|
| 230 |
+
instances = case_store.list_instances(case_id)
|
| 231 |
+
has_queries = len(instances) > 0
|
| 232 |
+
|
| 233 |
+
if has_queries:
|
| 234 |
+
# Load most recent
|
| 235 |
+
inst = instances[-1]
|
| 236 |
+
s["instance_id"] = inst.id
|
| 237 |
+
|
| 238 |
+
# Load image if exists
|
| 239 |
+
img = None
|
| 240 |
+
if inst.image_path and os.path.exists(inst.image_path):
|
| 241 |
+
from PIL import Image
|
| 242 |
+
img = Image.open(inst.image_path)
|
| 243 |
+
|
| 244 |
+
return (
|
| 245 |
+
s,
|
| 246 |
+
gr.update(visible=False), # page_patient
|
| 247 |
+
gr.update(visible=True), # page_analysis
|
| 248 |
+
gr.update(visible=True), # sidebar
|
| 249 |
+
gr.update(visible=False), # view_input
|
| 250 |
+
gr.update(visible=True), # view_results
|
| 251 |
+
'<div class="analysis-output"><div class="result">Previous analysis loaded</div></div>',
|
| 252 |
+
gr.update(visible=False) # confirm_box
|
| 253 |
+
)
|
| 254 |
+
else:
|
| 255 |
+
# New instance
|
| 256 |
+
inst = case_store.create_instance(case_id)
|
| 257 |
+
s["instance_id"] = inst.id
|
| 258 |
+
s["output"] = ""
|
| 259 |
+
|
| 260 |
+
return (
|
| 261 |
+
s,
|
| 262 |
+
gr.update(visible=False),
|
| 263 |
+
gr.update(visible=True),
|
| 264 |
+
gr.update(visible=False), # sidebar hidden for new patient
|
| 265 |
+
gr.update(visible=True), # view_input
|
| 266 |
+
gr.update(visible=False), # view_results
|
| 267 |
+
"",
|
| 268 |
+
gr.update(visible=False)
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
def new_patient(s: Dict):
|
| 272 |
+
"""Create new patient"""
|
| 273 |
+
case = case_store.create_case(f"Patient {datetime.now().strftime('%Y-%m-%d %H:%M')}")
|
| 274 |
+
return select_patient(case.id, s)
|
| 275 |
+
|
| 276 |
+
def go_back(s: Dict):
|
| 277 |
+
"""Return to patient selection"""
|
| 278 |
+
s["page"] = "patient_select"
|
| 279 |
+
s["case_id"] = None
|
| 280 |
+
s["instance_id"] = None
|
| 281 |
+
s["output"] = ""
|
| 282 |
+
|
| 283 |
+
return (
|
| 284 |
+
s,
|
| 285 |
+
gr.update(visible=True), # page_patient
|
| 286 |
+
gr.update(visible=False), # page_analysis
|
| 287 |
+
gr.update(visible=False), # sidebar
|
| 288 |
+
gr.update(visible=True), # view_input
|
| 289 |
+
gr.update(visible=False), # view_results
|
| 290 |
+
"",
|
| 291 |
+
gr.update(visible=False) # confirm_box
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
def new_query(s: Dict):
|
| 295 |
+
"""Start new query for current patient"""
|
| 296 |
+
case_id = s.get("case_id")
|
| 297 |
+
if not case_id:
|
| 298 |
+
return s, gr.update(), gr.update(), gr.update(), "", gr.update()
|
| 299 |
+
|
| 300 |
+
inst = case_store.create_instance(case_id)
|
| 301 |
+
s["instance_id"] = inst.id
|
| 302 |
+
s["output"] = ""
|
| 303 |
+
s["gradcam_base64"] = None
|
| 304 |
+
|
| 305 |
+
agent.reset()
|
| 306 |
+
|
| 307 |
+
return (
|
| 308 |
+
s,
|
| 309 |
+
gr.update(visible=True), # view_input
|
| 310 |
+
gr.update(visible=False), # view_results
|
| 311 |
+
None, # clear image
|
| 312 |
+
"", # clear output
|
| 313 |
+
gr.update(visible=False) # confirm_box
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
def enable_analyze(img):
|
| 317 |
+
"""Enable analyze button when image uploaded"""
|
| 318 |
+
return gr.update(interactive=img is not None)
|
| 319 |
+
|
| 320 |
+
def run_analysis(image, message, s: Dict):
|
| 321 |
+
"""Run analysis on uploaded image"""
|
| 322 |
+
if image is None:
|
| 323 |
+
yield s, gr.update(), gr.update(), gr.update(), gr.update()
|
| 324 |
+
return
|
| 325 |
+
|
| 326 |
+
case_id = s["case_id"]
|
| 327 |
+
instance_id = s["instance_id"]
|
| 328 |
+
|
| 329 |
+
# Save image
|
| 330 |
+
image_path = case_store.save_image(case_id, instance_id, image)
|
| 331 |
+
case_store.update_analysis(case_id, instance_id, stage="analyzing", image_path=image_path)
|
| 332 |
+
|
| 333 |
+
agent.reset()
|
| 334 |
+
s["output"] = ""
|
| 335 |
+
gradcam_base64 = None
|
| 336 |
+
has_confirm = False
|
| 337 |
+
|
| 338 |
+
# Switch to results view
|
| 339 |
+
yield (
|
| 340 |
+
s,
|
| 341 |
+
gr.update(visible=False), # view_input
|
| 342 |
+
gr.update(visible=True), # view_results
|
| 343 |
+
'<div class="analysis-output">Starting analysis...</div>',
|
| 344 |
+
gr.update(visible=False) # confirm_box
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
partial = ""
|
| 348 |
+
for chunk in agent.analyze(image_path, message or ""):
|
| 349 |
+
partial += chunk
|
| 350 |
+
|
| 351 |
+
# Check for GradCAM
|
| 352 |
+
if gradcam_base64 is None:
|
| 353 |
+
match = re.search(r'\[GRADCAM_IMAGE:([^\]]+)\]', partial)
|
| 354 |
+
if match:
|
| 355 |
+
path = match.group(1)
|
| 356 |
+
if os.path.exists(path):
|
| 357 |
+
try:
|
| 358 |
+
with open(path, "rb") as f:
|
| 359 |
+
gradcam_base64 = base64.b64encode(f.read()).decode('utf-8')
|
| 360 |
+
s["gradcam_base64"] = gradcam_base64
|
| 361 |
+
except:
|
| 362 |
+
pass
|
| 363 |
+
|
| 364 |
+
if '[CONFIRM:' in partial:
|
| 365 |
+
has_confirm = True
|
| 366 |
+
|
| 367 |
+
s["output"] = partial
|
| 368 |
+
|
| 369 |
+
yield (
|
| 370 |
+
s,
|
| 371 |
+
gr.update(visible=False),
|
| 372 |
+
gr.update(visible=True),
|
| 373 |
+
format_output(partial, gradcam_base64),
|
| 374 |
+
gr.update(visible=has_confirm)
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
# Save analysis
|
| 378 |
+
if agent.model and agent.model.last_diagnosis:
|
| 379 |
+
diag = agent.model.last_diagnosis["predictions"][0]
|
| 380 |
+
case_store.update_analysis(
|
| 381 |
+
case_id, instance_id,
|
| 382 |
+
stage="awaiting_confirmation",
|
| 383 |
+
analysis={"diagnosis": diag}
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
def confirm_yes(s: Dict):
|
| 387 |
+
"""User confirmed diagnosis"""
|
| 388 |
+
partial = s.get("output", "")
|
| 389 |
+
gradcam = s.get("gradcam_base64")
|
| 390 |
+
|
| 391 |
+
for chunk in agent.management_guidance(confirmed=True):
|
| 392 |
+
partial += chunk
|
| 393 |
+
s["output"] = partial
|
| 394 |
+
yield s, format_output(partial, gradcam), gr.update(visible=False)
|
| 395 |
+
|
| 396 |
+
case_store.update_analysis(s["case_id"], s["instance_id"], stage="complete")
|
| 397 |
+
|
| 398 |
+
def confirm_no():
|
| 399 |
+
"""Show feedback input"""
|
| 400 |
+
return gr.update(visible=True), gr.update(visible=True)
|
| 401 |
+
|
| 402 |
+
def submit_feedback(feedback: str, s: Dict):
|
| 403 |
+
"""Submit user feedback"""
|
| 404 |
+
partial = s.get("output", "")
|
| 405 |
+
gradcam = s.get("gradcam_base64")
|
| 406 |
+
|
| 407 |
+
for chunk in agent.management_guidance(confirmed=False, feedback=feedback):
|
| 408 |
+
partial += chunk
|
| 409 |
+
s["output"] = partial
|
| 410 |
+
yield (
|
| 411 |
+
s,
|
| 412 |
+
format_output(partial, gradcam),
|
| 413 |
+
gr.update(visible=False),
|
| 414 |
+
gr.update(visible=False),
|
| 415 |
+
gr.update(visible=False),
|
| 416 |
+
""
|
| 417 |
+
)
|
| 418 |
+
|
| 419 |
+
case_store.update_analysis(s["case_id"], s["instance_id"], stage="complete")
|
| 420 |
+
|
| 421 |
+
def send_followup(message: str, s: Dict):
|
| 422 |
+
"""Send follow-up question"""
|
| 423 |
+
if not message.strip():
|
| 424 |
+
return s, gr.update(), ""
|
| 425 |
+
|
| 426 |
+
case_store.add_chat_message(s["case_id"], s["instance_id"], "user", message)
|
| 427 |
+
|
| 428 |
+
partial = s.get("output", "")
|
| 429 |
+
gradcam = s.get("gradcam_base64")
|
| 430 |
+
|
| 431 |
+
partial += f'\n<div class="chat-message user">You: {message}</div>\n'
|
| 432 |
+
|
| 433 |
+
response = ""
|
| 434 |
+
for chunk in agent.followup(message):
|
| 435 |
+
response += chunk
|
| 436 |
+
s["output"] = partial + response
|
| 437 |
+
yield s, format_output(partial + response, gradcam), ""
|
| 438 |
+
|
| 439 |
+
case_store.add_chat_message(s["case_id"], s["instance_id"], "assistant", response)
|
| 440 |
+
|
| 441 |
+
# =========================================================================
|
| 442 |
+
# WIRE EVENTS
|
| 443 |
+
# =========================================================================
|
| 444 |
+
|
| 445 |
+
# Patient selection
|
| 446 |
+
btn_demo_melanoma.click(
|
| 447 |
+
lambda s: select_patient("demo-melanoma", s),
|
| 448 |
+
inputs=[state],
|
| 449 |
+
outputs=[state, page_patient, page_analysis, sidebar, view_input, view_results, output_html, confirm_box]
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
btn_demo_ak.click(
|
| 453 |
+
lambda s: select_patient("demo-ak", s),
|
| 454 |
+
inputs=[state],
|
| 455 |
+
outputs=[state, page_patient, page_analysis, sidebar, view_input, view_results, output_html, confirm_box]
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
btn_new_patient.click(
|
| 459 |
+
new_patient,
|
| 460 |
+
inputs=[state],
|
| 461 |
+
outputs=[state, page_patient, page_analysis, sidebar, view_input, view_results, output_html, confirm_box]
|
| 462 |
+
)
|
| 463 |
+
|
| 464 |
+
# Navigation
|
| 465 |
+
btn_back.click(
|
| 466 |
+
go_back,
|
| 467 |
+
inputs=[state],
|
| 468 |
+
outputs=[state, page_patient, page_analysis, sidebar, view_input, view_results, output_html, confirm_box]
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
btn_new_query.click(
|
| 472 |
+
new_query,
|
| 473 |
+
inputs=[state],
|
| 474 |
+
outputs=[state, view_input, view_results, input_image, output_html, confirm_box]
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
# Analysis
|
| 478 |
+
input_image.change(enable_analyze, inputs=[input_image], outputs=[btn_analyze])
|
| 479 |
+
|
| 480 |
+
btn_analyze.click(
|
| 481 |
+
run_analysis,
|
| 482 |
+
inputs=[input_image, input_message, state],
|
| 483 |
+
outputs=[state, view_input, view_results, output_html, confirm_box]
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
# Confirmation
|
| 487 |
+
btn_confirm_yes.click(
|
| 488 |
+
confirm_yes,
|
| 489 |
+
inputs=[state],
|
| 490 |
+
outputs=[state, output_html, confirm_box]
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
btn_confirm_no.click(
|
| 494 |
+
confirm_no,
|
| 495 |
+
outputs=[input_feedback, btn_submit_feedback]
|
| 496 |
+
)
|
| 497 |
+
|
| 498 |
+
btn_submit_feedback.click(
|
| 499 |
+
submit_feedback,
|
| 500 |
+
inputs=[input_feedback, state],
|
| 501 |
+
outputs=[state, output_html, confirm_box, input_feedback, btn_submit_feedback, input_feedback]
|
| 502 |
+
)
|
| 503 |
+
|
| 504 |
+
# Follow-up
|
| 505 |
+
btn_followup.click(
|
| 506 |
+
send_followup,
|
| 507 |
+
inputs=[input_followup, state],
|
| 508 |
+
outputs=[state, output_html, input_followup]
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
input_followup.submit(
|
| 512 |
+
send_followup,
|
| 513 |
+
inputs=[input_followup, state],
|
| 514 |
+
outputs=[state, output_html, input_followup]
|
| 515 |
+
)
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
# =============================================================================
|
| 519 |
+
# MAIN
|
| 520 |
+
# =============================================================================
|
| 521 |
+
|
| 522 |
+
if __name__ == "__main__":
|
| 523 |
+
print(f"\n{'='*50}")
|
| 524 |
+
print(f" {Config.APP_TITLE}")
|
| 525 |
+
print(f"{'='*50}\n")
|
| 526 |
+
|
| 527 |
+
app.queue().launch(
|
| 528 |
+
server_name="0.0.0.0" if Config.HF_SPACES else "127.0.0.1",
|
| 529 |
+
server_port=Config.SERVER_PORT,
|
| 530 |
+
share=False,
|
| 531 |
+
show_error=True
|
| 532 |
+
)
|
frontend/components/__init__.py
ADDED
|
File without changes
|
frontend/components/analysis_view.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Analysis View Component - Main analysis interface with input and results
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import gradio as gr
|
| 6 |
+
import re
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def parse_markdown(text: str) -> str:
|
| 11 |
+
"""Convert basic markdown to HTML"""
|
| 12 |
+
text = re.sub(r'\*\*(.+?)\*\*', r'<strong>\1</strong>', text)
|
| 13 |
+
text = re.sub(r'__(.+?)__', r'<strong>\1</strong>', text)
|
| 14 |
+
text = re.sub(r'\*(.+?)\*', r'<em>\1</em>', text)
|
| 15 |
+
|
| 16 |
+
# Bullet lists
|
| 17 |
+
lines = text.split('\n')
|
| 18 |
+
in_list = False
|
| 19 |
+
result = []
|
| 20 |
+
for line in lines:
|
| 21 |
+
stripped = line.strip()
|
| 22 |
+
if re.match(r'^[\*\-] ', stripped):
|
| 23 |
+
if not in_list:
|
| 24 |
+
result.append('<ul>')
|
| 25 |
+
in_list = True
|
| 26 |
+
item = re.sub(r'^[\*\-] ', '', stripped)
|
| 27 |
+
result.append(f'<li>{item}</li>')
|
| 28 |
+
else:
|
| 29 |
+
if in_list:
|
| 30 |
+
result.append('</ul>')
|
| 31 |
+
in_list = False
|
| 32 |
+
result.append(line)
|
| 33 |
+
if in_list:
|
| 34 |
+
result.append('</ul>')
|
| 35 |
+
|
| 36 |
+
return '\n'.join(result)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# Regex patterns for output parsing
|
| 40 |
+
_STAGE_RE = re.compile(r'\[STAGE:(\w+)\](.*?)\[/STAGE\]')
|
| 41 |
+
_THINKING_RE = re.compile(r'\[THINKING\](.*?)\[/THINKING\]')
|
| 42 |
+
_OBSERVATION_RE = re.compile(r'\[OBSERVATION\](.*?)\[/OBSERVATION\]')
|
| 43 |
+
_TOOL_OUTPUT_RE = re.compile(r'\[TOOL_OUTPUT:(.*?)\]\n(.*?)\[/TOOL_OUTPUT\]', re.DOTALL)
|
| 44 |
+
_RESULT_RE = re.compile(r'\[RESULT\](.*?)\[/RESULT\]')
|
| 45 |
+
_ERROR_RE = re.compile(r'\[ERROR\](.*?)\[/ERROR\]')
|
| 46 |
+
_GRADCAM_RE = re.compile(r'\[GRADCAM_IMAGE:[^\]]+\]\n?')
|
| 47 |
+
_RESPONSE_RE = re.compile(r'\[RESPONSE\]\n(.*?)\n\[/RESPONSE\]', re.DOTALL)
|
| 48 |
+
_COMPLETE_RE = re.compile(r'\[COMPLETE\](.*?)\[/COMPLETE\]')
|
| 49 |
+
_CONFIRM_RE = re.compile(r'\[CONFIRM:(\w+)\](.*?)\[/CONFIRM\]')
|
| 50 |
+
_REFERENCES_RE = re.compile(r'\[REFERENCES\](.*?)\[/REFERENCES\]', re.DOTALL)
|
| 51 |
+
_REF_RE = re.compile(r'\[REF:([^:]+):([^:]+):([^:]+):([^:]+):([^\]]+)\]')
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def format_output(raw_text: str, gradcam_base64: Optional[str] = None) -> str:
|
| 55 |
+
"""Convert tagged output to styled HTML"""
|
| 56 |
+
html = raw_text
|
| 57 |
+
|
| 58 |
+
# Stage headers
|
| 59 |
+
html = _STAGE_RE.sub(
|
| 60 |
+
r'<div class="stage"><span class="stage-indicator"></span><span class="stage-text">\2</span></div>',
|
| 61 |
+
html
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
# Thinking
|
| 65 |
+
html = _THINKING_RE.sub(r'<div class="thinking">\1</div>', html)
|
| 66 |
+
|
| 67 |
+
# Observations
|
| 68 |
+
html = _OBSERVATION_RE.sub(r'<div class="observation">\1</div>', html)
|
| 69 |
+
|
| 70 |
+
# Tool outputs
|
| 71 |
+
html = _TOOL_OUTPUT_RE.sub(
|
| 72 |
+
r'<div class="tool-output"><div class="tool-header">\1</div><pre class="tool-content">\2</pre></div>',
|
| 73 |
+
html
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
# Results
|
| 77 |
+
html = _RESULT_RE.sub(r'<div class="result">\1</div>', html)
|
| 78 |
+
|
| 79 |
+
# Errors
|
| 80 |
+
html = _ERROR_RE.sub(r'<div class="error">\1</div>', html)
|
| 81 |
+
|
| 82 |
+
# GradCAM image
|
| 83 |
+
if gradcam_base64:
|
| 84 |
+
img_html = f'<div class="gradcam-inline"><div class="gradcam-header">Attention Map</div><img src="data:image/png;base64,{gradcam_base64}" alt="Grad-CAM"></div>'
|
| 85 |
+
html = _GRADCAM_RE.sub(img_html, html)
|
| 86 |
+
else:
|
| 87 |
+
html = _GRADCAM_RE.sub('', html)
|
| 88 |
+
|
| 89 |
+
# Response section
|
| 90 |
+
def format_response(match):
|
| 91 |
+
content = match.group(1)
|
| 92 |
+
parsed = parse_markdown(content)
|
| 93 |
+
parsed = re.sub(r'\n\n+', '</p><p>', parsed)
|
| 94 |
+
parsed = parsed.replace('\n', '<br>')
|
| 95 |
+
return f'<div class="response"><p>{parsed}</p></div>'
|
| 96 |
+
|
| 97 |
+
html = _RESPONSE_RE.sub(format_response, html)
|
| 98 |
+
|
| 99 |
+
# Complete
|
| 100 |
+
html = _COMPLETE_RE.sub(r'<div class="complete">\1</div>', html)
|
| 101 |
+
|
| 102 |
+
# Confirmation
|
| 103 |
+
html = _CONFIRM_RE.sub(
|
| 104 |
+
r'<div class="confirm-box"><div class="confirm-text">\2</div></div>',
|
| 105 |
+
html
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# References
|
| 109 |
+
def format_references(match):
|
| 110 |
+
ref_content = match.group(1)
|
| 111 |
+
refs_html = ['<div class="references"><div class="references-header">References</div><ul>']
|
| 112 |
+
for ref_match in _REF_RE.finditer(ref_content):
|
| 113 |
+
_, source, page, filename, superscript = ref_match.groups()
|
| 114 |
+
refs_html.append(
|
| 115 |
+
f'<li><a href="guidelines/{filename}#page={page}" target="_blank" class="ref-link">'
|
| 116 |
+
f'<sup>{superscript}</sup> {source}, p.{page}</a></li>'
|
| 117 |
+
)
|
| 118 |
+
refs_html.append('</ul></div>')
|
| 119 |
+
return '\n'.join(refs_html)
|
| 120 |
+
|
| 121 |
+
html = _REFERENCES_RE.sub(format_references, html)
|
| 122 |
+
|
| 123 |
+
# Convert newlines
|
| 124 |
+
html = html.replace('\n', '<br>')
|
| 125 |
+
|
| 126 |
+
return f'<div class="analysis-output">{html}</div>'
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def create_analysis_view():
|
| 130 |
+
"""
|
| 131 |
+
Create the analysis view component.
|
| 132 |
+
|
| 133 |
+
Returns:
|
| 134 |
+
Tuple of (container, components dict)
|
| 135 |
+
"""
|
| 136 |
+
with gr.Group(visible=False, elem_classes=["analysis-container"]) as container:
|
| 137 |
+
|
| 138 |
+
with gr.Row():
|
| 139 |
+
# Main content area
|
| 140 |
+
with gr.Column(elem_classes=["main-content"]):
|
| 141 |
+
|
| 142 |
+
# Input greeting (shown when no analysis yet)
|
| 143 |
+
with gr.Group(visible=True, elem_classes=["input-greeting"]) as input_greeting:
|
| 144 |
+
gr.Markdown("What would you like to analyze?", elem_classes=["greeting-title"])
|
| 145 |
+
gr.Markdown("Upload an image and describe what you'd like to know", elem_classes=["greeting-subtitle"])
|
| 146 |
+
|
| 147 |
+
with gr.Column(elem_classes=["input-box-container"]):
|
| 148 |
+
message_input = gr.Textbox(
|
| 149 |
+
placeholder="Describe the lesion or ask a question...",
|
| 150 |
+
show_label=False,
|
| 151 |
+
lines=3,
|
| 152 |
+
elem_classes=["message-input"]
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
# Image upload (compact)
|
| 156 |
+
image_input = gr.Image(
|
| 157 |
+
label="",
|
| 158 |
+
type="pil",
|
| 159 |
+
height=180,
|
| 160 |
+
elem_classes=["image-preview"],
|
| 161 |
+
show_label=False
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
with gr.Row(elem_classes=["input-actions"]):
|
| 165 |
+
upload_hint = gr.Markdown("*Upload a skin lesion image above*", visible=True)
|
| 166 |
+
send_btn = gr.Button("Analyze", elem_classes=["send-btn"], interactive=False)
|
| 167 |
+
|
| 168 |
+
# Chat/results view (shown after analysis starts)
|
| 169 |
+
with gr.Group(visible=False, elem_classes=["chat-view"]) as chat_view:
|
| 170 |
+
results_output = gr.HTML(
|
| 171 |
+
value='<div class="analysis-output">Starting analysis...</div>',
|
| 172 |
+
elem_classes=["results-area"]
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
# Confirmation buttons
|
| 176 |
+
with gr.Group(visible=False, elem_classes=["confirm-buttons"]) as confirm_group:
|
| 177 |
+
gr.Markdown("**Do you agree with this diagnosis?**")
|
| 178 |
+
with gr.Row():
|
| 179 |
+
confirm_yes_btn = gr.Button("Yes, continue", variant="primary", size="sm")
|
| 180 |
+
confirm_no_btn = gr.Button("No, I disagree", variant="secondary", size="sm")
|
| 181 |
+
feedback_input = gr.Textbox(
|
| 182 |
+
label="Your assessment",
|
| 183 |
+
placeholder="Enter your diagnosis...",
|
| 184 |
+
visible=False
|
| 185 |
+
)
|
| 186 |
+
submit_feedback_btn = gr.Button("Submit", visible=False, size="sm")
|
| 187 |
+
|
| 188 |
+
# Follow-up input
|
| 189 |
+
with gr.Row(elem_classes=["chat-input-area"]):
|
| 190 |
+
followup_input = gr.Textbox(
|
| 191 |
+
placeholder="Ask a follow-up question...",
|
| 192 |
+
show_label=False,
|
| 193 |
+
lines=1
|
| 194 |
+
)
|
| 195 |
+
followup_btn = gr.Button("Send", size="sm", elem_classes=["send-btn"])
|
| 196 |
+
|
| 197 |
+
components = {
|
| 198 |
+
"input_greeting": input_greeting,
|
| 199 |
+
"chat_view": chat_view,
|
| 200 |
+
"message_input": message_input,
|
| 201 |
+
"image_input": image_input,
|
| 202 |
+
"send_btn": send_btn,
|
| 203 |
+
"results_output": results_output,
|
| 204 |
+
"confirm_group": confirm_group,
|
| 205 |
+
"confirm_yes_btn": confirm_yes_btn,
|
| 206 |
+
"confirm_no_btn": confirm_no_btn,
|
| 207 |
+
"feedback_input": feedback_input,
|
| 208 |
+
"submit_feedback_btn": submit_feedback_btn,
|
| 209 |
+
"followup_input": followup_input,
|
| 210 |
+
"followup_btn": followup_btn,
|
| 211 |
+
"upload_hint": upload_hint
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
return container, components
|
frontend/components/patient_select.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Patient Selection Component - Landing page for selecting/creating patients
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import gradio as gr
|
| 6 |
+
from typing import Callable, List
|
| 7 |
+
from data.case_store import get_case_store, Case
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def create_patient_select(on_patient_selected: Callable[[str], None]) -> gr.Group:
|
| 11 |
+
"""
|
| 12 |
+
Create the patient selection page component.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
on_patient_selected: Callback when a patient is selected (receives case_id)
|
| 16 |
+
|
| 17 |
+
Returns:
|
| 18 |
+
gr.Group containing the patient selection UI
|
| 19 |
+
"""
|
| 20 |
+
case_store = get_case_store()
|
| 21 |
+
|
| 22 |
+
with gr.Group(visible=True, elem_classes=["patient-select-container"]) as container:
|
| 23 |
+
gr.Markdown("# SkinProAI", elem_classes=["patient-select-title"])
|
| 24 |
+
gr.Markdown("Select a patient to continue or create a new case", elem_classes=["patient-select-subtitle"])
|
| 25 |
+
|
| 26 |
+
with gr.Column(elem_classes=["patient-grid"]):
|
| 27 |
+
# Demo cases
|
| 28 |
+
demo_melanoma_btn = gr.Button(
|
| 29 |
+
"Demo: Melanocytic Lesion",
|
| 30 |
+
elem_classes=["patient-card"]
|
| 31 |
+
)
|
| 32 |
+
demo_ak_btn = gr.Button(
|
| 33 |
+
"Demo: Actinic Keratosis",
|
| 34 |
+
elem_classes=["patient-card"]
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
# New patient button
|
| 38 |
+
new_patient_btn = gr.Button(
|
| 39 |
+
"+ New Patient",
|
| 40 |
+
elem_classes=["new-patient-btn"]
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
return container, demo_melanoma_btn, demo_ak_btn, new_patient_btn
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def get_patient_cases() -> List[Case]:
|
| 47 |
+
"""Get list of all patient cases"""
|
| 48 |
+
return get_case_store().list_cases()
|
frontend/components/sidebar.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Sidebar Component - Shows previous queries for a patient
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import gradio as gr
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
from typing import List, Optional
|
| 8 |
+
from data.case_store import get_case_store, Instance
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def format_query_item(instance: Instance, index: int) -> str:
|
| 12 |
+
"""Format an instance as a query item for display"""
|
| 13 |
+
diagnosis = "Pending"
|
| 14 |
+
if instance.analysis and instance.analysis.get("diagnosis"):
|
| 15 |
+
diag = instance.analysis["diagnosis"]
|
| 16 |
+
diagnosis = diag.get("full_name", diag.get("class", "Unknown"))
|
| 17 |
+
|
| 18 |
+
try:
|
| 19 |
+
dt = datetime.fromisoformat(instance.created_at.replace('Z', '+00:00'))
|
| 20 |
+
date_str = dt.strftime("%b %d, %H:%M")
|
| 21 |
+
except:
|
| 22 |
+
date_str = "Unknown"
|
| 23 |
+
|
| 24 |
+
return f"Query #{index}: {diagnosis} ({date_str})"
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def create_sidebar():
|
| 28 |
+
"""
|
| 29 |
+
Create the sidebar component for showing previous queries.
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
Tuple of (container, components dict)
|
| 33 |
+
"""
|
| 34 |
+
with gr.Column(visible=False, elem_classes=["query-sidebar"]) as container:
|
| 35 |
+
gr.Markdown("### Previous Queries", elem_classes=["sidebar-header"])
|
| 36 |
+
|
| 37 |
+
# Dynamic list of query buttons
|
| 38 |
+
query_list = gr.Column(elem_id="query-list")
|
| 39 |
+
|
| 40 |
+
# New query button
|
| 41 |
+
new_query_btn = gr.Button("+ New Query", size="sm", variant="primary")
|
| 42 |
+
|
| 43 |
+
components = {
|
| 44 |
+
"query_list": query_list,
|
| 45 |
+
"new_query_btn": new_query_btn
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
return container, components
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def get_queries_for_case(case_id: str) -> List[Instance]:
|
| 52 |
+
"""Get all instances/queries for a case"""
|
| 53 |
+
if not case_id:
|
| 54 |
+
return []
|
| 55 |
+
return get_case_store().list_instances(case_id)
|
frontend/components/styles.py
ADDED
|
@@ -0,0 +1,517 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
CSS Styles for SkinProAI components
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
MAIN_CSS = """
|
| 6 |
+
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600&display=swap');
|
| 7 |
+
|
| 8 |
+
* {
|
| 9 |
+
font-family: 'Inter', sans-serif !important;
|
| 10 |
+
}
|
| 11 |
+
|
| 12 |
+
.gradio-container {
|
| 13 |
+
max-width: 1200px !important;
|
| 14 |
+
margin: 0 auto !important;
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
/* Hide Gradio footer */
|
| 18 |
+
.gradio-container footer { display: none !important; }
|
| 19 |
+
|
| 20 |
+
/* ============================================
|
| 21 |
+
PATIENT SELECTION PAGE
|
| 22 |
+
============================================ */
|
| 23 |
+
|
| 24 |
+
.patient-select-container {
|
| 25 |
+
min-height: 80vh;
|
| 26 |
+
display: flex;
|
| 27 |
+
flex-direction: column;
|
| 28 |
+
align-items: center;
|
| 29 |
+
justify-content: center;
|
| 30 |
+
padding: 40px 20px;
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
.patient-select-title {
|
| 34 |
+
font-size: 32px;
|
| 35 |
+
font-weight: 600;
|
| 36 |
+
color: #111827;
|
| 37 |
+
margin-bottom: 8px;
|
| 38 |
+
text-align: center;
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
.patient-select-subtitle {
|
| 42 |
+
font-size: 16px;
|
| 43 |
+
color: #6b7280;
|
| 44 |
+
margin-bottom: 40px;
|
| 45 |
+
text-align: center;
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
.patient-grid {
|
| 49 |
+
display: flex;
|
| 50 |
+
gap: 20px;
|
| 51 |
+
flex-wrap: wrap;
|
| 52 |
+
justify-content: center;
|
| 53 |
+
max-width: 800px;
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
.patient-card {
|
| 57 |
+
background: white !important;
|
| 58 |
+
border: 2px solid #e5e7eb !important;
|
| 59 |
+
border-radius: 16px !important;
|
| 60 |
+
padding: 24px 32px !important;
|
| 61 |
+
min-width: 200px !important;
|
| 62 |
+
cursor: pointer;
|
| 63 |
+
transition: all 0.2s ease !important;
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
.patient-card:hover {
|
| 67 |
+
border-color: #6366f1 !important;
|
| 68 |
+
box-shadow: 0 8px 25px rgba(99, 102, 241, 0.15) !important;
|
| 69 |
+
transform: translateY(-2px);
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
.new-patient-btn {
|
| 73 |
+
background: #6366f1 !important;
|
| 74 |
+
color: white !important;
|
| 75 |
+
border: none !important;
|
| 76 |
+
border-radius: 12px !important;
|
| 77 |
+
padding: 16px 32px !important;
|
| 78 |
+
font-weight: 500 !important;
|
| 79 |
+
margin-top: 24px;
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
.new-patient-btn:hover {
|
| 83 |
+
background: #4f46e5 !important;
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
/* ============================================
|
| 87 |
+
ANALYSIS PAGE - MAIN LAYOUT
|
| 88 |
+
============================================ */
|
| 89 |
+
|
| 90 |
+
.analysis-container {
|
| 91 |
+
display: flex;
|
| 92 |
+
height: calc(100vh - 80px);
|
| 93 |
+
min-height: 600px;
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
/* Sidebar */
|
| 97 |
+
.query-sidebar {
|
| 98 |
+
width: 280px;
|
| 99 |
+
background: #f9fafb;
|
| 100 |
+
border-right: 1px solid #e5e7eb;
|
| 101 |
+
padding: 20px;
|
| 102 |
+
overflow-y: auto;
|
| 103 |
+
flex-shrink: 0;
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
.sidebar-header {
|
| 107 |
+
font-size: 14px;
|
| 108 |
+
font-weight: 600;
|
| 109 |
+
color: #374151;
|
| 110 |
+
margin-bottom: 16px;
|
| 111 |
+
padding-bottom: 12px;
|
| 112 |
+
border-bottom: 1px solid #e5e7eb;
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
.query-item {
|
| 116 |
+
background: white;
|
| 117 |
+
border: 1px solid #e5e7eb;
|
| 118 |
+
border-radius: 8px;
|
| 119 |
+
padding: 12px;
|
| 120 |
+
margin-bottom: 8px;
|
| 121 |
+
cursor: pointer;
|
| 122 |
+
transition: all 0.15s;
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
.query-item:hover {
|
| 126 |
+
border-color: #6366f1;
|
| 127 |
+
background: #f5f3ff;
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
.query-item-title {
|
| 131 |
+
font-size: 13px;
|
| 132 |
+
font-weight: 500;
|
| 133 |
+
color: #111827;
|
| 134 |
+
margin-bottom: 4px;
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
.query-item-meta {
|
| 138 |
+
font-size: 11px;
|
| 139 |
+
color: #6b7280;
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
/* Main content area */
|
| 143 |
+
.main-content {
|
| 144 |
+
flex: 1;
|
| 145 |
+
display: flex;
|
| 146 |
+
flex-direction: column;
|
| 147 |
+
padding: 24px;
|
| 148 |
+
overflow: hidden;
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
/* ============================================
|
| 152 |
+
INPUT AREA (Greeting style)
|
| 153 |
+
============================================ */
|
| 154 |
+
|
| 155 |
+
.input-greeting {
|
| 156 |
+
flex: 1;
|
| 157 |
+
display: flex;
|
| 158 |
+
flex-direction: column;
|
| 159 |
+
align-items: center;
|
| 160 |
+
justify-content: center;
|
| 161 |
+
padding: 40px;
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
.greeting-title {
|
| 165 |
+
font-size: 24px;
|
| 166 |
+
font-weight: 600;
|
| 167 |
+
color: #111827;
|
| 168 |
+
margin-bottom: 8px;
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
.greeting-subtitle {
|
| 172 |
+
font-size: 14px;
|
| 173 |
+
color: #6b7280;
|
| 174 |
+
margin-bottom: 32px;
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
.input-box-container {
|
| 178 |
+
width: 100%;
|
| 179 |
+
max-width: 600px;
|
| 180 |
+
background: white;
|
| 181 |
+
border: 2px solid #e5e7eb;
|
| 182 |
+
border-radius: 16px;
|
| 183 |
+
padding: 20px;
|
| 184 |
+
transition: border-color 0.2s;
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
.input-box-container:focus-within {
|
| 188 |
+
border-color: #6366f1;
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
.message-input textarea {
|
| 192 |
+
border: none !important;
|
| 193 |
+
resize: none !important;
|
| 194 |
+
font-size: 15px !important;
|
| 195 |
+
line-height: 1.5 !important;
|
| 196 |
+
padding: 0 !important;
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
.message-input textarea:focus {
|
| 200 |
+
box-shadow: none !important;
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
.input-actions {
|
| 204 |
+
display: flex;
|
| 205 |
+
align-items: center;
|
| 206 |
+
justify-content: space-between;
|
| 207 |
+
margin-top: 16px;
|
| 208 |
+
padding-top: 16px;
|
| 209 |
+
border-top: 1px solid #f3f4f6;
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
.upload-btn {
|
| 213 |
+
background: #f3f4f6 !important;
|
| 214 |
+
color: #374151 !important;
|
| 215 |
+
border: 1px solid #e5e7eb !important;
|
| 216 |
+
border-radius: 8px !important;
|
| 217 |
+
padding: 8px 16px !important;
|
| 218 |
+
font-size: 13px !important;
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
.upload-btn:hover {
|
| 222 |
+
background: #e5e7eb !important;
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
.send-btn {
|
| 226 |
+
background: #6366f1 !important;
|
| 227 |
+
color: white !important;
|
| 228 |
+
border: none !important;
|
| 229 |
+
border-radius: 8px !important;
|
| 230 |
+
padding: 10px 24px !important;
|
| 231 |
+
font-weight: 500 !important;
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
.send-btn:hover {
|
| 235 |
+
background: #4f46e5 !important;
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
.send-btn:disabled {
|
| 239 |
+
background: #d1d5db !important;
|
| 240 |
+
cursor: not-allowed;
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
/* Image preview */
|
| 244 |
+
.image-preview {
|
| 245 |
+
margin-top: 16px;
|
| 246 |
+
border-radius: 12px;
|
| 247 |
+
overflow: hidden;
|
| 248 |
+
max-height: 200px;
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
.image-preview img {
|
| 252 |
+
max-height: 200px;
|
| 253 |
+
object-fit: contain;
|
| 254 |
+
}
|
| 255 |
+
|
| 256 |
+
/* ============================================
|
| 257 |
+
CHAT/RESULTS VIEW
|
| 258 |
+
============================================ */
|
| 259 |
+
|
| 260 |
+
.chat-view {
|
| 261 |
+
flex: 1;
|
| 262 |
+
display: flex;
|
| 263 |
+
flex-direction: column;
|
| 264 |
+
overflow: hidden;
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
.results-area {
|
| 268 |
+
flex: 1;
|
| 269 |
+
overflow-y: auto;
|
| 270 |
+
padding: 20px;
|
| 271 |
+
background: #ffffff;
|
| 272 |
+
border: 1px solid #e5e7eb;
|
| 273 |
+
border-radius: 12px;
|
| 274 |
+
margin-bottom: 16px;
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
/* Analysis output styling */
|
| 278 |
+
.analysis-output {
|
| 279 |
+
line-height: 1.6;
|
| 280 |
+
color: #333;
|
| 281 |
+
}
|
| 282 |
+
|
| 283 |
+
.stage {
|
| 284 |
+
display: flex;
|
| 285 |
+
align-items: center;
|
| 286 |
+
gap: 10px;
|
| 287 |
+
padding: 8px 0;
|
| 288 |
+
font-weight: 500;
|
| 289 |
+
color: #1a1a1a;
|
| 290 |
+
margin-top: 12px;
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
.stage-indicator {
|
| 294 |
+
width: 8px;
|
| 295 |
+
height: 8px;
|
| 296 |
+
background: #6366f1;
|
| 297 |
+
border-radius: 50%;
|
| 298 |
+
animation: pulse 1.5s ease-in-out infinite;
|
| 299 |
+
}
|
| 300 |
+
|
| 301 |
+
@keyframes pulse {
|
| 302 |
+
0%, 100% { opacity: 1; transform: scale(1); }
|
| 303 |
+
50% { opacity: 0.5; transform: scale(0.8); }
|
| 304 |
+
}
|
| 305 |
+
|
| 306 |
+
.thinking {
|
| 307 |
+
color: #6b7280;
|
| 308 |
+
font-style: italic;
|
| 309 |
+
font-size: 13px;
|
| 310 |
+
padding: 4px 0 4px 16px;
|
| 311 |
+
border-left: 2px solid #e5e7eb;
|
| 312 |
+
margin: 4px 0;
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
.observation {
|
| 316 |
+
color: #374151;
|
| 317 |
+
font-size: 13px;
|
| 318 |
+
padding: 4px 0 4px 16px;
|
| 319 |
+
}
|
| 320 |
+
|
| 321 |
+
.tool-output {
|
| 322 |
+
background: #f8fafc;
|
| 323 |
+
border-radius: 8px;
|
| 324 |
+
margin: 12px 0;
|
| 325 |
+
overflow: hidden;
|
| 326 |
+
border: 1px solid #e2e8f0;
|
| 327 |
+
}
|
| 328 |
+
|
| 329 |
+
.tool-header {
|
| 330 |
+
background: #f1f5f9;
|
| 331 |
+
padding: 8px 12px;
|
| 332 |
+
font-weight: 500;
|
| 333 |
+
font-size: 13px;
|
| 334 |
+
color: #475569;
|
| 335 |
+
border-bottom: 1px solid #e2e8f0;
|
| 336 |
+
}
|
| 337 |
+
|
| 338 |
+
.tool-content {
|
| 339 |
+
padding: 12px;
|
| 340 |
+
margin: 0;
|
| 341 |
+
font-family: 'SF Mono', Monaco, monospace !important;
|
| 342 |
+
font-size: 12px;
|
| 343 |
+
line-height: 1.5;
|
| 344 |
+
white-space: pre-wrap;
|
| 345 |
+
color: #334155;
|
| 346 |
+
}
|
| 347 |
+
|
| 348 |
+
.result {
|
| 349 |
+
background: #ecfdf5;
|
| 350 |
+
border: 1px solid #a7f3d0;
|
| 351 |
+
border-radius: 8px;
|
| 352 |
+
padding: 12px 16px;
|
| 353 |
+
margin: 12px 0;
|
| 354 |
+
font-weight: 500;
|
| 355 |
+
color: #065f46;
|
| 356 |
+
}
|
| 357 |
+
|
| 358 |
+
.error {
|
| 359 |
+
background: #fef2f2;
|
| 360 |
+
border: 1px solid #fecaca;
|
| 361 |
+
border-radius: 8px;
|
| 362 |
+
padding: 12px 16px;
|
| 363 |
+
margin: 8px 0;
|
| 364 |
+
color: #b91c1c;
|
| 365 |
+
}
|
| 366 |
+
|
| 367 |
+
.response {
|
| 368 |
+
background: #ffffff;
|
| 369 |
+
border: 1px solid #e5e7eb;
|
| 370 |
+
border-radius: 8px;
|
| 371 |
+
padding: 16px;
|
| 372 |
+
margin: 16px 0;
|
| 373 |
+
line-height: 1.7;
|
| 374 |
+
}
|
| 375 |
+
|
| 376 |
+
.response ul, .response ol {
|
| 377 |
+
margin: 8px 0;
|
| 378 |
+
padding-left: 24px;
|
| 379 |
+
}
|
| 380 |
+
|
| 381 |
+
.response li {
|
| 382 |
+
margin: 4px 0;
|
| 383 |
+
}
|
| 384 |
+
|
| 385 |
+
.complete {
|
| 386 |
+
color: #6b7280;
|
| 387 |
+
font-size: 12px;
|
| 388 |
+
padding: 8px 0;
|
| 389 |
+
text-align: center;
|
| 390 |
+
}
|
| 391 |
+
|
| 392 |
+
/* Confirmation */
|
| 393 |
+
.confirm-box {
|
| 394 |
+
background: #eff6ff;
|
| 395 |
+
border: 1px solid #bfdbfe;
|
| 396 |
+
border-radius: 8px;
|
| 397 |
+
padding: 16px;
|
| 398 |
+
margin: 16px 0;
|
| 399 |
+
text-align: center;
|
| 400 |
+
}
|
| 401 |
+
|
| 402 |
+
.confirm-buttons {
|
| 403 |
+
background: #f0f9ff;
|
| 404 |
+
border: 1px solid #bae6fd;
|
| 405 |
+
border-radius: 8px;
|
| 406 |
+
padding: 12px;
|
| 407 |
+
margin-top: 12px;
|
| 408 |
+
}
|
| 409 |
+
|
| 410 |
+
/* References */
|
| 411 |
+
.references {
|
| 412 |
+
background: #f9fafb;
|
| 413 |
+
border: 1px solid #e5e7eb;
|
| 414 |
+
border-radius: 8px;
|
| 415 |
+
margin: 16px 0;
|
| 416 |
+
overflow: hidden;
|
| 417 |
+
}
|
| 418 |
+
|
| 419 |
+
.references-header {
|
| 420 |
+
background: #f3f4f6;
|
| 421 |
+
padding: 8px 12px;
|
| 422 |
+
font-weight: 500;
|
| 423 |
+
font-size: 13px;
|
| 424 |
+
border-bottom: 1px solid #e5e7eb;
|
| 425 |
+
}
|
| 426 |
+
|
| 427 |
+
.references ul {
|
| 428 |
+
list-style: none;
|
| 429 |
+
padding: 12px;
|
| 430 |
+
margin: 0;
|
| 431 |
+
}
|
| 432 |
+
|
| 433 |
+
.ref-link {
|
| 434 |
+
color: #6366f1;
|
| 435 |
+
text-decoration: none;
|
| 436 |
+
font-size: 13px;
|
| 437 |
+
}
|
| 438 |
+
|
| 439 |
+
.ref-link:hover {
|
| 440 |
+
text-decoration: underline;
|
| 441 |
+
}
|
| 442 |
+
|
| 443 |
+
/* GradCAM */
|
| 444 |
+
.gradcam-inline {
|
| 445 |
+
margin: 16px 0;
|
| 446 |
+
background: #f8fafc;
|
| 447 |
+
border-radius: 8px;
|
| 448 |
+
overflow: hidden;
|
| 449 |
+
border: 1px solid #e2e8f0;
|
| 450 |
+
}
|
| 451 |
+
|
| 452 |
+
.gradcam-header {
|
| 453 |
+
background: #f1f5f9;
|
| 454 |
+
padding: 8px 12px;
|
| 455 |
+
font-weight: 500;
|
| 456 |
+
font-size: 13px;
|
| 457 |
+
border-bottom: 1px solid #e2e8f0;
|
| 458 |
+
}
|
| 459 |
+
|
| 460 |
+
.gradcam-inline img {
|
| 461 |
+
max-width: 100%;
|
| 462 |
+
max-height: 300px;
|
| 463 |
+
display: block;
|
| 464 |
+
margin: 12px auto;
|
| 465 |
+
}
|
| 466 |
+
|
| 467 |
+
/* Chat input at bottom */
|
| 468 |
+
.chat-input-area {
|
| 469 |
+
background: white;
|
| 470 |
+
border: 1px solid #e5e7eb;
|
| 471 |
+
border-radius: 12px;
|
| 472 |
+
padding: 12px 16px;
|
| 473 |
+
display: flex;
|
| 474 |
+
gap: 12px;
|
| 475 |
+
align-items: flex-end;
|
| 476 |
+
}
|
| 477 |
+
|
| 478 |
+
.chat-input-area textarea {
|
| 479 |
+
flex: 1;
|
| 480 |
+
border: none !important;
|
| 481 |
+
resize: none !important;
|
| 482 |
+
font-size: 14px !important;
|
| 483 |
+
}
|
| 484 |
+
|
| 485 |
+
/* ============================================
|
| 486 |
+
HEADER
|
| 487 |
+
============================================ */
|
| 488 |
+
|
| 489 |
+
.app-header {
|
| 490 |
+
display: flex;
|
| 491 |
+
align-items: center;
|
| 492 |
+
justify-content: space-between;
|
| 493 |
+
padding: 16px 24px;
|
| 494 |
+
border-bottom: 1px solid #e5e7eb;
|
| 495 |
+
background: white;
|
| 496 |
+
}
|
| 497 |
+
|
| 498 |
+
.app-title {
|
| 499 |
+
font-size: 20px;
|
| 500 |
+
font-weight: 600;
|
| 501 |
+
color: #111827;
|
| 502 |
+
}
|
| 503 |
+
|
| 504 |
+
.back-btn {
|
| 505 |
+
background: transparent !important;
|
| 506 |
+
color: #6b7280 !important;
|
| 507 |
+
border: 1px solid #e5e7eb !important;
|
| 508 |
+
border-radius: 8px !important;
|
| 509 |
+
padding: 8px 16px !important;
|
| 510 |
+
font-size: 13px !important;
|
| 511 |
+
}
|
| 512 |
+
|
| 513 |
+
.back-btn:hover {
|
| 514 |
+
background: #f9fafb !important;
|
| 515 |
+
color: #111827 !important;
|
| 516 |
+
}
|
| 517 |
+
"""
|
guidelines/index/chunks.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
guidelines/index/faiss.index
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:faf9cd914f52b84a55d486a156b4756f28dbc1a92abeafc121077402e1fa53f4
|
| 3 |
+
size 145965
|
mcp_server/__init__.py
ADDED
|
File without changes
|
mcp_server/server.py
ADDED
|
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SkinProAI MCP Server - Pure JSON-RPC 2.0 stdio server (no mcp library required).
|
| 3 |
+
|
| 4 |
+
Uses sys.executable (venv Python) so all ML packages (torch, transformers, etc.)
|
| 5 |
+
are available. Tools are loaded lazily on first call.
|
| 6 |
+
|
| 7 |
+
Run standalone: python mcp_server/server.py
|
| 8 |
+
(Should start silently, waiting on stdin.)
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import sys
|
| 12 |
+
import json
|
| 13 |
+
import os
|
| 14 |
+
|
| 15 |
+
# Ensure project root is on path
|
| 16 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 17 |
+
|
| 18 |
+
from mcp_server.tool_registry import get_monet, get_convnext, get_gradcam, get_rag
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# ---------------------------------------------------------------------------
|
| 22 |
+
# Tool implementations
|
| 23 |
+
# ---------------------------------------------------------------------------
|
| 24 |
+
|
| 25 |
+
def _monet_analyze(arguments: dict) -> dict:
|
| 26 |
+
from PIL import Image
|
| 27 |
+
image = Image.open(arguments["image_path"]).convert("RGB")
|
| 28 |
+
return get_monet().analyze(image)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _classify_lesion(arguments: dict) -> dict:
|
| 32 |
+
from PIL import Image
|
| 33 |
+
image = Image.open(arguments["image_path"]).convert("RGB")
|
| 34 |
+
monet_scores = arguments.get("monet_scores")
|
| 35 |
+
return get_convnext().classify(
|
| 36 |
+
clinical_image=image,
|
| 37 |
+
derm_image=None,
|
| 38 |
+
monet_scores=monet_scores,
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _generate_gradcam(arguments: dict) -> dict:
|
| 43 |
+
from PIL import Image
|
| 44 |
+
import tempfile
|
| 45 |
+
image = Image.open(arguments["image_path"]).convert("RGB")
|
| 46 |
+
result = get_gradcam().analyze(image)
|
| 47 |
+
|
| 48 |
+
gradcam_file = tempfile.NamedTemporaryFile(suffix="_gradcam.png", delete=False)
|
| 49 |
+
gradcam_path = gradcam_file.name
|
| 50 |
+
gradcam_file.close()
|
| 51 |
+
result["overlay"].save(gradcam_path)
|
| 52 |
+
|
| 53 |
+
return {
|
| 54 |
+
"gradcam_path": gradcam_path,
|
| 55 |
+
"predicted_class": result["predicted_class"],
|
| 56 |
+
"predicted_class_full": result["predicted_class_full"],
|
| 57 |
+
"confidence": result["confidence"],
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def _search_guidelines(arguments: dict) -> dict:
|
| 62 |
+
query = arguments.get("query", "")
|
| 63 |
+
diagnosis = arguments.get("diagnosis") or ""
|
| 64 |
+
rag = get_rag()
|
| 65 |
+
context, references = rag.get_management_context(diagnosis, query)
|
| 66 |
+
references_display = rag.format_references_for_display(references)
|
| 67 |
+
return {
|
| 68 |
+
"context": context,
|
| 69 |
+
"references": references,
|
| 70 |
+
"references_display": references_display,
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def _compare_images(arguments: dict) -> dict:
|
| 75 |
+
from PIL import Image
|
| 76 |
+
import tempfile
|
| 77 |
+
image1 = Image.open(arguments["image1_path"]).convert("RGB")
|
| 78 |
+
image2 = Image.open(arguments["image2_path"]).convert("RGB")
|
| 79 |
+
|
| 80 |
+
from models.overlay_tool import get_overlay_tool
|
| 81 |
+
comparison = get_overlay_tool().generate_comparison_overlay(
|
| 82 |
+
image1, image2, label1="Previous", label2="Current"
|
| 83 |
+
)
|
| 84 |
+
comparison_path = comparison["path"]
|
| 85 |
+
|
| 86 |
+
monet = get_monet()
|
| 87 |
+
prev_result = monet.analyze(image1)
|
| 88 |
+
curr_result = monet.analyze(image2)
|
| 89 |
+
|
| 90 |
+
monet_deltas = {}
|
| 91 |
+
for name in curr_result["features"]:
|
| 92 |
+
prev_val = prev_result["features"].get(name, 0.0)
|
| 93 |
+
curr_val = curr_result["features"][name]
|
| 94 |
+
delta = curr_val - prev_val
|
| 95 |
+
if abs(delta) > 0.1:
|
| 96 |
+
monet_deltas[name] = {
|
| 97 |
+
"previous": prev_val,
|
| 98 |
+
"current": curr_val,
|
| 99 |
+
"delta": delta,
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
# Generate GradCAM for both images so the frontend can show a side-by-side comparison
|
| 103 |
+
prev_gradcam_path = None
|
| 104 |
+
curr_gradcam_path = None
|
| 105 |
+
try:
|
| 106 |
+
gradcam = get_gradcam()
|
| 107 |
+
prev_gc = gradcam.analyze(image1)
|
| 108 |
+
curr_gc = gradcam.analyze(image2)
|
| 109 |
+
|
| 110 |
+
f1 = tempfile.NamedTemporaryFile(suffix="_gradcam.png", delete=False)
|
| 111 |
+
prev_gradcam_path = f1.name
|
| 112 |
+
f1.close()
|
| 113 |
+
prev_gc["overlay"].save(prev_gradcam_path)
|
| 114 |
+
|
| 115 |
+
f2 = tempfile.NamedTemporaryFile(suffix="_gradcam.png", delete=False)
|
| 116 |
+
curr_gradcam_path = f2.name
|
| 117 |
+
f2.close()
|
| 118 |
+
curr_gc["overlay"].save(curr_gradcam_path)
|
| 119 |
+
except Exception:
|
| 120 |
+
pass # GradCAM comparison is best-effort
|
| 121 |
+
|
| 122 |
+
return {
|
| 123 |
+
"comparison_path": comparison_path,
|
| 124 |
+
"monet_deltas": monet_deltas,
|
| 125 |
+
"prev_gradcam_path": prev_gradcam_path,
|
| 126 |
+
"curr_gradcam_path": curr_gradcam_path,
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
TOOLS = {
|
| 131 |
+
"monet_analyze": _monet_analyze,
|
| 132 |
+
"classify_lesion": _classify_lesion,
|
| 133 |
+
"generate_gradcam": _generate_gradcam,
|
| 134 |
+
"search_guidelines": _search_guidelines,
|
| 135 |
+
"compare_images": _compare_images,
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
TOOLS_LIST = [
|
| 139 |
+
{
|
| 140 |
+
"name": "monet_analyze",
|
| 141 |
+
"description": "Extract MONET concept-presence scores from a skin lesion image.",
|
| 142 |
+
"inputSchema": {
|
| 143 |
+
"type": "object",
|
| 144 |
+
"properties": {"image_path": {"type": "string"}},
|
| 145 |
+
"required": ["image_path"],
|
| 146 |
+
},
|
| 147 |
+
},
|
| 148 |
+
{
|
| 149 |
+
"name": "classify_lesion",
|
| 150 |
+
"description": "Classify a skin lesion using ConvNeXt dual-encoder.",
|
| 151 |
+
"inputSchema": {
|
| 152 |
+
"type": "object",
|
| 153 |
+
"properties": {
|
| 154 |
+
"image_path": {"type": "string"},
|
| 155 |
+
"monet_scores": {"type": "array"},
|
| 156 |
+
},
|
| 157 |
+
"required": ["image_path"],
|
| 158 |
+
},
|
| 159 |
+
},
|
| 160 |
+
{
|
| 161 |
+
"name": "generate_gradcam",
|
| 162 |
+
"description": "Generate a Grad-CAM attention overlay for a skin lesion image.",
|
| 163 |
+
"inputSchema": {
|
| 164 |
+
"type": "object",
|
| 165 |
+
"properties": {"image_path": {"type": "string"}},
|
| 166 |
+
"required": ["image_path"],
|
| 167 |
+
},
|
| 168 |
+
},
|
| 169 |
+
{
|
| 170 |
+
"name": "search_guidelines",
|
| 171 |
+
"description": "Search clinical guidelines RAG for management context.",
|
| 172 |
+
"inputSchema": {
|
| 173 |
+
"type": "object",
|
| 174 |
+
"properties": {
|
| 175 |
+
"query": {"type": "string"},
|
| 176 |
+
"diagnosis": {"type": "string"},
|
| 177 |
+
},
|
| 178 |
+
"required": ["query"],
|
| 179 |
+
},
|
| 180 |
+
},
|
| 181 |
+
{
|
| 182 |
+
"name": "compare_images",
|
| 183 |
+
"description": "Generate comparison overlay and MONET deltas for two lesion images.",
|
| 184 |
+
"inputSchema": {
|
| 185 |
+
"type": "object",
|
| 186 |
+
"properties": {
|
| 187 |
+
"image1_path": {"type": "string"},
|
| 188 |
+
"image2_path": {"type": "string"},
|
| 189 |
+
},
|
| 190 |
+
"required": ["image1_path", "image2_path"],
|
| 191 |
+
},
|
| 192 |
+
},
|
| 193 |
+
]
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
# ---------------------------------------------------------------------------
|
| 197 |
+
# JSON-RPC 2.0 dispatcher
|
| 198 |
+
# ---------------------------------------------------------------------------
|
| 199 |
+
|
| 200 |
+
def handle_request(request: dict):
|
| 201 |
+
method = request.get("method")
|
| 202 |
+
req_id = request.get("id") # None for notifications
|
| 203 |
+
params = request.get("params", {})
|
| 204 |
+
|
| 205 |
+
if method == "initialize":
|
| 206 |
+
return {
|
| 207 |
+
"jsonrpc": "2.0",
|
| 208 |
+
"id": req_id,
|
| 209 |
+
"result": {
|
| 210 |
+
"protocolVersion": "2024-11-05",
|
| 211 |
+
"capabilities": {"tools": {"listChanged": False}},
|
| 212 |
+
"serverInfo": {"name": "SkinProAI", "version": "1.0.0"},
|
| 213 |
+
},
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
if method in ("notifications/initialized",):
|
| 217 |
+
return None # notification — no response
|
| 218 |
+
|
| 219 |
+
if method == "tools/list":
|
| 220 |
+
return {
|
| 221 |
+
"jsonrpc": "2.0",
|
| 222 |
+
"id": req_id,
|
| 223 |
+
"result": {"tools": TOOLS_LIST},
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
if method == "tools/call":
|
| 227 |
+
name = params.get("name")
|
| 228 |
+
arguments = params.get("arguments", {})
|
| 229 |
+
if name not in TOOLS:
|
| 230 |
+
return {
|
| 231 |
+
"jsonrpc": "2.0",
|
| 232 |
+
"id": req_id,
|
| 233 |
+
"error": {"code": -32601, "message": f"Unknown tool: {name}"},
|
| 234 |
+
}
|
| 235 |
+
try:
|
| 236 |
+
result = TOOLS[name](arguments)
|
| 237 |
+
return {
|
| 238 |
+
"jsonrpc": "2.0",
|
| 239 |
+
"id": req_id,
|
| 240 |
+
"result": {
|
| 241 |
+
"content": [{"type": "text", "text": json.dumps(result)}],
|
| 242 |
+
"isError": False,
|
| 243 |
+
},
|
| 244 |
+
}
|
| 245 |
+
except Exception as e:
|
| 246 |
+
return {
|
| 247 |
+
"jsonrpc": "2.0",
|
| 248 |
+
"id": req_id,
|
| 249 |
+
"result": {
|
| 250 |
+
"content": [{"type": "text", "text": f"Tool error: {e}"}],
|
| 251 |
+
"isError": True,
|
| 252 |
+
},
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
# Unknown method with id → method not found
|
| 256 |
+
if req_id is not None:
|
| 257 |
+
return {
|
| 258 |
+
"jsonrpc": "2.0",
|
| 259 |
+
"id": req_id,
|
| 260 |
+
"error": {"code": -32601, "message": f"Method not found: {method}"},
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
return None # unknown notification — ignore
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
# ---------------------------------------------------------------------------
|
| 267 |
+
# Main loop
|
| 268 |
+
# ---------------------------------------------------------------------------
|
| 269 |
+
|
| 270 |
+
def main():
|
| 271 |
+
for line in sys.stdin:
|
| 272 |
+
line = line.strip()
|
| 273 |
+
if not line:
|
| 274 |
+
continue
|
| 275 |
+
try:
|
| 276 |
+
request = json.loads(line)
|
| 277 |
+
except json.JSONDecodeError:
|
| 278 |
+
continue
|
| 279 |
+
response = handle_request(request)
|
| 280 |
+
if response is not None:
|
| 281 |
+
sys.stdout.write(json.dumps(response) + "\n")
|
| 282 |
+
sys.stdout.flush()
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
if __name__ == "__main__":
|
| 286 |
+
main()
|
mcp_server/tool_registry.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Lazy singleton loader for all 4 ML models used by the MCP server.
|
| 3 |
+
Fixes sys.path so the subprocess can import from models/.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import sys
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
# Ensure project root is on path (this file lives at project_root/mcp_server/tool_registry.py)
|
| 10 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 11 |
+
|
| 12 |
+
_monet = None
|
| 13 |
+
_convnext = None
|
| 14 |
+
_gradcam = None
|
| 15 |
+
_rag = None
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def get_monet():
|
| 19 |
+
global _monet
|
| 20 |
+
if _monet is None:
|
| 21 |
+
from models.monet_tool import MonetTool
|
| 22 |
+
_monet = MonetTool()
|
| 23 |
+
_monet.load()
|
| 24 |
+
return _monet
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def get_convnext():
|
| 28 |
+
global _convnext
|
| 29 |
+
if _convnext is None:
|
| 30 |
+
from models.convnext_classifier import ConvNeXtClassifier
|
| 31 |
+
root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 32 |
+
_convnext = ConvNeXtClassifier(
|
| 33 |
+
checkpoint_path=os.path.join(root, "models", "seed42_fold0.pt")
|
| 34 |
+
)
|
| 35 |
+
_convnext.load()
|
| 36 |
+
return _convnext
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def get_gradcam():
|
| 40 |
+
global _gradcam
|
| 41 |
+
if _gradcam is None:
|
| 42 |
+
from models.gradcam_tool import GradCAMTool
|
| 43 |
+
_gradcam = GradCAMTool(classifier=get_convnext())
|
| 44 |
+
_gradcam.load()
|
| 45 |
+
return _gradcam
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def get_rag():
|
| 49 |
+
global _rag
|
| 50 |
+
if _rag is None:
|
| 51 |
+
from models.guidelines_rag import get_guidelines_rag
|
| 52 |
+
_rag = get_guidelines_rag()
|
| 53 |
+
if not _rag.loaded:
|
| 54 |
+
_rag.load_index()
|
| 55 |
+
return _rag
|
models/convnext_classifier.py
ADDED
|
@@ -0,0 +1,383 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ConvNeXt Classifier Tool - Skin lesion classification using ConvNeXt + MONET features
|
| 3 |
+
Loads seed42_fold0.pt checkpoint and performs classification.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import numpy as np
|
| 9 |
+
from PIL import Image
|
| 10 |
+
from torchvision import transforms
|
| 11 |
+
from typing import Optional, Dict, List, Tuple
|
| 12 |
+
import timm
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# Class names for the 11-class skin lesion classification
|
| 16 |
+
CLASS_NAMES = [
|
| 17 |
+
'AKIEC', 'BCC', 'BEN_OTH', 'BKL', 'DF',
|
| 18 |
+
'INF', 'MAL_OTH', 'MEL', 'NV', 'SCCKA', 'VASC'
|
| 19 |
+
]
|
| 20 |
+
|
| 21 |
+
CLASS_FULL_NAMES = {
|
| 22 |
+
'AKIEC': 'Actinic Keratosis / Intraepithelial Carcinoma',
|
| 23 |
+
'BCC': 'Basal Cell Carcinoma',
|
| 24 |
+
'BEN_OTH': 'Benign Other',
|
| 25 |
+
'BKL': 'Benign Keratosis-like Lesion',
|
| 26 |
+
'DF': 'Dermatofibroma',
|
| 27 |
+
'INF': 'Inflammatory',
|
| 28 |
+
'MAL_OTH': 'Malignant Other',
|
| 29 |
+
'MEL': 'Melanoma',
|
| 30 |
+
'NV': 'Melanocytic Nevus',
|
| 31 |
+
'SCCKA': 'Squamous Cell Carcinoma / Keratoacanthoma',
|
| 32 |
+
'VASC': 'Vascular Lesion'
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class ConvNeXtDualEncoder(nn.Module):
|
| 37 |
+
"""
|
| 38 |
+
Dual-image ConvNeXt model matching the trained checkpoint.
|
| 39 |
+
Processes BOTH clinical and dermoscopy images through shared backbone.
|
| 40 |
+
|
| 41 |
+
Metadata input: 19 dimensions
|
| 42 |
+
- age (1): normalized age
|
| 43 |
+
- sex (4): one-hot encoded
|
| 44 |
+
- site (7): one-hot encoded (reduced from 14)
|
| 45 |
+
- MONET (7): 7 MONET feature scores
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
def __init__(
|
| 49 |
+
self,
|
| 50 |
+
model_name: str = 'convnext_base.fb_in22k_ft_in1k',
|
| 51 |
+
metadata_dim: int = 19,
|
| 52 |
+
num_classes: int = 11,
|
| 53 |
+
dropout: float = 0.3
|
| 54 |
+
):
|
| 55 |
+
super().__init__()
|
| 56 |
+
|
| 57 |
+
self.backbone = timm.create_model(
|
| 58 |
+
model_name,
|
| 59 |
+
pretrained=False,
|
| 60 |
+
num_classes=0
|
| 61 |
+
)
|
| 62 |
+
backbone_dim = self.backbone.num_features # 1024 for convnext_base
|
| 63 |
+
|
| 64 |
+
# Metadata MLP: 19 -> 64
|
| 65 |
+
self.meta_mlp = nn.Sequential(
|
| 66 |
+
nn.Linear(metadata_dim, 64),
|
| 67 |
+
nn.LayerNorm(64),
|
| 68 |
+
nn.GELU(),
|
| 69 |
+
nn.Dropout(dropout)
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
# Classifier: 2112 -> 512 -> 256 -> 11
|
| 73 |
+
# Input: clinical(1024) + derm(1024) + meta(64) = 2112
|
| 74 |
+
fusion_dim = backbone_dim * 2 + 64
|
| 75 |
+
self.classifier = nn.Sequential(
|
| 76 |
+
nn.Linear(fusion_dim, 512),
|
| 77 |
+
nn.LayerNorm(512),
|
| 78 |
+
nn.GELU(),
|
| 79 |
+
nn.Dropout(dropout),
|
| 80 |
+
nn.Linear(512, 256),
|
| 81 |
+
nn.LayerNorm(256),
|
| 82 |
+
nn.GELU(),
|
| 83 |
+
nn.Dropout(dropout),
|
| 84 |
+
nn.Linear(256, num_classes)
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
self.metadata_dim = metadata_dim
|
| 88 |
+
self.num_classes = num_classes
|
| 89 |
+
self.backbone_dim = backbone_dim
|
| 90 |
+
|
| 91 |
+
def forward(
|
| 92 |
+
self,
|
| 93 |
+
clinical_img: torch.Tensor,
|
| 94 |
+
derm_img: Optional[torch.Tensor] = None,
|
| 95 |
+
metadata: Optional[torch.Tensor] = None
|
| 96 |
+
) -> torch.Tensor:
|
| 97 |
+
"""
|
| 98 |
+
Forward pass with dual images.
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
clinical_img: [B, 3, H, W] clinical image tensor
|
| 102 |
+
derm_img: [B, 3, H, W] dermoscopy image tensor (uses clinical if None)
|
| 103 |
+
metadata: [B, 19] metadata tensor (zeros if None)
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
logits: [B, 11]
|
| 107 |
+
"""
|
| 108 |
+
# Process clinical image
|
| 109 |
+
clinical_features = self.backbone(clinical_img)
|
| 110 |
+
|
| 111 |
+
# Process dermoscopy image
|
| 112 |
+
if derm_img is not None:
|
| 113 |
+
derm_features = self.backbone(derm_img)
|
| 114 |
+
else:
|
| 115 |
+
derm_features = clinical_features
|
| 116 |
+
|
| 117 |
+
# Process metadata
|
| 118 |
+
if metadata is not None:
|
| 119 |
+
meta_features = self.meta_mlp(metadata)
|
| 120 |
+
else:
|
| 121 |
+
batch_size = clinical_features.size(0)
|
| 122 |
+
meta_features = torch.zeros(
|
| 123 |
+
batch_size, 64,
|
| 124 |
+
device=clinical_features.device
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
# Concatenate: [B, 1024] + [B, 1024] + [B, 64] = [B, 2112]
|
| 128 |
+
fused = torch.cat([clinical_features, derm_features, meta_features], dim=1)
|
| 129 |
+
logits = self.classifier(fused)
|
| 130 |
+
|
| 131 |
+
return logits
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class ConvNeXtClassifier:
|
| 135 |
+
"""
|
| 136 |
+
ConvNeXt classifier tool for skin lesion classification.
|
| 137 |
+
Uses dual images (clinical + dermoscopy) and MONET features.
|
| 138 |
+
"""
|
| 139 |
+
|
| 140 |
+
# Site mapping for metadata encoding
|
| 141 |
+
SITE_MAPPING = {
|
| 142 |
+
'head': 0, 'neck': 0, 'face': 0, # head_neck_face
|
| 143 |
+
'trunk': 1, 'back': 1, 'chest': 1, 'abdomen': 1,
|
| 144 |
+
'upper': 2, 'arm': 2, 'hand': 2, # upper extremity
|
| 145 |
+
'lower': 3, 'leg': 3, 'foot': 3, 'thigh': 3, # lower extremity
|
| 146 |
+
'genital': 4, 'oral': 5, 'acral': 6,
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
SEX_MAPPING = {'male': 0, 'female': 1, 'other': 2, 'unknown': 3}
|
| 150 |
+
|
| 151 |
+
def __init__(
|
| 152 |
+
self,
|
| 153 |
+
checkpoint_path: str = "models/seed42_fold0.pt",
|
| 154 |
+
device: Optional[str] = None
|
| 155 |
+
):
|
| 156 |
+
self.checkpoint_path = checkpoint_path
|
| 157 |
+
self.device = device
|
| 158 |
+
self.model = None
|
| 159 |
+
self.loaded = False
|
| 160 |
+
|
| 161 |
+
# Image preprocessing
|
| 162 |
+
self.transform = transforms.Compose([
|
| 163 |
+
transforms.Resize((384, 384)),
|
| 164 |
+
transforms.ToTensor(),
|
| 165 |
+
transforms.Normalize(
|
| 166 |
+
mean=[0.485, 0.456, 0.406],
|
| 167 |
+
std=[0.229, 0.224, 0.225]
|
| 168 |
+
)
|
| 169 |
+
])
|
| 170 |
+
|
| 171 |
+
def load(self):
|
| 172 |
+
"""Load the ConvNeXt model from checkpoint"""
|
| 173 |
+
if self.loaded:
|
| 174 |
+
return
|
| 175 |
+
|
| 176 |
+
# Determine device
|
| 177 |
+
if self.device is None:
|
| 178 |
+
if torch.cuda.is_available():
|
| 179 |
+
self.device = "cuda"
|
| 180 |
+
elif torch.backends.mps.is_available():
|
| 181 |
+
self.device = "mps"
|
| 182 |
+
else:
|
| 183 |
+
self.device = "cpu"
|
| 184 |
+
|
| 185 |
+
# Create model
|
| 186 |
+
self.model = ConvNeXtDualEncoder(
|
| 187 |
+
model_name='convnext_base.fb_in22k_ft_in1k',
|
| 188 |
+
metadata_dim=19,
|
| 189 |
+
num_classes=11,
|
| 190 |
+
dropout=0.3
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
# Load checkpoint
|
| 194 |
+
checkpoint = torch.load(
|
| 195 |
+
self.checkpoint_path,
|
| 196 |
+
map_location=self.device,
|
| 197 |
+
weights_only=False
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
|
| 201 |
+
self.model.load_state_dict(checkpoint['model_state_dict'])
|
| 202 |
+
else:
|
| 203 |
+
self.model.load_state_dict(checkpoint)
|
| 204 |
+
|
| 205 |
+
self.model.to(self.device)
|
| 206 |
+
self.model.eval()
|
| 207 |
+
self.loaded = True
|
| 208 |
+
|
| 209 |
+
def encode_metadata(
|
| 210 |
+
self,
|
| 211 |
+
age: Optional[float] = None,
|
| 212 |
+
sex: Optional[str] = None,
|
| 213 |
+
site: Optional[str] = None,
|
| 214 |
+
monet_scores: Optional[List[float]] = None
|
| 215 |
+
) -> torch.Tensor:
|
| 216 |
+
"""
|
| 217 |
+
Encode metadata into 19-dim vector.
|
| 218 |
+
|
| 219 |
+
Layout: [age(1), sex(4), site(7), monet(7)] = 19
|
| 220 |
+
|
| 221 |
+
Args:
|
| 222 |
+
age: Patient age in years
|
| 223 |
+
sex: 'male', 'female', 'other', or None
|
| 224 |
+
site: Anatomical site string
|
| 225 |
+
monet_scores: List of 7 MONET feature scores
|
| 226 |
+
|
| 227 |
+
Returns:
|
| 228 |
+
torch.Tensor of shape [19]
|
| 229 |
+
"""
|
| 230 |
+
features = []
|
| 231 |
+
|
| 232 |
+
# Age (1 dim) - normalized
|
| 233 |
+
age_norm = (age - 50) / 30 if age is not None else 0.0
|
| 234 |
+
features.append(age_norm)
|
| 235 |
+
|
| 236 |
+
# Sex (4 dim) - one-hot
|
| 237 |
+
sex_onehot = [0.0] * 4
|
| 238 |
+
if sex:
|
| 239 |
+
sex_idx = self.SEX_MAPPING.get(sex.lower(), 3)
|
| 240 |
+
sex_onehot[sex_idx] = 1.0
|
| 241 |
+
features.extend(sex_onehot)
|
| 242 |
+
|
| 243 |
+
# Site (7 dim) - one-hot
|
| 244 |
+
site_onehot = [0.0] * 7
|
| 245 |
+
if site:
|
| 246 |
+
site_lower = site.lower()
|
| 247 |
+
for key, idx in self.SITE_MAPPING.items():
|
| 248 |
+
if key in site_lower:
|
| 249 |
+
site_onehot[idx] = 1.0
|
| 250 |
+
break
|
| 251 |
+
features.extend(site_onehot)
|
| 252 |
+
|
| 253 |
+
# MONET (7 dim)
|
| 254 |
+
if monet_scores is not None and len(monet_scores) == 7:
|
| 255 |
+
features.extend(monet_scores)
|
| 256 |
+
else:
|
| 257 |
+
features.extend([0.0] * 7)
|
| 258 |
+
|
| 259 |
+
return torch.tensor(features, dtype=torch.float32)
|
| 260 |
+
|
| 261 |
+
def preprocess_image(self, image: Image.Image) -> torch.Tensor:
|
| 262 |
+
"""Preprocess PIL image for model input"""
|
| 263 |
+
if image.mode != "RGB":
|
| 264 |
+
image = image.convert("RGB")
|
| 265 |
+
return self.transform(image).unsqueeze(0)
|
| 266 |
+
|
| 267 |
+
def classify(
|
| 268 |
+
self,
|
| 269 |
+
clinical_image: Image.Image,
|
| 270 |
+
derm_image: Optional[Image.Image] = None,
|
| 271 |
+
age: Optional[float] = None,
|
| 272 |
+
sex: Optional[str] = None,
|
| 273 |
+
site: Optional[str] = None,
|
| 274 |
+
monet_scores: Optional[List[float]] = None,
|
| 275 |
+
top_k: int = 5
|
| 276 |
+
) -> Dict:
|
| 277 |
+
"""
|
| 278 |
+
Classify a skin lesion.
|
| 279 |
+
|
| 280 |
+
Args:
|
| 281 |
+
clinical_image: Clinical (close-up) image
|
| 282 |
+
derm_image: Dermoscopy image (optional, uses clinical if None)
|
| 283 |
+
age: Patient age
|
| 284 |
+
sex: Patient sex
|
| 285 |
+
site: Anatomical site
|
| 286 |
+
monet_scores: 7 MONET feature scores
|
| 287 |
+
top_k: Number of top predictions to return
|
| 288 |
+
|
| 289 |
+
Returns:
|
| 290 |
+
dict with 'predictions', 'probabilities', 'top_class', 'confidence'
|
| 291 |
+
"""
|
| 292 |
+
if not self.loaded:
|
| 293 |
+
self.load()
|
| 294 |
+
|
| 295 |
+
# Preprocess images
|
| 296 |
+
clinical_tensor = self.preprocess_image(clinical_image).to(self.device)
|
| 297 |
+
|
| 298 |
+
if derm_image is not None:
|
| 299 |
+
derm_tensor = self.preprocess_image(derm_image).to(self.device)
|
| 300 |
+
else:
|
| 301 |
+
derm_tensor = None
|
| 302 |
+
|
| 303 |
+
# Encode metadata
|
| 304 |
+
metadata = self.encode_metadata(age, sex, site, monet_scores)
|
| 305 |
+
metadata_tensor = metadata.unsqueeze(0).to(self.device)
|
| 306 |
+
|
| 307 |
+
# Run inference
|
| 308 |
+
with torch.no_grad():
|
| 309 |
+
logits = self.model(clinical_tensor, derm_tensor, metadata_tensor)
|
| 310 |
+
probs = torch.softmax(logits, dim=1)[0].cpu().numpy()
|
| 311 |
+
|
| 312 |
+
# Get top-k predictions
|
| 313 |
+
top_indices = np.argsort(probs)[::-1][:top_k]
|
| 314 |
+
|
| 315 |
+
predictions = []
|
| 316 |
+
for idx in top_indices:
|
| 317 |
+
predictions.append({
|
| 318 |
+
'class': CLASS_NAMES[idx],
|
| 319 |
+
'full_name': CLASS_FULL_NAMES[CLASS_NAMES[idx]],
|
| 320 |
+
'probability': float(probs[idx])
|
| 321 |
+
})
|
| 322 |
+
|
| 323 |
+
return {
|
| 324 |
+
'predictions': predictions,
|
| 325 |
+
'probabilities': probs.tolist(),
|
| 326 |
+
'top_class': CLASS_NAMES[top_indices[0]],
|
| 327 |
+
'confidence': float(probs[top_indices[0]]),
|
| 328 |
+
'all_classes': CLASS_NAMES,
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
def __call__(
|
| 332 |
+
self,
|
| 333 |
+
clinical_image: Image.Image,
|
| 334 |
+
derm_image: Optional[Image.Image] = None,
|
| 335 |
+
**kwargs
|
| 336 |
+
) -> Dict:
|
| 337 |
+
"""Shorthand for classify()"""
|
| 338 |
+
return self.classify(clinical_image, derm_image, **kwargs)
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
# Singleton instance
|
| 342 |
+
_convnext_instance = None
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
def get_convnext_classifier(checkpoint_path: str = "models/seed42_fold0.pt") -> ConvNeXtClassifier:
|
| 346 |
+
"""Get or create ConvNeXt classifier instance"""
|
| 347 |
+
global _convnext_instance
|
| 348 |
+
if _convnext_instance is None:
|
| 349 |
+
_convnext_instance = ConvNeXtClassifier(checkpoint_path)
|
| 350 |
+
return _convnext_instance
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
if __name__ == "__main__":
|
| 354 |
+
import sys
|
| 355 |
+
|
| 356 |
+
print("ConvNeXt Classifier Test")
|
| 357 |
+
print("=" * 50)
|
| 358 |
+
|
| 359 |
+
classifier = ConvNeXtClassifier()
|
| 360 |
+
print("Loading model...")
|
| 361 |
+
classifier.load()
|
| 362 |
+
print("Model loaded!")
|
| 363 |
+
|
| 364 |
+
if len(sys.argv) > 1:
|
| 365 |
+
image_path = sys.argv[1]
|
| 366 |
+
print(f"\nClassifying: {image_path}")
|
| 367 |
+
|
| 368 |
+
image = Image.open(image_path).convert("RGB")
|
| 369 |
+
|
| 370 |
+
# Example with mock MONET scores
|
| 371 |
+
monet_scores = [0.2, 0.1, 0.05, 0.3, 0.7, 0.1, 0.05]
|
| 372 |
+
|
| 373 |
+
result = classifier.classify(
|
| 374 |
+
clinical_image=image,
|
| 375 |
+
age=55,
|
| 376 |
+
sex="male",
|
| 377 |
+
site="back",
|
| 378 |
+
monet_scores=monet_scores
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
print("\nTop Predictions:")
|
| 382 |
+
for pred in result['predictions']:
|
| 383 |
+
print(f" {pred['probability']:.1%} - {pred['class']} ({pred['full_name']})")
|
models/explainability.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# models/explainability.py
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import numpy as np
|
| 6 |
+
import cv2
|
| 7 |
+
from typing import Tuple
|
| 8 |
+
from PIL import Image
|
| 9 |
+
|
| 10 |
+
class GradCAM:
|
| 11 |
+
"""
|
| 12 |
+
Gradient-weighted Class Activation Mapping
|
| 13 |
+
Shows which regions of image are important for prediction
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, model: torch.nn.Module, target_layer: str = None):
|
| 17 |
+
"""
|
| 18 |
+
Args:
|
| 19 |
+
model: The neural network
|
| 20 |
+
target_layer: Layer name to compute CAM on (usually last conv layer)
|
| 21 |
+
"""
|
| 22 |
+
self.model = model
|
| 23 |
+
self.gradients = None
|
| 24 |
+
self.activations = None
|
| 25 |
+
|
| 26 |
+
# Auto-detect target layer if not specified
|
| 27 |
+
if target_layer is None:
|
| 28 |
+
# Use last ConvNeXt stage
|
| 29 |
+
self.target_layer = model.convnext.stages[-1]
|
| 30 |
+
else:
|
| 31 |
+
self.target_layer = dict(model.named_modules())[target_layer]
|
| 32 |
+
|
| 33 |
+
# Register hooks
|
| 34 |
+
self.target_layer.register_forward_hook(self._save_activation)
|
| 35 |
+
self.target_layer.register_full_backward_hook(self._save_gradient)
|
| 36 |
+
|
| 37 |
+
def _save_activation(self, module, input, output):
|
| 38 |
+
"""Save forward activations"""
|
| 39 |
+
self.activations = output.detach()
|
| 40 |
+
|
| 41 |
+
def _save_gradient(self, module, grad_input, grad_output):
|
| 42 |
+
"""Save backward gradients"""
|
| 43 |
+
self.gradients = grad_output[0].detach()
|
| 44 |
+
|
| 45 |
+
def generate_cam(
|
| 46 |
+
self,
|
| 47 |
+
image: torch.Tensor,
|
| 48 |
+
target_class: int = None
|
| 49 |
+
) -> np.ndarray:
|
| 50 |
+
"""
|
| 51 |
+
Generate Class Activation Map
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
image: Input image [1, 3, H, W]
|
| 55 |
+
target_class: Class to generate CAM for (None = predicted class)
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
cam: Activation map [H, W] normalized to 0-1
|
| 59 |
+
"""
|
| 60 |
+
self.model.eval()
|
| 61 |
+
|
| 62 |
+
# Forward pass
|
| 63 |
+
output = self.model(image)
|
| 64 |
+
|
| 65 |
+
# Use predicted class if not specified
|
| 66 |
+
if target_class is None:
|
| 67 |
+
target_class = output.argmax(dim=1).item()
|
| 68 |
+
|
| 69 |
+
# Zero gradients
|
| 70 |
+
self.model.zero_grad()
|
| 71 |
+
|
| 72 |
+
# Backward pass for target class
|
| 73 |
+
output[0, target_class].backward()
|
| 74 |
+
|
| 75 |
+
# Get gradients and activations
|
| 76 |
+
gradients = self.gradients[0] # [C, H, W]
|
| 77 |
+
activations = self.activations[0] # [C, H, W]
|
| 78 |
+
|
| 79 |
+
# Global average pooling of gradients
|
| 80 |
+
weights = gradients.mean(dim=(1, 2)) # [C]
|
| 81 |
+
|
| 82 |
+
# Weighted sum of activations
|
| 83 |
+
cam = torch.zeros(activations.shape[1:], dtype=torch.float32)
|
| 84 |
+
for i, w in enumerate(weights):
|
| 85 |
+
cam += w * activations[i]
|
| 86 |
+
|
| 87 |
+
# ReLU
|
| 88 |
+
cam = F.relu(cam)
|
| 89 |
+
|
| 90 |
+
# Normalize to 0-1
|
| 91 |
+
cam = cam.cpu().numpy()
|
| 92 |
+
cam = cam - cam.min()
|
| 93 |
+
if cam.max() > 0:
|
| 94 |
+
cam = cam / cam.max()
|
| 95 |
+
|
| 96 |
+
return cam
|
| 97 |
+
|
| 98 |
+
def overlay_cam_on_image(
|
| 99 |
+
self,
|
| 100 |
+
image: np.ndarray, # [H, W, 3] RGB
|
| 101 |
+
cam: np.ndarray, # [h, w]
|
| 102 |
+
alpha: float = 0.5,
|
| 103 |
+
colormap: int = cv2.COLORMAP_JET
|
| 104 |
+
) -> np.ndarray:
|
| 105 |
+
"""
|
| 106 |
+
Overlay CAM heatmap on original image
|
| 107 |
+
|
| 108 |
+
Returns:
|
| 109 |
+
overlay: [H, W, 3] RGB image with heatmap
|
| 110 |
+
"""
|
| 111 |
+
H, W = image.shape[:2]
|
| 112 |
+
|
| 113 |
+
# Resize CAM to image size
|
| 114 |
+
cam_resized = cv2.resize(cam, (W, H))
|
| 115 |
+
|
| 116 |
+
# Convert to heatmap
|
| 117 |
+
heatmap = cv2.applyColorMap(
|
| 118 |
+
np.uint8(255 * cam_resized),
|
| 119 |
+
colormap
|
| 120 |
+
)
|
| 121 |
+
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
|
| 122 |
+
|
| 123 |
+
# Blend with original image
|
| 124 |
+
overlay = (alpha * heatmap + (1 - alpha) * image).astype(np.uint8)
|
| 125 |
+
|
| 126 |
+
return overlay
|
| 127 |
+
|
| 128 |
+
class AttentionVisualizer:
|
| 129 |
+
"""Visualize MedSigLIP attention maps"""
|
| 130 |
+
|
| 131 |
+
def __init__(self, model):
|
| 132 |
+
self.model = model
|
| 133 |
+
|
| 134 |
+
def get_attention_maps(self, image: torch.Tensor) -> np.ndarray:
|
| 135 |
+
"""
|
| 136 |
+
Extract attention maps from MedSigLIP
|
| 137 |
+
|
| 138 |
+
Returns:
|
| 139 |
+
attention: [num_heads, H, W] attention weights
|
| 140 |
+
"""
|
| 141 |
+
# Forward pass
|
| 142 |
+
with torch.no_grad():
|
| 143 |
+
_ = self.model(image)
|
| 144 |
+
|
| 145 |
+
# Get last layer attention from MedSigLIP
|
| 146 |
+
# Shape: [batch, num_heads, seq_len, seq_len]
|
| 147 |
+
attention = self.model.medsiglip_features
|
| 148 |
+
|
| 149 |
+
# Average across heads and extract spatial attention
|
| 150 |
+
# This is model-dependent - adjust based on MedSigLIP architecture
|
| 151 |
+
|
| 152 |
+
# Placeholder implementation
|
| 153 |
+
# You'll need to adapt this to your specific MedSigLIP implementation
|
| 154 |
+
return np.random.rand(14, 14) # Placeholder
|
| 155 |
+
|
| 156 |
+
def overlay_attention(
|
| 157 |
+
self,
|
| 158 |
+
image: np.ndarray,
|
| 159 |
+
attention: np.ndarray,
|
| 160 |
+
alpha: float = 0.6
|
| 161 |
+
) -> np.ndarray:
|
| 162 |
+
"""Overlay attention map on image"""
|
| 163 |
+
H, W = image.shape[:2]
|
| 164 |
+
|
| 165 |
+
# Resize attention to image size
|
| 166 |
+
attention_resized = cv2.resize(attention, (W, H))
|
| 167 |
+
|
| 168 |
+
# Normalize
|
| 169 |
+
attention_resized = (attention_resized - attention_resized.min())
|
| 170 |
+
if attention_resized.max() > 0:
|
| 171 |
+
attention_resized = attention_resized / attention_resized.max()
|
| 172 |
+
|
| 173 |
+
# Create colored overlay
|
| 174 |
+
heatmap = cv2.applyColorMap(
|
| 175 |
+
np.uint8(255 * attention_resized),
|
| 176 |
+
cv2.COLORMAP_VIRIDIS
|
| 177 |
+
)
|
| 178 |
+
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
|
| 179 |
+
|
| 180 |
+
# Blend
|
| 181 |
+
overlay = (alpha * heatmap + (1 - alpha) * image).astype(np.uint8)
|
| 182 |
+
|
| 183 |
+
return overlay
|
models/gradcam_tool.py
ADDED
|
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Grad-CAM Tool - Visual explanation of ConvNeXt predictions
|
| 3 |
+
Shows which regions of the image the model focuses on.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
import numpy as np
|
| 9 |
+
from PIL import Image
|
| 10 |
+
from torchvision import transforms
|
| 11 |
+
from typing import Optional, Tuple
|
| 12 |
+
import cv2
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class GradCAM:
|
| 16 |
+
"""
|
| 17 |
+
Grad-CAM implementation for ConvNeXt model.
|
| 18 |
+
Generates heatmaps showing model attention.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self, model, target_layer=None):
|
| 22 |
+
"""
|
| 23 |
+
Args:
|
| 24 |
+
model: ConvNeXtDualEncoder model
|
| 25 |
+
target_layer: Layer to extract gradients from (default: last conv layer)
|
| 26 |
+
"""
|
| 27 |
+
self.model = model
|
| 28 |
+
self.gradients = None
|
| 29 |
+
self.activations = None
|
| 30 |
+
|
| 31 |
+
# Hook the target layer (last stage of backbone)
|
| 32 |
+
if target_layer is None:
|
| 33 |
+
target_layer = model.backbone.stages[-1]
|
| 34 |
+
|
| 35 |
+
target_layer.register_forward_hook(self._save_activation)
|
| 36 |
+
target_layer.register_full_backward_hook(self._save_gradient)
|
| 37 |
+
|
| 38 |
+
def _save_activation(self, module, input, output):
|
| 39 |
+
"""Save activations during forward pass"""
|
| 40 |
+
self.activations = output.detach()
|
| 41 |
+
|
| 42 |
+
def _save_gradient(self, module, grad_input, grad_output):
|
| 43 |
+
"""Save gradients during backward pass"""
|
| 44 |
+
self.gradients = grad_output[0].detach()
|
| 45 |
+
|
| 46 |
+
def generate(
|
| 47 |
+
self,
|
| 48 |
+
image_tensor: torch.Tensor,
|
| 49 |
+
target_class: Optional[int] = None,
|
| 50 |
+
derm_tensor: Optional[torch.Tensor] = None,
|
| 51 |
+
metadata: Optional[torch.Tensor] = None
|
| 52 |
+
) -> np.ndarray:
|
| 53 |
+
"""
|
| 54 |
+
Generate Grad-CAM heatmap.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
image_tensor: Input image tensor [1, 3, H, W]
|
| 58 |
+
target_class: Class index to visualize (default: predicted class)
|
| 59 |
+
derm_tensor: Optional dermoscopy image tensor
|
| 60 |
+
metadata: Optional metadata tensor
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
CAM heatmap as numpy array [H, W] normalized to 0-1
|
| 64 |
+
"""
|
| 65 |
+
self.model.eval()
|
| 66 |
+
|
| 67 |
+
# Forward pass
|
| 68 |
+
output = self.model(image_tensor, derm_tensor, metadata)
|
| 69 |
+
|
| 70 |
+
if target_class is None:
|
| 71 |
+
target_class = output.argmax(dim=1).item()
|
| 72 |
+
|
| 73 |
+
# Backward pass for target class
|
| 74 |
+
self.model.zero_grad()
|
| 75 |
+
output[0, target_class].backward()
|
| 76 |
+
|
| 77 |
+
# Get gradients and activations
|
| 78 |
+
gradients = self.gradients[0] # [C, H, W]
|
| 79 |
+
activations = self.activations[0] # [C, H, W]
|
| 80 |
+
|
| 81 |
+
# Global average pooling of gradients
|
| 82 |
+
weights = gradients.mean(dim=(1, 2)) # [C]
|
| 83 |
+
|
| 84 |
+
# Weighted combination of activation maps
|
| 85 |
+
cam = torch.zeros(activations.shape[1:], dtype=torch.float32, device=activations.device)
|
| 86 |
+
for i, w in enumerate(weights):
|
| 87 |
+
cam += w * activations[i]
|
| 88 |
+
|
| 89 |
+
# ReLU and normalize
|
| 90 |
+
cam = F.relu(cam)
|
| 91 |
+
cam = cam.cpu().numpy()
|
| 92 |
+
|
| 93 |
+
if cam.max() > 0:
|
| 94 |
+
cam = (cam - cam.min()) / (cam.max() - cam.min())
|
| 95 |
+
|
| 96 |
+
return cam
|
| 97 |
+
|
| 98 |
+
def overlay(
|
| 99 |
+
self,
|
| 100 |
+
image: np.ndarray,
|
| 101 |
+
cam: np.ndarray,
|
| 102 |
+
alpha: float = 0.5,
|
| 103 |
+
colormap: int = cv2.COLORMAP_JET
|
| 104 |
+
) -> np.ndarray:
|
| 105 |
+
"""
|
| 106 |
+
Overlay CAM heatmap on original image.
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
image: Original image [H, W, 3] RGB uint8
|
| 110 |
+
cam: CAM heatmap [H, W] float 0-1
|
| 111 |
+
alpha: Overlay transparency
|
| 112 |
+
colormap: OpenCV colormap
|
| 113 |
+
|
| 114 |
+
Returns:
|
| 115 |
+
Overlaid image [H, W, 3] RGB uint8
|
| 116 |
+
"""
|
| 117 |
+
H, W = image.shape[:2]
|
| 118 |
+
|
| 119 |
+
# Resize CAM to image size
|
| 120 |
+
cam_resized = cv2.resize(cam, (W, H))
|
| 121 |
+
|
| 122 |
+
# Apply colormap
|
| 123 |
+
heatmap = cv2.applyColorMap(
|
| 124 |
+
np.uint8(255 * cam_resized),
|
| 125 |
+
colormap
|
| 126 |
+
)
|
| 127 |
+
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
|
| 128 |
+
|
| 129 |
+
# Overlay
|
| 130 |
+
overlay = (alpha * heatmap + (1 - alpha) * image).astype(np.uint8)
|
| 131 |
+
|
| 132 |
+
return overlay
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class GradCAMTool:
|
| 136 |
+
"""
|
| 137 |
+
High-level Grad-CAM tool for ConvNeXt classifier.
|
| 138 |
+
"""
|
| 139 |
+
|
| 140 |
+
def __init__(self, classifier=None):
|
| 141 |
+
"""
|
| 142 |
+
Args:
|
| 143 |
+
classifier: ConvNeXtClassifier instance (will create one if None)
|
| 144 |
+
"""
|
| 145 |
+
self.classifier = classifier
|
| 146 |
+
self.gradcam = None
|
| 147 |
+
self.loaded = False
|
| 148 |
+
|
| 149 |
+
# Preprocessing
|
| 150 |
+
self.transform = transforms.Compose([
|
| 151 |
+
transforms.Resize((384, 384)),
|
| 152 |
+
transforms.ToTensor(),
|
| 153 |
+
transforms.Normalize(
|
| 154 |
+
mean=[0.485, 0.456, 0.406],
|
| 155 |
+
std=[0.229, 0.224, 0.225]
|
| 156 |
+
)
|
| 157 |
+
])
|
| 158 |
+
|
| 159 |
+
def load(self):
|
| 160 |
+
"""Load classifier and setup Grad-CAM"""
|
| 161 |
+
if self.loaded:
|
| 162 |
+
return
|
| 163 |
+
|
| 164 |
+
if self.classifier is None:
|
| 165 |
+
from models.convnext_classifier import ConvNeXtClassifier
|
| 166 |
+
self.classifier = ConvNeXtClassifier()
|
| 167 |
+
self.classifier.load()
|
| 168 |
+
|
| 169 |
+
self.gradcam = GradCAM(self.classifier.model)
|
| 170 |
+
self.loaded = True
|
| 171 |
+
|
| 172 |
+
def generate_heatmap(
|
| 173 |
+
self,
|
| 174 |
+
image: Image.Image,
|
| 175 |
+
target_class: Optional[int] = None
|
| 176 |
+
) -> Tuple[np.ndarray, np.ndarray, int, float]:
|
| 177 |
+
"""
|
| 178 |
+
Generate Grad-CAM heatmap for an image.
|
| 179 |
+
|
| 180 |
+
Args:
|
| 181 |
+
image: PIL Image
|
| 182 |
+
target_class: Class to visualize (default: predicted)
|
| 183 |
+
|
| 184 |
+
Returns:
|
| 185 |
+
Tuple of (overlay_image, cam_heatmap, predicted_class, confidence)
|
| 186 |
+
"""
|
| 187 |
+
if not self.loaded:
|
| 188 |
+
self.load()
|
| 189 |
+
|
| 190 |
+
# Ensure RGB
|
| 191 |
+
if image.mode != "RGB":
|
| 192 |
+
image = image.convert("RGB")
|
| 193 |
+
|
| 194 |
+
# Preprocess
|
| 195 |
+
image_np = np.array(image.resize((384, 384)))
|
| 196 |
+
image_tensor = self.transform(image).unsqueeze(0).to(self.classifier.device)
|
| 197 |
+
|
| 198 |
+
# Get prediction first
|
| 199 |
+
with torch.no_grad():
|
| 200 |
+
logits = self.classifier.model(image_tensor)
|
| 201 |
+
probs = torch.softmax(logits, dim=1)[0]
|
| 202 |
+
pred_class = probs.argmax().item()
|
| 203 |
+
confidence = probs[pred_class].item()
|
| 204 |
+
|
| 205 |
+
# Use predicted class if not specified
|
| 206 |
+
if target_class is None:
|
| 207 |
+
target_class = pred_class
|
| 208 |
+
|
| 209 |
+
# Generate CAM
|
| 210 |
+
cam = self.gradcam.generate(image_tensor, target_class)
|
| 211 |
+
|
| 212 |
+
# Create overlay
|
| 213 |
+
overlay = self.gradcam.overlay(image_np, cam, alpha=0.5)
|
| 214 |
+
|
| 215 |
+
return overlay, cam, pred_class, confidence
|
| 216 |
+
|
| 217 |
+
def analyze(
|
| 218 |
+
self,
|
| 219 |
+
image: Image.Image,
|
| 220 |
+
target_class: Optional[int] = None
|
| 221 |
+
) -> dict:
|
| 222 |
+
"""
|
| 223 |
+
Full analysis with Grad-CAM visualization.
|
| 224 |
+
|
| 225 |
+
Args:
|
| 226 |
+
image: PIL Image
|
| 227 |
+
target_class: Class to visualize
|
| 228 |
+
|
| 229 |
+
Returns:
|
| 230 |
+
Dict with overlay_image, cam, prediction info
|
| 231 |
+
"""
|
| 232 |
+
from models.convnext_classifier import CLASS_NAMES, CLASS_FULL_NAMES
|
| 233 |
+
|
| 234 |
+
overlay, cam, pred_class, confidence = self.generate_heatmap(image, target_class)
|
| 235 |
+
|
| 236 |
+
return {
|
| 237 |
+
"overlay": Image.fromarray(overlay),
|
| 238 |
+
"cam": cam,
|
| 239 |
+
"predicted_class": CLASS_NAMES[pred_class],
|
| 240 |
+
"predicted_class_full": CLASS_FULL_NAMES[CLASS_NAMES[pred_class]],
|
| 241 |
+
"confidence": confidence,
|
| 242 |
+
"class_index": pred_class,
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
def __call__(self, image: Image.Image, target_class: Optional[int] = None) -> dict:
|
| 246 |
+
return self.analyze(image, target_class)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
# Singleton
|
| 250 |
+
_gradcam_instance = None
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def get_gradcam_tool() -> GradCAMTool:
|
| 254 |
+
"""Get or create Grad-CAM tool instance"""
|
| 255 |
+
global _gradcam_instance
|
| 256 |
+
if _gradcam_instance is None:
|
| 257 |
+
_gradcam_instance = GradCAMTool()
|
| 258 |
+
return _gradcam_instance
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
if __name__ == "__main__":
|
| 262 |
+
import sys
|
| 263 |
+
|
| 264 |
+
print("Grad-CAM Tool Test")
|
| 265 |
+
print("=" * 50)
|
| 266 |
+
|
| 267 |
+
tool = GradCAMTool()
|
| 268 |
+
print("Loading model...")
|
| 269 |
+
tool.load()
|
| 270 |
+
print("Model loaded!")
|
| 271 |
+
|
| 272 |
+
if len(sys.argv) > 1:
|
| 273 |
+
image_path = sys.argv[1]
|
| 274 |
+
print(f"\nAnalyzing: {image_path}")
|
| 275 |
+
|
| 276 |
+
image = Image.open(image_path).convert("RGB")
|
| 277 |
+
result = tool.analyze(image)
|
| 278 |
+
|
| 279 |
+
print(f"\nPrediction: {result['predicted_class']} ({result['confidence']:.1%})")
|
| 280 |
+
print(f"Full name: {result['predicted_class_full']}")
|
| 281 |
+
|
| 282 |
+
# Save overlay
|
| 283 |
+
output_path = image_path.rsplit(".", 1)[0] + "_gradcam.png"
|
| 284 |
+
result["overlay"].save(output_path)
|
| 285 |
+
print(f"\nGrad-CAM overlay saved to: {output_path}")
|
models/guidelines_rag.py
ADDED
|
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Guidelines RAG System - Retrieval-Augmented Generation for clinical guidelines
|
| 3 |
+
Uses FAISS for vector similarity search on chunked guideline PDFs.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import json
|
| 8 |
+
import re
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import List, Dict, Optional, Tuple
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
# Paths
|
| 15 |
+
GUIDELINES_DIR = Path(__file__).parent.parent / "guidelines"
|
| 16 |
+
INDEX_DIR = GUIDELINES_DIR / "index"
|
| 17 |
+
FAISS_INDEX_PATH = INDEX_DIR / "faiss.index"
|
| 18 |
+
CHUNKS_PATH = INDEX_DIR / "chunks.json"
|
| 19 |
+
|
| 20 |
+
# Chunking parameters
|
| 21 |
+
CHUNK_SIZE = 500 # tokens (approximate)
|
| 22 |
+
CHUNK_OVERLAP = 50 # tokens overlap between chunks
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class GuidelinesRAG:
|
| 26 |
+
"""
|
| 27 |
+
RAG system for clinical guidelines.
|
| 28 |
+
Extracts text from PDFs, chunks it, creates embeddings, and provides search.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(self):
|
| 32 |
+
self.index = None
|
| 33 |
+
self.chunks = []
|
| 34 |
+
self.embedder = None
|
| 35 |
+
self.loaded = False
|
| 36 |
+
|
| 37 |
+
def _load_embedder(self):
|
| 38 |
+
"""Load sentence transformer model for embeddings"""
|
| 39 |
+
if self.embedder is None:
|
| 40 |
+
from sentence_transformers import SentenceTransformer
|
| 41 |
+
self.embedder = SentenceTransformer('all-MiniLM-L6-v2')
|
| 42 |
+
|
| 43 |
+
def _extract_pdf_text(self, pdf_path: Path) -> str:
|
| 44 |
+
"""Extract text from a PDF file"""
|
| 45 |
+
try:
|
| 46 |
+
import pdfplumber
|
| 47 |
+
text_parts = []
|
| 48 |
+
with pdfplumber.open(pdf_path) as pdf:
|
| 49 |
+
for page in pdf.pages:
|
| 50 |
+
page_text = page.extract_text()
|
| 51 |
+
if page_text:
|
| 52 |
+
text_parts.append(page_text)
|
| 53 |
+
return "\n\n".join(text_parts)
|
| 54 |
+
except ImportError:
|
| 55 |
+
# Fallback to PyPDF2
|
| 56 |
+
from PyPDF2 import PdfReader
|
| 57 |
+
reader = PdfReader(pdf_path)
|
| 58 |
+
text_parts = []
|
| 59 |
+
for page in reader.pages:
|
| 60 |
+
text = page.extract_text()
|
| 61 |
+
if text:
|
| 62 |
+
text_parts.append(text)
|
| 63 |
+
return "\n\n".join(text_parts)
|
| 64 |
+
|
| 65 |
+
def _clean_text(self, text: str) -> str:
|
| 66 |
+
"""Clean extracted text"""
|
| 67 |
+
# Remove excessive whitespace
|
| 68 |
+
text = re.sub(r'\s+', ' ', text)
|
| 69 |
+
# Remove page numbers and headers
|
| 70 |
+
text = re.sub(r'\n\d+\s*\n', '\n', text)
|
| 71 |
+
# Fix broken words from line breaks
|
| 72 |
+
text = re.sub(r'(\w)-\s+(\w)', r'\1\2', text)
|
| 73 |
+
return text.strip()
|
| 74 |
+
|
| 75 |
+
def _extract_pdf_with_pages(self, pdf_path: Path) -> List[Tuple[str, int]]:
|
| 76 |
+
"""Extract text from PDF with page numbers"""
|
| 77 |
+
try:
|
| 78 |
+
import pdfplumber
|
| 79 |
+
pages = []
|
| 80 |
+
with pdfplumber.open(pdf_path) as pdf:
|
| 81 |
+
for i, page in enumerate(pdf.pages, 1):
|
| 82 |
+
page_text = page.extract_text()
|
| 83 |
+
if page_text:
|
| 84 |
+
pages.append((page_text, i))
|
| 85 |
+
return pages
|
| 86 |
+
except ImportError:
|
| 87 |
+
from PyPDF2 import PdfReader
|
| 88 |
+
reader = PdfReader(pdf_path)
|
| 89 |
+
pages = []
|
| 90 |
+
for i, page in enumerate(reader.pages, 1):
|
| 91 |
+
text = page.extract_text()
|
| 92 |
+
if text:
|
| 93 |
+
pages.append((text, i))
|
| 94 |
+
return pages
|
| 95 |
+
|
| 96 |
+
def _chunk_text(self, text: str, source: str, page_num: int = 0) -> List[Dict]:
|
| 97 |
+
"""
|
| 98 |
+
Chunk text into overlapping segments.
|
| 99 |
+
Returns list of dicts with 'text', 'source', 'chunk_id', 'page'.
|
| 100 |
+
"""
|
| 101 |
+
# Approximate tokens by words (rough estimate: 1 token ≈ 0.75 words)
|
| 102 |
+
words = text.split()
|
| 103 |
+
chunk_words = int(CHUNK_SIZE * 0.75)
|
| 104 |
+
overlap_words = int(CHUNK_OVERLAP * 0.75)
|
| 105 |
+
|
| 106 |
+
chunks = []
|
| 107 |
+
start = 0
|
| 108 |
+
chunk_id = 0
|
| 109 |
+
|
| 110 |
+
while start < len(words):
|
| 111 |
+
end = start + chunk_words
|
| 112 |
+
chunk_text = ' '.join(words[start:end])
|
| 113 |
+
|
| 114 |
+
# Try to end at sentence boundary
|
| 115 |
+
if end < len(words):
|
| 116 |
+
last_period = chunk_text.rfind('.')
|
| 117 |
+
if last_period > len(chunk_text) * 0.7:
|
| 118 |
+
chunk_text = chunk_text[:last_period + 1]
|
| 119 |
+
|
| 120 |
+
chunks.append({
|
| 121 |
+
'text': chunk_text,
|
| 122 |
+
'source': source,
|
| 123 |
+
'chunk_id': chunk_id,
|
| 124 |
+
'page': page_num
|
| 125 |
+
})
|
| 126 |
+
|
| 127 |
+
start = end - overlap_words
|
| 128 |
+
chunk_id += 1
|
| 129 |
+
|
| 130 |
+
return chunks
|
| 131 |
+
|
| 132 |
+
def build_index(self, force_rebuild: bool = False) -> bool:
|
| 133 |
+
"""
|
| 134 |
+
Build FAISS index from guideline PDFs.
|
| 135 |
+
Returns True if index was built, False if loaded from cache.
|
| 136 |
+
"""
|
| 137 |
+
# Check if index already exists
|
| 138 |
+
if not force_rebuild and FAISS_INDEX_PATH.exists() and CHUNKS_PATH.exists():
|
| 139 |
+
return self.load_index()
|
| 140 |
+
|
| 141 |
+
print("Building guidelines index...")
|
| 142 |
+
self._load_embedder()
|
| 143 |
+
|
| 144 |
+
# Create index directory
|
| 145 |
+
INDEX_DIR.mkdir(parents=True, exist_ok=True)
|
| 146 |
+
|
| 147 |
+
# Extract and chunk all PDFs with page tracking
|
| 148 |
+
all_chunks = []
|
| 149 |
+
pdf_files = list(GUIDELINES_DIR.glob("*.pdf"))
|
| 150 |
+
|
| 151 |
+
for pdf_path in pdf_files:
|
| 152 |
+
print(f" Processing: {pdf_path.name}")
|
| 153 |
+
pages = self._extract_pdf_with_pages(pdf_path)
|
| 154 |
+
pdf_chunks = 0
|
| 155 |
+
for page_text, page_num in pages:
|
| 156 |
+
cleaned = self._clean_text(page_text)
|
| 157 |
+
chunks = self._chunk_text(cleaned, pdf_path.name, page_num)
|
| 158 |
+
all_chunks.extend(chunks)
|
| 159 |
+
pdf_chunks += len(chunks)
|
| 160 |
+
print(f" -> {pdf_chunks} chunks from {len(pages)} pages")
|
| 161 |
+
|
| 162 |
+
if not all_chunks:
|
| 163 |
+
print("No chunks extracted from PDFs!")
|
| 164 |
+
return False
|
| 165 |
+
|
| 166 |
+
self.chunks = all_chunks
|
| 167 |
+
print(f"Total chunks: {len(self.chunks)}")
|
| 168 |
+
|
| 169 |
+
# Generate embeddings
|
| 170 |
+
print("Generating embeddings...")
|
| 171 |
+
texts = [c['text'] for c in self.chunks]
|
| 172 |
+
embeddings = self.embedder.encode(texts, show_progress_bar=True)
|
| 173 |
+
embeddings = np.array(embeddings).astype('float32')
|
| 174 |
+
|
| 175 |
+
# Build FAISS index
|
| 176 |
+
import faiss
|
| 177 |
+
dimension = embeddings.shape[1]
|
| 178 |
+
self.index = faiss.IndexFlatIP(dimension) # Inner product (cosine with normalized vectors)
|
| 179 |
+
|
| 180 |
+
# Normalize embeddings for cosine similarity
|
| 181 |
+
faiss.normalize_L2(embeddings)
|
| 182 |
+
self.index.add(embeddings)
|
| 183 |
+
|
| 184 |
+
# Save index and chunks
|
| 185 |
+
faiss.write_index(self.index, str(FAISS_INDEX_PATH))
|
| 186 |
+
with open(CHUNKS_PATH, 'w') as f:
|
| 187 |
+
json.dump(self.chunks, f)
|
| 188 |
+
|
| 189 |
+
print(f"Index saved to {INDEX_DIR}")
|
| 190 |
+
self.loaded = True
|
| 191 |
+
return True
|
| 192 |
+
|
| 193 |
+
def load_index(self) -> bool:
|
| 194 |
+
"""Load persisted FAISS index and chunks"""
|
| 195 |
+
if not FAISS_INDEX_PATH.exists() or not CHUNKS_PATH.exists():
|
| 196 |
+
return False
|
| 197 |
+
|
| 198 |
+
import faiss
|
| 199 |
+
self.index = faiss.read_index(str(FAISS_INDEX_PATH))
|
| 200 |
+
|
| 201 |
+
with open(CHUNKS_PATH, 'r') as f:
|
| 202 |
+
self.chunks = json.load(f)
|
| 203 |
+
|
| 204 |
+
self._load_embedder()
|
| 205 |
+
self.loaded = True
|
| 206 |
+
return True
|
| 207 |
+
|
| 208 |
+
def search(self, query: str, k: int = 5) -> List[Dict]:
|
| 209 |
+
"""
|
| 210 |
+
Search for relevant guideline chunks.
|
| 211 |
+
Returns list of chunks with similarity scores.
|
| 212 |
+
"""
|
| 213 |
+
if not self.loaded:
|
| 214 |
+
if not self.load_index():
|
| 215 |
+
self.build_index()
|
| 216 |
+
|
| 217 |
+
import faiss
|
| 218 |
+
|
| 219 |
+
# Encode query
|
| 220 |
+
query_embedding = self.embedder.encode([query])
|
| 221 |
+
query_embedding = np.array(query_embedding).astype('float32')
|
| 222 |
+
faiss.normalize_L2(query_embedding)
|
| 223 |
+
|
| 224 |
+
# Search
|
| 225 |
+
scores, indices = self.index.search(query_embedding, k)
|
| 226 |
+
|
| 227 |
+
results = []
|
| 228 |
+
for score, idx in zip(scores[0], indices[0]):
|
| 229 |
+
if idx < len(self.chunks):
|
| 230 |
+
chunk = self.chunks[idx].copy()
|
| 231 |
+
chunk['score'] = float(score)
|
| 232 |
+
results.append(chunk)
|
| 233 |
+
|
| 234 |
+
return results
|
| 235 |
+
|
| 236 |
+
def get_management_context(self, diagnosis: str, features: Optional[str] = None) -> Tuple[str, List[Dict]]:
|
| 237 |
+
"""
|
| 238 |
+
Get formatted context from guidelines for management recommendations.
|
| 239 |
+
Returns tuple of (context_string, references_list).
|
| 240 |
+
References can be used for citation hyperlinks.
|
| 241 |
+
"""
|
| 242 |
+
# Build search query
|
| 243 |
+
query = f"{diagnosis} management treatment recommendations"
|
| 244 |
+
if features:
|
| 245 |
+
query += f" {features}"
|
| 246 |
+
|
| 247 |
+
chunks = self.search(query, k=5)
|
| 248 |
+
|
| 249 |
+
if not chunks:
|
| 250 |
+
return "No relevant guidelines found.", []
|
| 251 |
+
|
| 252 |
+
# Build context and collect references
|
| 253 |
+
context_parts = []
|
| 254 |
+
references = []
|
| 255 |
+
|
| 256 |
+
# Unicode superscript digits
|
| 257 |
+
superscripts = ['¹', '²', '³', '⁴', '⁵', '⁶', '⁷', '⁸', '⁹']
|
| 258 |
+
|
| 259 |
+
for i, chunk in enumerate(chunks, 1):
|
| 260 |
+
source = chunk['source'].replace('.pdf', '')
|
| 261 |
+
page = chunk.get('page', 0)
|
| 262 |
+
ref_id = f"ref{i}"
|
| 263 |
+
superscript = superscripts[i-1] if i <= len(superscripts) else f"[{i}]"
|
| 264 |
+
|
| 265 |
+
# Add reference marker with superscript
|
| 266 |
+
context_parts.append(f"[Source {superscript}] {chunk['text']}")
|
| 267 |
+
|
| 268 |
+
# Collect reference info
|
| 269 |
+
references.append({
|
| 270 |
+
'id': ref_id,
|
| 271 |
+
'source': source,
|
| 272 |
+
'page': page,
|
| 273 |
+
'file': chunk['source'],
|
| 274 |
+
'score': chunk.get('score', 0)
|
| 275 |
+
})
|
| 276 |
+
|
| 277 |
+
context = "\n\n".join(context_parts)
|
| 278 |
+
return context, references
|
| 279 |
+
|
| 280 |
+
def format_references_for_prompt(self, references: List[Dict]) -> str:
|
| 281 |
+
"""Format references for inclusion in LLM prompt"""
|
| 282 |
+
if not references:
|
| 283 |
+
return ""
|
| 284 |
+
|
| 285 |
+
lines = ["\n**References:**"]
|
| 286 |
+
for ref in references:
|
| 287 |
+
lines.append(f"[{ref['id']}] {ref['source']}, p.{ref['page']}")
|
| 288 |
+
return "\n".join(lines)
|
| 289 |
+
|
| 290 |
+
def format_references_for_display(self, references: List[Dict]) -> str:
|
| 291 |
+
"""
|
| 292 |
+
Format references with markers that frontend can parse into hyperlinks.
|
| 293 |
+
Uses format: [REF:id:source:page:file:superscript]
|
| 294 |
+
"""
|
| 295 |
+
if not references:
|
| 296 |
+
return ""
|
| 297 |
+
|
| 298 |
+
# Unicode superscript digits
|
| 299 |
+
superscripts = ['¹', '²', '³', '⁴', '⁵', '⁶', '⁷', '⁸', '⁹']
|
| 300 |
+
|
| 301 |
+
lines = ["\n[REFERENCES]"]
|
| 302 |
+
for i, ref in enumerate(references, 1):
|
| 303 |
+
superscript = superscripts[i-1] if i <= len(superscripts) else f"[{i}]"
|
| 304 |
+
# Format: [REF:ref1:Melanoma Guidelines:5:melanoma.pdf:¹]
|
| 305 |
+
lines.append(f"[REF:{ref['id']}:{ref['source']}:{ref['page']}:{ref['file']}:{superscript}]")
|
| 306 |
+
lines.append("[/REFERENCES]")
|
| 307 |
+
return "\n".join(lines)
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
# Singleton instance
|
| 311 |
+
_rag_instance = None
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def get_guidelines_rag() -> GuidelinesRAG:
|
| 315 |
+
"""Get or create RAG instance"""
|
| 316 |
+
global _rag_instance
|
| 317 |
+
if _rag_instance is None:
|
| 318 |
+
_rag_instance = GuidelinesRAG()
|
| 319 |
+
return _rag_instance
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
if __name__ == "__main__":
|
| 323 |
+
print("=" * 60)
|
| 324 |
+
print(" Guidelines RAG System - Index Builder")
|
| 325 |
+
print("=" * 60)
|
| 326 |
+
|
| 327 |
+
rag = GuidelinesRAG()
|
| 328 |
+
|
| 329 |
+
# Build or rebuild index
|
| 330 |
+
import sys
|
| 331 |
+
force = "--force" in sys.argv
|
| 332 |
+
rag.build_index(force_rebuild=force)
|
| 333 |
+
|
| 334 |
+
# Test search
|
| 335 |
+
print("\n" + "=" * 60)
|
| 336 |
+
print(" Testing Search")
|
| 337 |
+
print("=" * 60)
|
| 338 |
+
|
| 339 |
+
test_queries = [
|
| 340 |
+
"melanoma management",
|
| 341 |
+
"actinic keratosis treatment",
|
| 342 |
+
"surgical excision margins"
|
| 343 |
+
]
|
| 344 |
+
|
| 345 |
+
for query in test_queries:
|
| 346 |
+
print(f"\nQuery: '{query}'")
|
| 347 |
+
results = rag.search(query, k=2)
|
| 348 |
+
for r in results:
|
| 349 |
+
print(f" [{r['score']:.3f}] {r['source']}: {r['text'][:100]}...")
|
models/medgemma_agent.py
ADDED
|
@@ -0,0 +1,927 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MedGemma Agent - LLM agent with tool calling and staged thinking feedback
|
| 3 |
+
|
| 4 |
+
Pipeline: MedGemma independent exam → Tools (MONET/ConvNeXt/GradCAM) → MedGemma reconciliation → Management
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import sys
|
| 8 |
+
import time
|
| 9 |
+
import random
|
| 10 |
+
import json
|
| 11 |
+
import os
|
| 12 |
+
import subprocess
|
| 13 |
+
import threading
|
| 14 |
+
from typing import Optional, Generator, Dict, Any
|
| 15 |
+
from PIL import Image
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class MCPClient:
|
| 19 |
+
"""
|
| 20 |
+
Minimal MCP client that communicates with a FastMCP subprocess over stdio.
|
| 21 |
+
|
| 22 |
+
Uses raw newline-delimited JSON-RPC 2.0 so the main process (Python 3.9)
|
| 23 |
+
does not need the mcp library. The subprocess is launched with python3.11
|
| 24 |
+
which has mcp installed.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(self):
|
| 28 |
+
self._process = None
|
| 29 |
+
self._lock = threading.Lock()
|
| 30 |
+
self._id_counter = 0
|
| 31 |
+
|
| 32 |
+
def _next_id(self) -> int:
|
| 33 |
+
self._id_counter += 1
|
| 34 |
+
return self._id_counter
|
| 35 |
+
|
| 36 |
+
def _send(self, obj: dict):
|
| 37 |
+
line = json.dumps(obj) + "\n"
|
| 38 |
+
self._process.stdin.write(line)
|
| 39 |
+
self._process.stdin.flush()
|
| 40 |
+
|
| 41 |
+
def _recv(self) -> dict:
|
| 42 |
+
while True:
|
| 43 |
+
line = self._process.stdout.readline()
|
| 44 |
+
if not line:
|
| 45 |
+
raise RuntimeError("MCP server closed connection unexpectedly")
|
| 46 |
+
line = line.strip()
|
| 47 |
+
if not line:
|
| 48 |
+
continue
|
| 49 |
+
msg = json.loads(line)
|
| 50 |
+
# Skip server-initiated notifications (no "id" key)
|
| 51 |
+
if "id" in msg:
|
| 52 |
+
return msg
|
| 53 |
+
|
| 54 |
+
def _initialize(self):
|
| 55 |
+
"""Send MCP initialize handshake."""
|
| 56 |
+
req_id = self._next_id()
|
| 57 |
+
self._send({
|
| 58 |
+
"jsonrpc": "2.0",
|
| 59 |
+
"id": req_id,
|
| 60 |
+
"method": "initialize",
|
| 61 |
+
"params": {
|
| 62 |
+
"protocolVersion": "2024-11-05",
|
| 63 |
+
"capabilities": {},
|
| 64 |
+
"clientInfo": {"name": "SkinProAI", "version": "1.0.0"},
|
| 65 |
+
},
|
| 66 |
+
})
|
| 67 |
+
self._recv() # consume initialize response
|
| 68 |
+
# Confirm initialization
|
| 69 |
+
self._send({
|
| 70 |
+
"jsonrpc": "2.0",
|
| 71 |
+
"method": "notifications/initialized",
|
| 72 |
+
"params": {},
|
| 73 |
+
})
|
| 74 |
+
|
| 75 |
+
def start(self):
|
| 76 |
+
"""Spawn the MCP server subprocess and complete the handshake."""
|
| 77 |
+
root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 78 |
+
server_script = os.path.join(root, "mcp_server", "server.py")
|
| 79 |
+
self._process = subprocess.Popen(
|
| 80 |
+
[sys.executable, server_script], # use same venv Python (has all ML packages)
|
| 81 |
+
stdin=subprocess.PIPE,
|
| 82 |
+
stdout=subprocess.PIPE,
|
| 83 |
+
stderr=subprocess.PIPE,
|
| 84 |
+
text=True,
|
| 85 |
+
bufsize=1, # line-buffered
|
| 86 |
+
)
|
| 87 |
+
self._initialize()
|
| 88 |
+
|
| 89 |
+
def call_tool_sync(self, tool_name: str, arguments: dict) -> dict:
|
| 90 |
+
"""Call a tool synchronously and return the parsed result dict."""
|
| 91 |
+
with self._lock:
|
| 92 |
+
req_id = self._next_id()
|
| 93 |
+
self._send({
|
| 94 |
+
"jsonrpc": "2.0",
|
| 95 |
+
"id": req_id,
|
| 96 |
+
"method": "tools/call",
|
| 97 |
+
"params": {"name": tool_name, "arguments": arguments},
|
| 98 |
+
})
|
| 99 |
+
response = self._recv()
|
| 100 |
+
|
| 101 |
+
# Protocol-level error (e.g. unknown method)
|
| 102 |
+
if "error" in response:
|
| 103 |
+
raise RuntimeError(
|
| 104 |
+
f"MCP tool '{tool_name}' failed: {response['error']}"
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
result = response["result"]
|
| 108 |
+
content_text = result["content"][0]["text"]
|
| 109 |
+
|
| 110 |
+
# Tool-level error (isError=True means the tool itself raised an exception)
|
| 111 |
+
if result.get("isError"):
|
| 112 |
+
raise RuntimeError(f"MCP tool '{tool_name}' error: {content_text}")
|
| 113 |
+
|
| 114 |
+
return json.loads(content_text)
|
| 115 |
+
|
| 116 |
+
def stop(self):
|
| 117 |
+
"""Terminate the MCP server subprocess."""
|
| 118 |
+
if self._process:
|
| 119 |
+
try:
|
| 120 |
+
self._process.stdin.close()
|
| 121 |
+
self._process.terminate()
|
| 122 |
+
self._process.wait(timeout=5)
|
| 123 |
+
except Exception:
|
| 124 |
+
pass
|
| 125 |
+
self._process = None
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
# Rotating verbs for spinner effect
|
| 129 |
+
ANALYSIS_VERBS = [
|
| 130 |
+
"Analyzing", "Examining", "Processing", "Inspecting", "Evaluating",
|
| 131 |
+
"Scanning", "Assessing", "Reviewing", "Studying", "Interpreting"
|
| 132 |
+
]
|
| 133 |
+
|
| 134 |
+
# Comprehensive visual exam prompt (combined from 4 separate stages)
|
| 135 |
+
COMPREHENSIVE_EXAM_PROMPT = """Perform a systematic dermoscopic examination of this skin lesion. Assess ALL of the following in a SINGLE concise analysis:
|
| 136 |
+
|
| 137 |
+
1. PATTERN: Overall architecture, symmetry (symmetric/asymmetric), organization
|
| 138 |
+
2. COLORS: List all colors present (brown, black, blue, white, red, pink) and distribution
|
| 139 |
+
3. BORDER: Sharp vs gradual, regular vs irregular, any disruptions
|
| 140 |
+
4. STRUCTURES: Pigment network, dots/globules, streaks, blue-white veil, regression, vessels
|
| 141 |
+
|
| 142 |
+
Then provide:
|
| 143 |
+
- Top 3 differential diagnoses with brief reasoning
|
| 144 |
+
- Concern level (1-5, where 5=urgent)
|
| 145 |
+
- Single most important feature driving your assessment
|
| 146 |
+
|
| 147 |
+
Be CONCISE - focus on clinically relevant findings only."""
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def get_verb():
|
| 151 |
+
"""Get a random analysis verb for spinner effect"""
|
| 152 |
+
return random.choice(ANALYSIS_VERBS)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
class MedGemmaAgent:
|
| 156 |
+
"""
|
| 157 |
+
Medical image analysis agent with:
|
| 158 |
+
- Staged thinking display (no emojis)
|
| 159 |
+
- Tool calling (MONET, ConvNeXt, Grad-CAM)
|
| 160 |
+
- Streaming responses
|
| 161 |
+
"""
|
| 162 |
+
|
| 163 |
+
def __init__(self, verbose: bool = True):
|
| 164 |
+
self.verbose = verbose
|
| 165 |
+
self.pipe = None
|
| 166 |
+
self.model_id = "google/medgemma-4b-it"
|
| 167 |
+
self.loaded = False
|
| 168 |
+
|
| 169 |
+
# Tools (legacy direct instances, kept for fallback / non-MCP use)
|
| 170 |
+
self.monet_tool = None
|
| 171 |
+
self.convnext_tool = None
|
| 172 |
+
self.gradcam_tool = None
|
| 173 |
+
self.rag_tool = None
|
| 174 |
+
self.tools_loaded = False
|
| 175 |
+
|
| 176 |
+
# MCP client
|
| 177 |
+
self.mcp_client = None
|
| 178 |
+
|
| 179 |
+
# State for confirmation flow
|
| 180 |
+
self.last_diagnosis = None
|
| 181 |
+
self.last_monet_result = None
|
| 182 |
+
self.last_image = None
|
| 183 |
+
self.last_medgemma_exam = None # Store independent MedGemma findings
|
| 184 |
+
self.last_reconciliation = None
|
| 185 |
+
|
| 186 |
+
def reset_state(self):
|
| 187 |
+
"""Reset analysis state for new analysis (keeps models loaded)"""
|
| 188 |
+
self.last_diagnosis = None
|
| 189 |
+
self.last_monet_result = None
|
| 190 |
+
self.last_image = None
|
| 191 |
+
self.last_medgemma_exam = None
|
| 192 |
+
self.last_reconciliation = None
|
| 193 |
+
|
| 194 |
+
def _print(self, message: str):
|
| 195 |
+
"""Print if verbose"""
|
| 196 |
+
if self.verbose:
|
| 197 |
+
print(message, flush=True)
|
| 198 |
+
|
| 199 |
+
def load_model(self):
|
| 200 |
+
"""Load MedGemma model"""
|
| 201 |
+
if self.loaded:
|
| 202 |
+
return
|
| 203 |
+
|
| 204 |
+
self._print("Initializing MedGemma agent...")
|
| 205 |
+
|
| 206 |
+
import torch
|
| 207 |
+
from transformers import pipeline
|
| 208 |
+
|
| 209 |
+
self._print(f"Loading model: {self.model_id}")
|
| 210 |
+
|
| 211 |
+
if torch.cuda.is_available():
|
| 212 |
+
device = "cuda"
|
| 213 |
+
self._print(f"Using GPU: {torch.cuda.get_device_name(0)}")
|
| 214 |
+
elif torch.backends.mps.is_available():
|
| 215 |
+
device = "mps"
|
| 216 |
+
self._print("Using Apple Silicon (MPS)")
|
| 217 |
+
else:
|
| 218 |
+
device = "cpu"
|
| 219 |
+
self._print("Using CPU")
|
| 220 |
+
|
| 221 |
+
model_kwargs = dict(
|
| 222 |
+
torch_dtype=torch.bfloat16 if device != "cpu" else torch.float32,
|
| 223 |
+
device_map="auto",
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
start = time.time()
|
| 227 |
+
self.pipe = pipeline(
|
| 228 |
+
"image-text-to-text",
|
| 229 |
+
model=self.model_id,
|
| 230 |
+
model_kwargs=model_kwargs
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
self._print(f"Model loaded in {time.time() - start:.1f}s")
|
| 234 |
+
self.loaded = True
|
| 235 |
+
|
| 236 |
+
def load_tools(self):
|
| 237 |
+
"""Load tool models (MONET + ConvNeXt + Grad-CAM + RAG)"""
|
| 238 |
+
if self.tools_loaded:
|
| 239 |
+
return
|
| 240 |
+
|
| 241 |
+
from models.monet_tool import MonetTool
|
| 242 |
+
self.monet_tool = MonetTool()
|
| 243 |
+
self.monet_tool.load()
|
| 244 |
+
|
| 245 |
+
from models.convnext_classifier import ConvNeXtClassifier
|
| 246 |
+
self.convnext_tool = ConvNeXtClassifier()
|
| 247 |
+
self.convnext_tool.load()
|
| 248 |
+
|
| 249 |
+
from models.gradcam_tool import GradCAMTool
|
| 250 |
+
self.gradcam_tool = GradCAMTool(classifier=self.convnext_tool)
|
| 251 |
+
self.gradcam_tool.load()
|
| 252 |
+
|
| 253 |
+
from models.guidelines_rag import get_guidelines_rag
|
| 254 |
+
self.rag_tool = get_guidelines_rag()
|
| 255 |
+
if not self.rag_tool.loaded:
|
| 256 |
+
self.rag_tool.load_index()
|
| 257 |
+
|
| 258 |
+
self.tools_loaded = True
|
| 259 |
+
|
| 260 |
+
def load_tools_via_mcp(self):
|
| 261 |
+
"""Start the MCP server subprocess and mark tools as loaded."""
|
| 262 |
+
if self.tools_loaded:
|
| 263 |
+
return
|
| 264 |
+
self.mcp_client = MCPClient()
|
| 265 |
+
self.mcp_client.start()
|
| 266 |
+
self.tools_loaded = True
|
| 267 |
+
|
| 268 |
+
def _multi_pass_visual_exam(self, image, question: Optional[str] = None) -> Generator[str, None, Dict[str, str]]:
|
| 269 |
+
"""
|
| 270 |
+
MedGemma performs comprehensive visual examination BEFORE tools run.
|
| 271 |
+
Single prompt covers pattern, colors, borders, structures, and differentials.
|
| 272 |
+
Returns findings dict after yielding all output.
|
| 273 |
+
"""
|
| 274 |
+
findings = {}
|
| 275 |
+
|
| 276 |
+
yield f"\n[STAGE:medgemma_exam]MedGemma Visual Examination[/STAGE]\n"
|
| 277 |
+
yield f"[THINKING]Performing systematic dermoscopic assessment...[/THINKING]\n"
|
| 278 |
+
|
| 279 |
+
# Build prompt with optional clinical question
|
| 280 |
+
exam_prompt = COMPREHENSIVE_EXAM_PROMPT
|
| 281 |
+
if question:
|
| 282 |
+
exam_prompt += f"\n\nCLINICAL QUESTION: {question}"
|
| 283 |
+
|
| 284 |
+
messages = [
|
| 285 |
+
{
|
| 286 |
+
"role": "user",
|
| 287 |
+
"content": [
|
| 288 |
+
{"type": "image", "image": image},
|
| 289 |
+
{"type": "text", "text": exam_prompt}
|
| 290 |
+
]
|
| 291 |
+
}
|
| 292 |
+
]
|
| 293 |
+
|
| 294 |
+
try:
|
| 295 |
+
time.sleep(0.2)
|
| 296 |
+
output = self.pipe(messages, max_new_tokens=400)
|
| 297 |
+
result = output[0]["generated_text"][-1]["content"]
|
| 298 |
+
findings['synthesis'] = result
|
| 299 |
+
|
| 300 |
+
yield f"[RESPONSE]\n"
|
| 301 |
+
words = result.split()
|
| 302 |
+
for i, word in enumerate(words):
|
| 303 |
+
time.sleep(0.015)
|
| 304 |
+
yield word + (" " if i < len(words) - 1 else "")
|
| 305 |
+
yield f"\n[/RESPONSE]\n"
|
| 306 |
+
|
| 307 |
+
except Exception as e:
|
| 308 |
+
findings['synthesis'] = f"Analysis failed: {e}"
|
| 309 |
+
yield f"[ERROR]Visual examination failed: {e}[/ERROR]\n"
|
| 310 |
+
|
| 311 |
+
self.last_medgemma_exam = findings
|
| 312 |
+
return findings
|
| 313 |
+
|
| 314 |
+
def _reconcile_findings(
|
| 315 |
+
self,
|
| 316 |
+
image,
|
| 317 |
+
medgemma_exam: Dict[str, str],
|
| 318 |
+
monet_result: Dict[str, Any],
|
| 319 |
+
convnext_result: Dict[str, Any],
|
| 320 |
+
question: Optional[str] = None
|
| 321 |
+
) -> Generator[str, None, None]:
|
| 322 |
+
"""
|
| 323 |
+
MedGemma reconciles its independent findings with tool outputs.
|
| 324 |
+
Identifies agreements, disagreements, and provides integrated assessment.
|
| 325 |
+
"""
|
| 326 |
+
yield f"\n[STAGE:reconciliation]Reconciling MedGemma Findings with Tool Results[/STAGE]\n"
|
| 327 |
+
yield f"[THINKING]Comparing independent visual assessment against AI classification tools...[/THINKING]\n"
|
| 328 |
+
|
| 329 |
+
top = convnext_result['predictions'][0]
|
| 330 |
+
runner_up = convnext_result['predictions'][1] if len(convnext_result['predictions']) > 1 else None
|
| 331 |
+
|
| 332 |
+
# Build MONET features string
|
| 333 |
+
monet_top = sorted(monet_result["features"].items(), key=lambda x: x[1], reverse=True)[:5]
|
| 334 |
+
monet_str = ", ".join([f"{k.replace('MONET_', '').replace('_', ' ')}: {v:.0%}" for k, v in monet_top])
|
| 335 |
+
|
| 336 |
+
reconciliation_prompt = f"""You performed an independent visual examination of this lesion and concluded:
|
| 337 |
+
|
| 338 |
+
YOUR ASSESSMENT:
|
| 339 |
+
{medgemma_exam.get('synthesis', 'Not available')[:600]}
|
| 340 |
+
|
| 341 |
+
The AI classification tools produced these results:
|
| 342 |
+
- ConvNeXt classifier: {top['full_name']} ({top['probability']:.1%} confidence)
|
| 343 |
+
{f"- Runner-up: {runner_up['full_name']} ({runner_up['probability']:.1%})" if runner_up else ""}
|
| 344 |
+
- Key MONET features: {monet_str}
|
| 345 |
+
|
| 346 |
+
{f'CLINICAL QUESTION: {question}' if question else ''}
|
| 347 |
+
|
| 348 |
+
Reconcile your visual findings with the AI classification:
|
| 349 |
+
1. AGREEMENT/DISAGREEMENT: Do your findings support the AI diagnosis? Any conflicts?
|
| 350 |
+
2. INTEGRATED ASSESSMENT: Final diagnosis considering all evidence
|
| 351 |
+
3. CONFIDENCE (1-10): How certain? What would change your assessment?
|
| 352 |
+
|
| 353 |
+
Be concise and specific."""
|
| 354 |
+
|
| 355 |
+
messages = [
|
| 356 |
+
{
|
| 357 |
+
"role": "user",
|
| 358 |
+
"content": [
|
| 359 |
+
{"type": "image", "image": image},
|
| 360 |
+
{"type": "text", "text": reconciliation_prompt}
|
| 361 |
+
]
|
| 362 |
+
}
|
| 363 |
+
]
|
| 364 |
+
|
| 365 |
+
try:
|
| 366 |
+
output = self.pipe(messages, max_new_tokens=300)
|
| 367 |
+
reconciliation = output[0]["generated_text"][-1]["content"]
|
| 368 |
+
self.last_reconciliation = reconciliation
|
| 369 |
+
|
| 370 |
+
yield f"[RESPONSE]\n"
|
| 371 |
+
words = reconciliation.split()
|
| 372 |
+
for i, word in enumerate(words):
|
| 373 |
+
time.sleep(0.015)
|
| 374 |
+
yield word + (" " if i < len(words) - 1 else "")
|
| 375 |
+
yield f"\n[/RESPONSE]\n"
|
| 376 |
+
|
| 377 |
+
except Exception as e:
|
| 378 |
+
yield f"[ERROR]Reconciliation failed: {e}[/ERROR]\n"
|
| 379 |
+
|
| 380 |
+
def analyze_image_stream(
|
| 381 |
+
self,
|
| 382 |
+
image_path: str,
|
| 383 |
+
question: Optional[str] = None,
|
| 384 |
+
max_tokens: int = 512,
|
| 385 |
+
use_tools: bool = True
|
| 386 |
+
) -> Generator[str, None, None]:
|
| 387 |
+
"""
|
| 388 |
+
Stream analysis with new pipeline:
|
| 389 |
+
1. MedGemma independent multi-pass exam
|
| 390 |
+
2. MONET + ConvNeXt + GradCAM tools
|
| 391 |
+
3. MedGemma reconciliation
|
| 392 |
+
4. Confirmation request
|
| 393 |
+
"""
|
| 394 |
+
if not self.loaded:
|
| 395 |
+
yield "[STAGE:loading]Initializing MedGemma...[/STAGE]\n"
|
| 396 |
+
self.load_model()
|
| 397 |
+
|
| 398 |
+
yield f"[STAGE:image]{get_verb()} image...[/STAGE]\n"
|
| 399 |
+
|
| 400 |
+
try:
|
| 401 |
+
image = Image.open(image_path).convert("RGB")
|
| 402 |
+
self.last_image = image
|
| 403 |
+
except Exception as e:
|
| 404 |
+
yield f"[ERROR]Failed to load image: {e}[/ERROR]\n"
|
| 405 |
+
return
|
| 406 |
+
|
| 407 |
+
# Load tools early via MCP subprocess
|
| 408 |
+
if use_tools and not self.tools_loaded:
|
| 409 |
+
yield f"[STAGE:tools]Loading analysis tools...[/STAGE]\n"
|
| 410 |
+
self.load_tools_via_mcp()
|
| 411 |
+
|
| 412 |
+
# ===== PHASE 1: MedGemma Independent Visual Examination =====
|
| 413 |
+
medgemma_exam = {}
|
| 414 |
+
for chunk in self._multi_pass_visual_exam(image, question):
|
| 415 |
+
yield chunk
|
| 416 |
+
if isinstance(chunk, dict):
|
| 417 |
+
medgemma_exam = chunk
|
| 418 |
+
medgemma_exam = self.last_medgemma_exam or {}
|
| 419 |
+
|
| 420 |
+
monet_result = None
|
| 421 |
+
convnext_result = None
|
| 422 |
+
|
| 423 |
+
if use_tools:
|
| 424 |
+
# ===== PHASE 2: Run Classification Tools =====
|
| 425 |
+
yield f"\n[STAGE:tools_run]Running AI Classification Tools[/STAGE]\n"
|
| 426 |
+
yield f"[THINKING]Now running MONET and ConvNeXt to compare against visual examination...[/THINKING]\n"
|
| 427 |
+
|
| 428 |
+
# MONET Feature Extraction
|
| 429 |
+
time.sleep(0.2)
|
| 430 |
+
yield f"\n[STAGE:monet]MONET Feature Extraction[/STAGE]\n"
|
| 431 |
+
|
| 432 |
+
try:
|
| 433 |
+
monet_result = self.mcp_client.call_tool_sync(
|
| 434 |
+
"monet_analyze", {"image_path": image_path}
|
| 435 |
+
)
|
| 436 |
+
self.last_monet_result = monet_result
|
| 437 |
+
|
| 438 |
+
yield f"[TOOL_OUTPUT:MONET Features]\n"
|
| 439 |
+
for name, score in monet_result["features"].items():
|
| 440 |
+
short_name = name.replace("MONET_", "").replace("_", " ").title()
|
| 441 |
+
bar_filled = int(score * 10)
|
| 442 |
+
bar = "|" + "=" * bar_filled + "-" * (10 - bar_filled) + "|"
|
| 443 |
+
yield f" {short_name}: {bar} {score:.0%}\n"
|
| 444 |
+
yield f"[/TOOL_OUTPUT]\n"
|
| 445 |
+
|
| 446 |
+
except Exception as e:
|
| 447 |
+
yield f"[ERROR]MONET failed: {e}[/ERROR]\n"
|
| 448 |
+
|
| 449 |
+
# ConvNeXt Classification
|
| 450 |
+
time.sleep(0.2)
|
| 451 |
+
yield f"\n[STAGE:convnext]ConvNeXt Classification[/STAGE]\n"
|
| 452 |
+
|
| 453 |
+
try:
|
| 454 |
+
monet_scores = monet_result["vector"] if monet_result else None
|
| 455 |
+
convnext_result = self.mcp_client.call_tool_sync(
|
| 456 |
+
"classify_lesion",
|
| 457 |
+
{
|
| 458 |
+
"image_path": image_path,
|
| 459 |
+
"monet_scores": monet_scores,
|
| 460 |
+
},
|
| 461 |
+
)
|
| 462 |
+
self.last_diagnosis = convnext_result
|
| 463 |
+
|
| 464 |
+
yield f"[TOOL_OUTPUT:Classification Results]\n"
|
| 465 |
+
for pred in convnext_result["predictions"][:5]:
|
| 466 |
+
prob = pred['probability']
|
| 467 |
+
bar_filled = int(prob * 20)
|
| 468 |
+
bar = "|" + "=" * bar_filled + "-" * (20 - bar_filled) + "|"
|
| 469 |
+
yield f" {pred['class']}: {bar} {prob:.1%}\n"
|
| 470 |
+
yield f" {pred['full_name']}\n"
|
| 471 |
+
yield f"[/TOOL_OUTPUT]\n"
|
| 472 |
+
|
| 473 |
+
top = convnext_result['predictions'][0]
|
| 474 |
+
yield f"[RESULT]ConvNeXt Primary: {top['full_name']} ({top['probability']:.1%})[/RESULT]\n"
|
| 475 |
+
|
| 476 |
+
except Exception as e:
|
| 477 |
+
yield f"[ERROR]ConvNeXt failed: {e}[/ERROR]\n"
|
| 478 |
+
|
| 479 |
+
# Grad-CAM Visualization
|
| 480 |
+
time.sleep(0.2)
|
| 481 |
+
yield f"\n[STAGE:gradcam]Grad-CAM Attention Map[/STAGE]\n"
|
| 482 |
+
|
| 483 |
+
try:
|
| 484 |
+
gradcam_result = self.mcp_client.call_tool_sync(
|
| 485 |
+
"generate_gradcam", {"image_path": image_path}
|
| 486 |
+
)
|
| 487 |
+
gradcam_path = gradcam_result["gradcam_path"]
|
| 488 |
+
yield f"[GRADCAM_IMAGE:{gradcam_path}]\n"
|
| 489 |
+
except Exception as e:
|
| 490 |
+
yield f"[ERROR]Grad-CAM failed: {e}[/ERROR]\n"
|
| 491 |
+
|
| 492 |
+
# ===== PHASE 3: MedGemma Reconciliation =====
|
| 493 |
+
if convnext_result and monet_result and medgemma_exam:
|
| 494 |
+
for chunk in self._reconcile_findings(
|
| 495 |
+
image, medgemma_exam, monet_result, convnext_result, question
|
| 496 |
+
):
|
| 497 |
+
yield chunk
|
| 498 |
+
|
| 499 |
+
# Yield confirmation request
|
| 500 |
+
if convnext_result:
|
| 501 |
+
top = convnext_result['predictions'][0]
|
| 502 |
+
yield f"\n[CONFIRM:diagnosis]Do you agree with the integrated assessment?[/CONFIRM]\n"
|
| 503 |
+
|
| 504 |
+
def generate_management_guidance(
|
| 505 |
+
self,
|
| 506 |
+
user_confirmed: bool = True,
|
| 507 |
+
user_feedback: Optional[str] = None
|
| 508 |
+
) -> Generator[str, None, None]:
|
| 509 |
+
"""
|
| 510 |
+
Generate LESION-SPECIFIC management guidance using RAG + MedGemma reasoning.
|
| 511 |
+
References specific findings from this analysis, not generic textbook management.
|
| 512 |
+
"""
|
| 513 |
+
if not self.last_diagnosis:
|
| 514 |
+
yield "[ERROR]No diagnosis available. Please analyze an image first.[/ERROR]\n"
|
| 515 |
+
return
|
| 516 |
+
|
| 517 |
+
top = self.last_diagnosis['predictions'][0]
|
| 518 |
+
runner_up = self.last_diagnosis['predictions'][1] if len(self.last_diagnosis['predictions']) > 1 else None
|
| 519 |
+
diagnosis = top['full_name']
|
| 520 |
+
|
| 521 |
+
if not user_confirmed and user_feedback:
|
| 522 |
+
yield f"[THINKING]Clinician provided alternative assessment: {user_feedback}[/THINKING]\n"
|
| 523 |
+
diagnosis = user_feedback
|
| 524 |
+
|
| 525 |
+
# Stage: RAG Search
|
| 526 |
+
time.sleep(0.3)
|
| 527 |
+
yield f"\n[STAGE:guidelines]Searching clinical guidelines for {diagnosis}...[/STAGE]\n"
|
| 528 |
+
|
| 529 |
+
# Get RAG context via MCP
|
| 530 |
+
features_desc = self.last_monet_result.get('description', '') if self.last_monet_result else ''
|
| 531 |
+
rag_data = self.mcp_client.call_tool_sync(
|
| 532 |
+
"search_guidelines",
|
| 533 |
+
{"query": features_desc, "diagnosis": diagnosis},
|
| 534 |
+
)
|
| 535 |
+
context = rag_data["context"]
|
| 536 |
+
references = rag_data["references"]
|
| 537 |
+
|
| 538 |
+
# Check guideline relevance
|
| 539 |
+
has_relevant_guidelines = False
|
| 540 |
+
if references:
|
| 541 |
+
diagnosis_lower = diagnosis.lower()
|
| 542 |
+
for ref in references:
|
| 543 |
+
source_lower = ref['source'].lower()
|
| 544 |
+
if any(term in diagnosis_lower for term in ['melanoma']) and 'melanoma' in source_lower:
|
| 545 |
+
has_relevant_guidelines = True
|
| 546 |
+
break
|
| 547 |
+
elif 'actinic' in diagnosis_lower and 'actinic' in source_lower:
|
| 548 |
+
has_relevant_guidelines = True
|
| 549 |
+
break
|
| 550 |
+
elif ref.get('score', 0) > 0.7:
|
| 551 |
+
has_relevant_guidelines = True
|
| 552 |
+
break
|
| 553 |
+
|
| 554 |
+
if not references or not has_relevant_guidelines:
|
| 555 |
+
yield f"[THINKING]No specific published guidelines for {diagnosis}. Using clinical knowledge.[/THINKING]\n"
|
| 556 |
+
context = "No specific clinical guidelines available."
|
| 557 |
+
references = []
|
| 558 |
+
|
| 559 |
+
# Build MONET features for context
|
| 560 |
+
monet_features = ""
|
| 561 |
+
if self.last_monet_result:
|
| 562 |
+
top_features = sorted(self.last_monet_result["features"].items(), key=lambda x: x[1], reverse=True)[:5]
|
| 563 |
+
monet_features = ", ".join([f"{k.replace('MONET_', '').replace('_', ' ')}: {v:.0%}" for k, v in top_features])
|
| 564 |
+
|
| 565 |
+
# Stage: Lesion-Specific Management Reasoning
|
| 566 |
+
time.sleep(0.3)
|
| 567 |
+
yield f"\n[STAGE:management]Generating Lesion-Specific Management Plan[/STAGE]\n"
|
| 568 |
+
yield f"[THINKING]Creating management plan tailored to THIS lesion's specific characteristics...[/THINKING]\n"
|
| 569 |
+
|
| 570 |
+
management_prompt = f"""Generate a CONCISE management plan for this lesion:
|
| 571 |
+
|
| 572 |
+
DIAGNOSIS: {diagnosis} ({top['probability']:.1%})
|
| 573 |
+
{f"Alternative: {runner_up['full_name']} ({runner_up['probability']:.1%})" if runner_up else ""}
|
| 574 |
+
KEY FEATURES: {monet_features}
|
| 575 |
+
|
| 576 |
+
{f"GUIDELINES: {context[:800]}" if context else ""}
|
| 577 |
+
|
| 578 |
+
Provide:
|
| 579 |
+
1. RECOMMENDED ACTION: Biopsy, excision, monitoring, or discharge - with specific reasoning
|
| 580 |
+
2. URGENCY: Routine vs urgent vs same-day referral
|
| 581 |
+
3. KEY CONCERNS: What features drive this recommendation
|
| 582 |
+
|
| 583 |
+
Be specific to THIS lesion. 3-5 sentences maximum."""
|
| 584 |
+
|
| 585 |
+
messages = [
|
| 586 |
+
{
|
| 587 |
+
"role": "user",
|
| 588 |
+
"content": [
|
| 589 |
+
{"type": "image", "image": self.last_image},
|
| 590 |
+
{"type": "text", "text": management_prompt}
|
| 591 |
+
]
|
| 592 |
+
}
|
| 593 |
+
]
|
| 594 |
+
|
| 595 |
+
# Generate response
|
| 596 |
+
start = time.time()
|
| 597 |
+
try:
|
| 598 |
+
output = self.pipe(messages, max_new_tokens=250)
|
| 599 |
+
response = output[0]["generated_text"][-1]["content"]
|
| 600 |
+
|
| 601 |
+
yield f"[RESPONSE]\n"
|
| 602 |
+
words = response.split()
|
| 603 |
+
for i, word in enumerate(words):
|
| 604 |
+
time.sleep(0.015)
|
| 605 |
+
yield word + (" " if i < len(words) - 1 else "")
|
| 606 |
+
yield f"\n[/RESPONSE]\n"
|
| 607 |
+
|
| 608 |
+
except Exception as e:
|
| 609 |
+
yield f"[ERROR]Management generation failed: {e}[/ERROR]\n"
|
| 610 |
+
|
| 611 |
+
# Output references (pre-formatted by MCP server)
|
| 612 |
+
if references:
|
| 613 |
+
yield rag_data["references_display"]
|
| 614 |
+
|
| 615 |
+
yield f"\n[COMPLETE]Lesion-specific management plan generated in {time.time() - start:.1f}s[/COMPLETE]\n"
|
| 616 |
+
|
| 617 |
+
# Store response for recommendation extraction
|
| 618 |
+
self.last_management_response = response
|
| 619 |
+
|
| 620 |
+
def extract_recommendation(self) -> Generator[str, None, Dict[str, Any]]:
|
| 621 |
+
"""
|
| 622 |
+
Extract structured recommendation from management guidance.
|
| 623 |
+
Determines: BIOPSY, EXCISION, FOLLOWUP, or DISCHARGE
|
| 624 |
+
For BIOPSY/EXCISION, gets coordinates from MedGemma.
|
| 625 |
+
"""
|
| 626 |
+
if not self.last_management_response or not self.last_image:
|
| 627 |
+
yield "[ERROR]No management guidance available[/ERROR]\n"
|
| 628 |
+
return {"action": "UNKNOWN"}
|
| 629 |
+
|
| 630 |
+
yield f"\n[STAGE:recommendation]Extracting Clinical Recommendation[/STAGE]\n"
|
| 631 |
+
|
| 632 |
+
# Ask MedGemma to classify the recommendation
|
| 633 |
+
classification_prompt = f"""Based on the management plan you just provided:
|
| 634 |
+
|
| 635 |
+
{self.last_management_response[:1000]}
|
| 636 |
+
|
| 637 |
+
Classify the PRIMARY recommended action into exactly ONE of these categories:
|
| 638 |
+
- BIOPSY: If punch biopsy, shave biopsy, or incisional biopsy is recommended
|
| 639 |
+
- EXCISION: If complete surgical excision is recommended
|
| 640 |
+
- FOLLOWUP: If monitoring with repeat photography/dermoscopy is recommended
|
| 641 |
+
- DISCHARGE: If the lesion is clearly benign and no follow-up needed
|
| 642 |
+
|
| 643 |
+
Respond with ONLY the category name (BIOPSY, EXCISION, FOLLOWUP, or DISCHARGE) on the first line.
|
| 644 |
+
Then on the second line, provide a brief (1 sentence) justification."""
|
| 645 |
+
|
| 646 |
+
messages = [
|
| 647 |
+
{
|
| 648 |
+
"role": "user",
|
| 649 |
+
"content": [
|
| 650 |
+
{"type": "image", "image": self.last_image},
|
| 651 |
+
{"type": "text", "text": classification_prompt}
|
| 652 |
+
]
|
| 653 |
+
}
|
| 654 |
+
]
|
| 655 |
+
|
| 656 |
+
try:
|
| 657 |
+
output = self.pipe(messages, max_new_tokens=100)
|
| 658 |
+
response = output[0]["generated_text"][-1]["content"].strip()
|
| 659 |
+
lines = response.split('\n')
|
| 660 |
+
action = lines[0].strip().upper()
|
| 661 |
+
justification = lines[1].strip() if len(lines) > 1 else ""
|
| 662 |
+
|
| 663 |
+
# Validate action
|
| 664 |
+
valid_actions = ["BIOPSY", "EXCISION", "FOLLOWUP", "DISCHARGE"]
|
| 665 |
+
if action not in valid_actions:
|
| 666 |
+
# Try to extract from response
|
| 667 |
+
for valid in valid_actions:
|
| 668 |
+
if valid in response.upper():
|
| 669 |
+
action = valid
|
| 670 |
+
break
|
| 671 |
+
else:
|
| 672 |
+
action = "FOLLOWUP" # Default to safe option
|
| 673 |
+
|
| 674 |
+
yield f"[RESULT]Recommended Action: {action}[/RESULT]\n"
|
| 675 |
+
yield f"[OBSERVATION]{justification}[/OBSERVATION]\n"
|
| 676 |
+
|
| 677 |
+
result = {
|
| 678 |
+
"action": action,
|
| 679 |
+
"justification": justification
|
| 680 |
+
}
|
| 681 |
+
|
| 682 |
+
return result
|
| 683 |
+
|
| 684 |
+
except Exception as e:
|
| 685 |
+
yield f"[ERROR]Failed to extract recommendation: {e}[/ERROR]\n"
|
| 686 |
+
return {"action": "UNKNOWN", "error": str(e)}
|
| 687 |
+
|
| 688 |
+
def compare_followup_images(
|
| 689 |
+
self,
|
| 690 |
+
previous_image_path: str,
|
| 691 |
+
current_image_path: str
|
| 692 |
+
) -> Generator[str, None, None]:
|
| 693 |
+
"""
|
| 694 |
+
Compare a follow-up image with the previous one.
|
| 695 |
+
Runs full analysis pipeline on current image, then compares findings.
|
| 696 |
+
"""
|
| 697 |
+
yield f"\n[STAGE:comparison]Follow-up Comparison Analysis[/STAGE]\n"
|
| 698 |
+
|
| 699 |
+
try:
|
| 700 |
+
current_image = Image.open(current_image_path).convert("RGB")
|
| 701 |
+
except Exception as e:
|
| 702 |
+
yield f"[ERROR]Failed to load images: {e}[/ERROR]\n"
|
| 703 |
+
return
|
| 704 |
+
|
| 705 |
+
# Store previous analysis state
|
| 706 |
+
prev_exam = self.last_medgemma_exam
|
| 707 |
+
|
| 708 |
+
# Generate comparison image and MONET deltas via MCP
|
| 709 |
+
yield f"\n[STAGE:current_analysis]Analyzing Current Image[/STAGE]\n"
|
| 710 |
+
|
| 711 |
+
if self.tools_loaded:
|
| 712 |
+
try:
|
| 713 |
+
compare_data = self.mcp_client.call_tool_sync(
|
| 714 |
+
"compare_images",
|
| 715 |
+
{
|
| 716 |
+
"image1_path": previous_image_path,
|
| 717 |
+
"image2_path": current_image_path,
|
| 718 |
+
},
|
| 719 |
+
)
|
| 720 |
+
yield f"[COMPARISON_IMAGE:{compare_data['comparison_path']}]\n"
|
| 721 |
+
|
| 722 |
+
# Side-by-side GradCAM comparison if both paths available
|
| 723 |
+
prev_gc = compare_data.get("prev_gradcam_path")
|
| 724 |
+
curr_gc = compare_data.get("curr_gradcam_path")
|
| 725 |
+
if prev_gc and curr_gc:
|
| 726 |
+
yield f"[GRADCAM_COMPARE:{prev_gc}:{curr_gc}]\n"
|
| 727 |
+
|
| 728 |
+
# Display MONET feature deltas
|
| 729 |
+
if compare_data["monet_deltas"]:
|
| 730 |
+
yield f"[TOOL_OUTPUT:Feature Comparison]\n"
|
| 731 |
+
for name, delta_info in compare_data["monet_deltas"].items():
|
| 732 |
+
prev_val = delta_info["previous"]
|
| 733 |
+
curr_val = delta_info["current"]
|
| 734 |
+
diff = delta_info["delta"]
|
| 735 |
+
short_name = name.replace("MONET_", "").replace("_", " ").title()
|
| 736 |
+
direction = "↑" if diff > 0 else "↓"
|
| 737 |
+
yield f" {short_name}: {prev_val:.0%} → {curr_val:.0%} ({direction}{abs(diff):.0%})\n"
|
| 738 |
+
yield f"[/TOOL_OUTPUT]\n"
|
| 739 |
+
|
| 740 |
+
except Exception as e:
|
| 741 |
+
yield f"[ERROR]MCP comparison failed: {e}[/ERROR]\n"
|
| 742 |
+
|
| 743 |
+
# MedGemma comparison analysis
|
| 744 |
+
comparison_prompt = f"""You are comparing TWO images of the same skin lesion taken at different times.
|
| 745 |
+
|
| 746 |
+
PREVIOUS ANALYSIS:
|
| 747 |
+
{prev_exam.get('synthesis', 'Not available')[:500] if prev_exam else 'Not available'}
|
| 748 |
+
|
| 749 |
+
Now examine the CURRENT image and compare to your memory of the previous findings.
|
| 750 |
+
|
| 751 |
+
Assess for changes in:
|
| 752 |
+
1. SIZE: Has the lesion grown, shrunk, or stayed the same?
|
| 753 |
+
2. COLOR: Any new colors appeared? Any colors faded?
|
| 754 |
+
3. SHAPE/SYMMETRY: Has the shape changed? More or less symmetric?
|
| 755 |
+
4. BORDERS: Sharper, more irregular, or unchanged?
|
| 756 |
+
5. STRUCTURES: New dermoscopic structures? Lost structures?
|
| 757 |
+
|
| 758 |
+
Provide your assessment:
|
| 759 |
+
- CHANGE_LEVEL: SIGNIFICANT_CHANGE / MINOR_CHANGE / STABLE / IMPROVED
|
| 760 |
+
- Specific changes observed
|
| 761 |
+
- Clinical recommendation based on changes"""
|
| 762 |
+
|
| 763 |
+
messages = [
|
| 764 |
+
{
|
| 765 |
+
"role": "user",
|
| 766 |
+
"content": [
|
| 767 |
+
{"type": "image", "image": current_image},
|
| 768 |
+
{"type": "text", "text": comparison_prompt}
|
| 769 |
+
]
|
| 770 |
+
}
|
| 771 |
+
]
|
| 772 |
+
|
| 773 |
+
try:
|
| 774 |
+
yield f"[THINKING]Comparing current image to previous findings...[/THINKING]\n"
|
| 775 |
+
output = self.pipe(messages, max_new_tokens=400)
|
| 776 |
+
comparison_result = output[0]["generated_text"][-1]["content"]
|
| 777 |
+
|
| 778 |
+
yield f"[RESPONSE]\n"
|
| 779 |
+
words = comparison_result.split()
|
| 780 |
+
for i, word in enumerate(words):
|
| 781 |
+
time.sleep(0.02)
|
| 782 |
+
yield word + (" " if i < len(words) - 1 else "")
|
| 783 |
+
yield f"\n[/RESPONSE]\n"
|
| 784 |
+
|
| 785 |
+
# Extract change level
|
| 786 |
+
change_level = "UNKNOWN"
|
| 787 |
+
for level in ["SIGNIFICANT_CHANGE", "MINOR_CHANGE", "STABLE", "IMPROVED"]:
|
| 788 |
+
if level in comparison_result.upper():
|
| 789 |
+
change_level = level
|
| 790 |
+
break
|
| 791 |
+
|
| 792 |
+
if change_level == "SIGNIFICANT_CHANGE":
|
| 793 |
+
yield f"[RESULT]⚠️ SIGNIFICANT CHANGES DETECTED - Further evaluation recommended[/RESULT]\n"
|
| 794 |
+
elif change_level == "IMPROVED":
|
| 795 |
+
yield f"[RESULT]✓ LESION IMPROVED - Continue monitoring[/RESULT]\n"
|
| 796 |
+
elif change_level == "STABLE":
|
| 797 |
+
yield f"[RESULT]✓ LESION STABLE - Continue scheduled follow-up[/RESULT]\n"
|
| 798 |
+
else:
|
| 799 |
+
yield f"[RESULT]Minor changes noted - Clinical correlation recommended[/RESULT]\n"
|
| 800 |
+
|
| 801 |
+
except Exception as e:
|
| 802 |
+
yield f"[ERROR]Comparison analysis failed: {e}[/ERROR]\n"
|
| 803 |
+
|
| 804 |
+
yield f"\n[COMPLETE]Follow-up comparison complete[/COMPLETE]\n"
|
| 805 |
+
|
| 806 |
+
def chat(self, message: str, image_path: Optional[str] = None) -> str:
|
| 807 |
+
"""Simple chat interface"""
|
| 808 |
+
if not self.loaded:
|
| 809 |
+
self.load_model()
|
| 810 |
+
|
| 811 |
+
content = []
|
| 812 |
+
if image_path:
|
| 813 |
+
image = Image.open(image_path).convert("RGB")
|
| 814 |
+
content.append({"type": "image", "image": image})
|
| 815 |
+
content.append({"type": "text", "text": message})
|
| 816 |
+
|
| 817 |
+
messages = [{"role": "user", "content": content}]
|
| 818 |
+
output = self.pipe(messages, max_new_tokens=512)
|
| 819 |
+
return output[0]["generated_text"][-1]["content"]
|
| 820 |
+
|
| 821 |
+
def chat_followup(self, message: str) -> Generator[str, None, None]:
|
| 822 |
+
"""
|
| 823 |
+
Handle follow-up questions using the stored analysis context.
|
| 824 |
+
Uses the last analyzed image and diagnosis to provide contextual responses.
|
| 825 |
+
"""
|
| 826 |
+
if not self.loaded:
|
| 827 |
+
yield "[ERROR]Model not loaded[/ERROR]\n"
|
| 828 |
+
return
|
| 829 |
+
|
| 830 |
+
if not self.last_diagnosis or not self.last_image:
|
| 831 |
+
yield "[ERROR]No previous analysis context. Please analyze an image first.[/ERROR]\n"
|
| 832 |
+
return
|
| 833 |
+
|
| 834 |
+
# Build context from previous analysis
|
| 835 |
+
top_diagnosis = self.last_diagnosis['predictions'][0]
|
| 836 |
+
differentials = ", ".join([
|
| 837 |
+
f"{p['class']} ({p['probability']:.0%})"
|
| 838 |
+
for p in self.last_diagnosis['predictions'][:3]
|
| 839 |
+
])
|
| 840 |
+
|
| 841 |
+
monet_desc = ""
|
| 842 |
+
if self.last_monet_result:
|
| 843 |
+
monet_desc = self.last_monet_result.get('description', '')
|
| 844 |
+
|
| 845 |
+
context_prompt = f"""You are a dermatology assistant helping with skin lesion analysis.
|
| 846 |
+
|
| 847 |
+
PREVIOUS ANALYSIS CONTEXT:
|
| 848 |
+
- Primary diagnosis: {top_diagnosis['full_name']} ({top_diagnosis['probability']:.1%} confidence)
|
| 849 |
+
- Differential diagnoses: {differentials}
|
| 850 |
+
- Visual features: {monet_desc}
|
| 851 |
+
|
| 852 |
+
The user has a follow-up question about this lesion. Please provide a helpful, medically accurate response.
|
| 853 |
+
|
| 854 |
+
USER QUESTION: {message}
|
| 855 |
+
|
| 856 |
+
Provide a concise, informative response. If the question is outside your expertise or requires in-person examination, say so."""
|
| 857 |
+
|
| 858 |
+
messages = [
|
| 859 |
+
{
|
| 860 |
+
"role": "user",
|
| 861 |
+
"content": [
|
| 862 |
+
{"type": "image", "image": self.last_image},
|
| 863 |
+
{"type": "text", "text": context_prompt}
|
| 864 |
+
]
|
| 865 |
+
}
|
| 866 |
+
]
|
| 867 |
+
|
| 868 |
+
try:
|
| 869 |
+
yield f"[THINKING]Considering your question in context of the previous analysis...[/THINKING]\n"
|
| 870 |
+
time.sleep(0.2)
|
| 871 |
+
|
| 872 |
+
output = self.pipe(messages, max_new_tokens=400)
|
| 873 |
+
response = output[0]["generated_text"][-1]["content"]
|
| 874 |
+
|
| 875 |
+
yield f"[RESPONSE]\n"
|
| 876 |
+
# Stream word by word for typewriter effect
|
| 877 |
+
words = response.split()
|
| 878 |
+
for i, word in enumerate(words):
|
| 879 |
+
time.sleep(0.02)
|
| 880 |
+
yield word + (" " if i < len(words) - 1 else "")
|
| 881 |
+
yield f"\n[/RESPONSE]\n"
|
| 882 |
+
|
| 883 |
+
except Exception as e:
|
| 884 |
+
yield f"[ERROR]Failed to generate response: {e}[/ERROR]\n"
|
| 885 |
+
|
| 886 |
+
|
| 887 |
+
def main():
|
| 888 |
+
"""Interactive terminal interface"""
|
| 889 |
+
print("=" * 60)
|
| 890 |
+
print(" MedGemma Agent - Medical Image Analysis")
|
| 891 |
+
print("=" * 60)
|
| 892 |
+
|
| 893 |
+
agent = MedGemmaAgent(verbose=True)
|
| 894 |
+
agent.load_model()
|
| 895 |
+
|
| 896 |
+
print("\nCommands: analyze <path>, chat <message>, quit")
|
| 897 |
+
|
| 898 |
+
while True:
|
| 899 |
+
try:
|
| 900 |
+
user_input = input("\n> ").strip()
|
| 901 |
+
if not user_input:
|
| 902 |
+
continue
|
| 903 |
+
|
| 904 |
+
if user_input.lower() in ["quit", "exit", "q"]:
|
| 905 |
+
break
|
| 906 |
+
|
| 907 |
+
parts = user_input.split(maxsplit=1)
|
| 908 |
+
cmd = parts[0].lower()
|
| 909 |
+
|
| 910 |
+
if cmd == "analyze" and len(parts) > 1:
|
| 911 |
+
for chunk in agent.analyze_image_stream(parts[1].strip()):
|
| 912 |
+
print(chunk, end="", flush=True)
|
| 913 |
+
|
| 914 |
+
elif cmd == "chat" and len(parts) > 1:
|
| 915 |
+
print(agent.chat(parts[1]))
|
| 916 |
+
|
| 917 |
+
else:
|
| 918 |
+
print("Unknown command")
|
| 919 |
+
|
| 920 |
+
except KeyboardInterrupt:
|
| 921 |
+
break
|
| 922 |
+
except Exception as e:
|
| 923 |
+
print(f"Error: {e}")
|
| 924 |
+
|
| 925 |
+
|
| 926 |
+
if __name__ == "__main__":
|
| 927 |
+
main()
|
models/medsiglip_convnext_fusion.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# models/medsiglip_convnext_fusion.py
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from typing import Dict, List, Tuple, Optional
|
| 6 |
+
import numpy as np
|
| 7 |
+
import timm
|
| 8 |
+
from transformers import AutoModel, AutoProcessor
|
| 9 |
+
|
| 10 |
+
class MedSigLIPConvNeXtFusion(nn.Module):
|
| 11 |
+
"""
|
| 12 |
+
Your trained MedSigLIP-ConvNeXt fusion model from MILK10 challenge
|
| 13 |
+
Supports 11-class skin lesion classification
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
# Class names from your training
|
| 17 |
+
CLASS_NAMES = [
|
| 18 |
+
'AKIEC', # Actinic Keratoses and Intraepithelial Carcinoma
|
| 19 |
+
'BCC', # Basal Cell Carcinoma
|
| 20 |
+
'BEN_OTH', # Benign Other
|
| 21 |
+
'BKL', # Benign Keratosis-like Lesions
|
| 22 |
+
'DF', # Dermatofibroma
|
| 23 |
+
'INF', # Inflammatory
|
| 24 |
+
'MAL_OTH', # Malignant Other
|
| 25 |
+
'MEL', # Melanoma
|
| 26 |
+
'NV', # Melanocytic Nevi
|
| 27 |
+
'SCCKA', # Squamous Cell Carcinoma and Keratoacanthoma
|
| 28 |
+
'VASC' # Vascular Lesions
|
| 29 |
+
]
|
| 30 |
+
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
num_classes: int = 11,
|
| 34 |
+
medsiglip_model: str = "google/medsiglip-base",
|
| 35 |
+
convnext_variant: str = "convnext_base",
|
| 36 |
+
fusion_dim: int = 512,
|
| 37 |
+
dropout: float = 0.3,
|
| 38 |
+
metadata_dim: int = 20 # For metadata features
|
| 39 |
+
):
|
| 40 |
+
super().__init__()
|
| 41 |
+
|
| 42 |
+
self.num_classes = num_classes
|
| 43 |
+
|
| 44 |
+
# MedSigLIP Vision Encoder
|
| 45 |
+
print(f"Loading MedSigLIP: {medsiglip_model}")
|
| 46 |
+
self.medsiglip = AutoModel.from_pretrained(medsiglip_model)
|
| 47 |
+
self.medsiglip_processor = AutoProcessor.from_pretrained(medsiglip_model)
|
| 48 |
+
|
| 49 |
+
# ConvNeXt Backbone
|
| 50 |
+
print(f"Loading ConvNeXt: {convnext_variant}")
|
| 51 |
+
self.convnext = timm.create_model(
|
| 52 |
+
convnext_variant,
|
| 53 |
+
pretrained=True,
|
| 54 |
+
num_classes=0,
|
| 55 |
+
global_pool='avg'
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
# Feature dimensions
|
| 59 |
+
self.medsiglip_dim = self.medsiglip.config.hidden_size # 768
|
| 60 |
+
self.convnext_dim = self.convnext.num_features # 1024
|
| 61 |
+
|
| 62 |
+
# Optional metadata branch
|
| 63 |
+
self.use_metadata = metadata_dim > 0
|
| 64 |
+
if self.use_metadata:
|
| 65 |
+
self.metadata_encoder = nn.Sequential(
|
| 66 |
+
nn.Linear(metadata_dim, 64),
|
| 67 |
+
nn.LayerNorm(64),
|
| 68 |
+
nn.GELU(),
|
| 69 |
+
nn.Dropout(0.2),
|
| 70 |
+
nn.Linear(64, 32)
|
| 71 |
+
)
|
| 72 |
+
total_dim = self.medsiglip_dim + self.convnext_dim + 32
|
| 73 |
+
else:
|
| 74 |
+
total_dim = self.medsiglip_dim + self.convnext_dim
|
| 75 |
+
|
| 76 |
+
# Fusion layers
|
| 77 |
+
self.fusion = nn.Sequential(
|
| 78 |
+
nn.Linear(total_dim, fusion_dim),
|
| 79 |
+
nn.LayerNorm(fusion_dim),
|
| 80 |
+
nn.GELU(),
|
| 81 |
+
nn.Dropout(dropout),
|
| 82 |
+
nn.Linear(fusion_dim, fusion_dim // 2),
|
| 83 |
+
nn.LayerNorm(fusion_dim // 2),
|
| 84 |
+
nn.GELU(),
|
| 85 |
+
nn.Dropout(dropout)
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
# Classification head
|
| 89 |
+
self.classifier = nn.Linear(fusion_dim // 2, num_classes)
|
| 90 |
+
|
| 91 |
+
# Store intermediate features for Grad-CAM
|
| 92 |
+
self.convnext_features = None
|
| 93 |
+
self.medsiglip_features = None
|
| 94 |
+
|
| 95 |
+
# Register hooks
|
| 96 |
+
self.convnext.stages[-1].register_forward_hook(self._save_convnext_features)
|
| 97 |
+
|
| 98 |
+
def _save_convnext_features(self, module, input, output):
|
| 99 |
+
"""Hook to save ConvNeXt feature maps for Grad-CAM"""
|
| 100 |
+
self.convnext_features = output
|
| 101 |
+
|
| 102 |
+
def forward(
|
| 103 |
+
self,
|
| 104 |
+
image: torch.Tensor,
|
| 105 |
+
metadata: Optional[torch.Tensor] = None
|
| 106 |
+
) -> torch.Tensor:
|
| 107 |
+
"""
|
| 108 |
+
Forward pass
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
image: [B, 3, H, W] tensor
|
| 112 |
+
metadata: [B, metadata_dim] optional metadata features
|
| 113 |
+
|
| 114 |
+
Returns:
|
| 115 |
+
logits: [B, num_classes]
|
| 116 |
+
"""
|
| 117 |
+
# MedSigLIP features
|
| 118 |
+
medsiglip_out = self.medsiglip.vision_model(image)
|
| 119 |
+
medsiglip_features = medsiglip_out.pooler_output # [B, 768]
|
| 120 |
+
|
| 121 |
+
# ConvNeXt features
|
| 122 |
+
convnext_features = self.convnext(image) # [B, 1024]
|
| 123 |
+
|
| 124 |
+
# Concatenate vision features
|
| 125 |
+
fused = torch.cat([medsiglip_features, convnext_features], dim=1)
|
| 126 |
+
|
| 127 |
+
# Add metadata if available
|
| 128 |
+
if self.use_metadata and metadata is not None:
|
| 129 |
+
metadata_features = self.metadata_encoder(metadata)
|
| 130 |
+
fused = torch.cat([fused, metadata_features], dim=1)
|
| 131 |
+
|
| 132 |
+
# Fusion layers
|
| 133 |
+
fused = self.fusion(fused)
|
| 134 |
+
|
| 135 |
+
# Classification
|
| 136 |
+
logits = self.classifier(fused)
|
| 137 |
+
|
| 138 |
+
return logits
|
| 139 |
+
|
| 140 |
+
def predict(
|
| 141 |
+
self,
|
| 142 |
+
image: torch.Tensor,
|
| 143 |
+
metadata: Optional[torch.Tensor] = None,
|
| 144 |
+
top_k: int = 5
|
| 145 |
+
) -> Dict:
|
| 146 |
+
"""
|
| 147 |
+
Get predictions with probabilities
|
| 148 |
+
|
| 149 |
+
Args:
|
| 150 |
+
image: [B, 3, H, W] or [3, H, W]
|
| 151 |
+
metadata: Optional metadata features
|
| 152 |
+
top_k: Number of top predictions
|
| 153 |
+
|
| 154 |
+
Returns:
|
| 155 |
+
Dictionary with predictions and features
|
| 156 |
+
"""
|
| 157 |
+
if image.dim() == 3:
|
| 158 |
+
image = image.unsqueeze(0)
|
| 159 |
+
|
| 160 |
+
self.eval()
|
| 161 |
+
with torch.no_grad():
|
| 162 |
+
logits = self.forward(image, metadata)
|
| 163 |
+
probs = torch.softmax(logits, dim=1)
|
| 164 |
+
|
| 165 |
+
# Top-k predictions
|
| 166 |
+
top_probs, top_indices = torch.topk(
|
| 167 |
+
probs,
|
| 168 |
+
k=min(top_k, self.num_classes),
|
| 169 |
+
dim=1
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
# Format results
|
| 173 |
+
predictions = []
|
| 174 |
+
for i in range(top_probs.size(1)):
|
| 175 |
+
predictions.append({
|
| 176 |
+
'class': self.CLASS_NAMES[top_indices[0, i].item()],
|
| 177 |
+
'probability': top_probs[0, i].item(),
|
| 178 |
+
'class_idx': top_indices[0, i].item()
|
| 179 |
+
})
|
| 180 |
+
|
| 181 |
+
return {
|
| 182 |
+
'predictions': predictions,
|
| 183 |
+
'all_probabilities': probs[0].cpu().numpy(),
|
| 184 |
+
'logits': logits[0].cpu().numpy(),
|
| 185 |
+
'convnext_features': self.convnext_features,
|
| 186 |
+
'medsiglip_features': self.medsiglip_features
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
@classmethod
|
| 190 |
+
def load_from_checkpoint(
|
| 191 |
+
cls,
|
| 192 |
+
medsiglip_path: str,
|
| 193 |
+
convnext_path: Optional[str] = None,
|
| 194 |
+
ensemble_weights: tuple = (0.6, 0.4),
|
| 195 |
+
device: str = 'cpu'
|
| 196 |
+
):
|
| 197 |
+
"""
|
| 198 |
+
Load model from your training checkpoints
|
| 199 |
+
|
| 200 |
+
Args:
|
| 201 |
+
medsiglip_path: Path to MedSigLIP model weights
|
| 202 |
+
convnext_path: Path to ConvNeXt model weights (optional)
|
| 203 |
+
ensemble_weights: (w_medsiglip, w_convnext)
|
| 204 |
+
device: Device to load on
|
| 205 |
+
"""
|
| 206 |
+
model = cls(num_classes=11)
|
| 207 |
+
|
| 208 |
+
# Load MedSigLIP weights
|
| 209 |
+
print(f"Loading MedSigLIP from: {medsiglip_path}")
|
| 210 |
+
medsiglip_state = torch.load(medsiglip_path, map_location=device)
|
| 211 |
+
|
| 212 |
+
# Handle different checkpoint formats
|
| 213 |
+
if 'model_state_dict' in medsiglip_state:
|
| 214 |
+
model.load_state_dict(medsiglip_state['model_state_dict'])
|
| 215 |
+
else:
|
| 216 |
+
model.load_state_dict(medsiglip_state)
|
| 217 |
+
|
| 218 |
+
# Store ensemble weights for prediction fusion
|
| 219 |
+
model.ensemble_weights = ensemble_weights
|
| 220 |
+
|
| 221 |
+
model.to(device)
|
| 222 |
+
model.eval()
|
| 223 |
+
|
| 224 |
+
return model
|
models/monet_concepts.py
ADDED
|
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# models/monet_concepts.py
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
from typing import Dict, List
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
|
| 8 |
+
@dataclass
|
| 9 |
+
class ConceptScore:
|
| 10 |
+
"""Single MONET concept with score and evidence"""
|
| 11 |
+
name: str
|
| 12 |
+
score: float
|
| 13 |
+
confidence: float
|
| 14 |
+
description: str
|
| 15 |
+
clinical_relevance: str # How this affects diagnosis
|
| 16 |
+
|
| 17 |
+
class MONETConceptScorer:
|
| 18 |
+
"""
|
| 19 |
+
MONET concept scoring using your trained metadata patterns
|
| 20 |
+
Integrates the boosting logic from your ensemble code
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
# MONET concepts used in your training
|
| 24 |
+
CONCEPT_DEFINITIONS = {
|
| 25 |
+
'MONET_ulceration_crust': {
|
| 26 |
+
'description': 'Ulceration or crusting present',
|
| 27 |
+
'high_in': ['SCCKA', 'BCC', 'MAL_OTH'],
|
| 28 |
+
'low_in': ['NV', 'BKL'],
|
| 29 |
+
'threshold_high': 0.50
|
| 30 |
+
},
|
| 31 |
+
'MONET_erythema': {
|
| 32 |
+
'description': 'Redness or inflammation',
|
| 33 |
+
'high_in': ['INF', 'BCC', 'SCCKA'],
|
| 34 |
+
'low_in': ['MEL', 'NV'],
|
| 35 |
+
'threshold_high': 0.40
|
| 36 |
+
},
|
| 37 |
+
'MONET_pigmented': {
|
| 38 |
+
'description': 'Pigmentation present',
|
| 39 |
+
'high_in': ['MEL', 'NV', 'BKL'],
|
| 40 |
+
'low_in': ['BCC', 'SCCKA', 'INF'],
|
| 41 |
+
'threshold_high': 0.55
|
| 42 |
+
},
|
| 43 |
+
'MONET_vasculature_vessels': {
|
| 44 |
+
'description': 'Vascular structures visible',
|
| 45 |
+
'high_in': ['VASC', 'BCC'],
|
| 46 |
+
'low_in': ['MEL', 'NV'],
|
| 47 |
+
'threshold_high': 0.35
|
| 48 |
+
},
|
| 49 |
+
'MONET_hair': {
|
| 50 |
+
'description': 'Hair follicles present',
|
| 51 |
+
'high_in': ['NV', 'BKL'],
|
| 52 |
+
'low_in': ['BCC', 'MEL'],
|
| 53 |
+
'threshold_high': 0.30
|
| 54 |
+
},
|
| 55 |
+
'MONET_gel_water_drop_fluid_dermoscopy_liquid': {
|
| 56 |
+
'description': 'Gel/fluid artifacts',
|
| 57 |
+
'high_in': [],
|
| 58 |
+
'low_in': [],
|
| 59 |
+
'threshold_high': 0.40
|
| 60 |
+
},
|
| 61 |
+
'MONET_skin_markings_pen_ink_purple_pen': {
|
| 62 |
+
'description': 'Pen markings present',
|
| 63 |
+
'high_in': [],
|
| 64 |
+
'low_in': [],
|
| 65 |
+
'threshold_high': 0.40
|
| 66 |
+
}
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
# Class-specific patterns from your metadata boosting
|
| 70 |
+
CLASS_PATTERNS = {
|
| 71 |
+
'MAL_OTH': {
|
| 72 |
+
'sex': 'male', # 88.9% male
|
| 73 |
+
'site_preference': 'trunk',
|
| 74 |
+
'age_range': (60, 80),
|
| 75 |
+
'key_concepts': {'MONET_ulceration_crust': 0.35}
|
| 76 |
+
},
|
| 77 |
+
'INF': {
|
| 78 |
+
'key_concepts': {
|
| 79 |
+
'MONET_erythema': 0.42,
|
| 80 |
+
'MONET_pigmented': (None, 0.30) # Low pigmentation
|
| 81 |
+
}
|
| 82 |
+
},
|
| 83 |
+
'BEN_OTH': {
|
| 84 |
+
'site_preference': ['head', 'neck', 'face'], # 47.7%
|
| 85 |
+
'key_concepts': {'MONET_pigmented': (0.30, 0.50)}
|
| 86 |
+
},
|
| 87 |
+
'DF': {
|
| 88 |
+
'site_preference': ['lower', 'leg', 'ankle', 'foot'], # 65.4%
|
| 89 |
+
'age_range': (40, 65)
|
| 90 |
+
},
|
| 91 |
+
'SCCKA': {
|
| 92 |
+
'age_range': (65, None),
|
| 93 |
+
'key_concepts': {
|
| 94 |
+
'MONET_ulceration_crust': 0.50,
|
| 95 |
+
'MONET_pigmented': (None, 0.15)
|
| 96 |
+
}
|
| 97 |
+
},
|
| 98 |
+
'MEL': {
|
| 99 |
+
'age_range': (55, None), # 61.8 years average
|
| 100 |
+
'key_concepts': {'MONET_pigmented': 0.55}
|
| 101 |
+
},
|
| 102 |
+
'NV': {
|
| 103 |
+
'age_range': (None, 45), # 42.0 years average
|
| 104 |
+
'key_concepts': {'MONET_pigmented': 0.55}
|
| 105 |
+
}
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
def __init__(self):
|
| 109 |
+
"""Initialize MONET scorer with class patterns"""
|
| 110 |
+
self.class_names = [
|
| 111 |
+
'AKIEC', 'BCC', 'BEN_OTH', 'BKL', 'DF',
|
| 112 |
+
'INF', 'MAL_OTH', 'MEL', 'NV', 'SCCKA', 'VASC'
|
| 113 |
+
]
|
| 114 |
+
|
| 115 |
+
def compute_concept_scores(
|
| 116 |
+
self,
|
| 117 |
+
metadata: Dict[str, float]
|
| 118 |
+
) -> Dict[str, ConceptScore]:
|
| 119 |
+
"""
|
| 120 |
+
Compute MONET concept scores from metadata
|
| 121 |
+
|
| 122 |
+
Args:
|
| 123 |
+
metadata: Dictionary with MONET scores, age, sex, site, etc.
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
Dictionary of concept scores
|
| 127 |
+
"""
|
| 128 |
+
concept_scores = {}
|
| 129 |
+
|
| 130 |
+
for concept_name, definition in self.CONCEPT_DEFINITIONS.items():
|
| 131 |
+
score = metadata.get(concept_name, 0.0)
|
| 132 |
+
|
| 133 |
+
# Determine confidence based on how extreme the score is
|
| 134 |
+
if score > definition['threshold_high']:
|
| 135 |
+
confidence = min((score - definition['threshold_high']) / 0.2, 1.0)
|
| 136 |
+
level = "HIGH"
|
| 137 |
+
elif score < 0.2:
|
| 138 |
+
confidence = min((0.2 - score) / 0.2, 1.0)
|
| 139 |
+
level = "LOW"
|
| 140 |
+
else:
|
| 141 |
+
confidence = 0.5
|
| 142 |
+
level = "MODERATE"
|
| 143 |
+
|
| 144 |
+
# Clinical relevance
|
| 145 |
+
if level == "HIGH":
|
| 146 |
+
relevant_classes = definition['high_in']
|
| 147 |
+
clinical_relevance = f"Supports: {', '.join(relevant_classes)}"
|
| 148 |
+
elif level == "LOW":
|
| 149 |
+
excluded_classes = definition['low_in']
|
| 150 |
+
clinical_relevance = f"Against: {', '.join(excluded_classes)}"
|
| 151 |
+
else:
|
| 152 |
+
clinical_relevance = "Non-specific"
|
| 153 |
+
|
| 154 |
+
concept_scores[concept_name] = ConceptScore(
|
| 155 |
+
name=concept_name.replace('MONET_', '').replace('_', ' ').title(),
|
| 156 |
+
score=score,
|
| 157 |
+
confidence=confidence,
|
| 158 |
+
description=f"{definition['description']} ({level})",
|
| 159 |
+
clinical_relevance=clinical_relevance
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
return concept_scores
|
| 163 |
+
|
| 164 |
+
def apply_metadata_boosting(
|
| 165 |
+
self,
|
| 166 |
+
probs: np.ndarray,
|
| 167 |
+
metadata: Dict
|
| 168 |
+
) -> np.ndarray:
|
| 169 |
+
"""
|
| 170 |
+
Apply your metadata boosting logic
|
| 171 |
+
This is directly from your ensemble optimization code
|
| 172 |
+
|
| 173 |
+
Args:
|
| 174 |
+
probs: [11] probability array
|
| 175 |
+
metadata: Dictionary with age, sex, site, MONET scores
|
| 176 |
+
|
| 177 |
+
Returns:
|
| 178 |
+
boosted_probs: [11] adjusted probabilities
|
| 179 |
+
"""
|
| 180 |
+
boosted_probs = probs.copy()
|
| 181 |
+
|
| 182 |
+
# 1. MAL_OTH boosting
|
| 183 |
+
if metadata.get('sex') == 'male':
|
| 184 |
+
site = str(metadata.get('site', '')).lower()
|
| 185 |
+
if 'trunk' in site:
|
| 186 |
+
age = metadata.get('age_approx', 60)
|
| 187 |
+
ulceration = metadata.get('MONET_ulceration_crust', 0)
|
| 188 |
+
|
| 189 |
+
score = 0
|
| 190 |
+
score += 3 if metadata.get('sex') == 'male' else 0
|
| 191 |
+
score += 2 if 'trunk' in site else 0
|
| 192 |
+
score += 1 if 60 <= age <= 80 else 0
|
| 193 |
+
score += 2 if ulceration > 0.35 else 0
|
| 194 |
+
|
| 195 |
+
confidence = score / 8.0
|
| 196 |
+
if confidence > 0.5:
|
| 197 |
+
boosted_probs[6] *= (1.0 + confidence) # MAL_OTH index
|
| 198 |
+
|
| 199 |
+
# 2. INF boosting
|
| 200 |
+
erythema = metadata.get('MONET_erythema', 0)
|
| 201 |
+
pigmentation = metadata.get('MONET_pigmented', 0)
|
| 202 |
+
|
| 203 |
+
if erythema > 0.42 and pigmentation < 0.30:
|
| 204 |
+
confidence = min((erythema - 0.42) / 0.10 + 0.5, 1.0)
|
| 205 |
+
boosted_probs[5] *= (1.0 + confidence * 0.8) # INF index
|
| 206 |
+
|
| 207 |
+
# 3. BEN_OTH boosting
|
| 208 |
+
site = str(metadata.get('site', '')).lower()
|
| 209 |
+
is_head_neck = any(x in site for x in ['head', 'neck', 'face'])
|
| 210 |
+
|
| 211 |
+
if is_head_neck and 0.30 < pigmentation < 0.50:
|
| 212 |
+
ulceration = metadata.get('MONET_ulceration_crust', 0)
|
| 213 |
+
confidence = 0.7 if ulceration < 0.30 else 0.4
|
| 214 |
+
boosted_probs[2] *= (1.0 + confidence * 0.5) # BEN_OTH index
|
| 215 |
+
|
| 216 |
+
# 4. DF boosting
|
| 217 |
+
is_lower_ext = any(x in site for x in ['lower', 'leg', 'ankle', 'foot'])
|
| 218 |
+
|
| 219 |
+
if is_lower_ext:
|
| 220 |
+
age = metadata.get('age_approx', 60)
|
| 221 |
+
if 40 <= age <= 65:
|
| 222 |
+
boosted_probs[4] *= 1.8 # DF index
|
| 223 |
+
elif 30 <= age <= 75:
|
| 224 |
+
boosted_probs[4] *= 1.5
|
| 225 |
+
|
| 226 |
+
# 5. SCCKA boosting
|
| 227 |
+
ulceration = metadata.get('MONET_ulceration_crust', 0)
|
| 228 |
+
age = metadata.get('age_approx', 60)
|
| 229 |
+
|
| 230 |
+
if ulceration > 0.50 and age >= 65 and pigmentation < 0.15:
|
| 231 |
+
boosted_probs[9] *= 1.9 # SCCKA index
|
| 232 |
+
elif ulceration > 0.45 and age >= 60 and pigmentation < 0.20:
|
| 233 |
+
boosted_probs[9] *= 1.5
|
| 234 |
+
|
| 235 |
+
# 6. MEL vs NV age separation
|
| 236 |
+
if pigmentation > 0.55:
|
| 237 |
+
if age >= 55:
|
| 238 |
+
age_score = min((age - 55) / 20.0, 1.0)
|
| 239 |
+
boosted_probs[7] *= (1.0 + age_score * 0.5) # MEL
|
| 240 |
+
boosted_probs[8] *= (1.0 - age_score * 0.3) # NV
|
| 241 |
+
elif age <= 45:
|
| 242 |
+
age_score = min((45 - age) / 30.0, 1.0)
|
| 243 |
+
boosted_probs[7] *= (1.0 - age_score * 0.3) # MEL
|
| 244 |
+
boosted_probs[8] *= (1.0 + age_score * 0.5) # NV
|
| 245 |
+
|
| 246 |
+
# 7. Exclusions based on pigmentation/erythema
|
| 247 |
+
if pigmentation > 0.50:
|
| 248 |
+
boosted_probs[0] *= 0.7 # AKIEC
|
| 249 |
+
boosted_probs[1] *= 0.6 # BCC
|
| 250 |
+
boosted_probs[5] *= 0.5 # INF
|
| 251 |
+
boosted_probs[9] *= 0.3 # SCCKA
|
| 252 |
+
|
| 253 |
+
if erythema > 0.40:
|
| 254 |
+
boosted_probs[7] *= 0.7 # MEL
|
| 255 |
+
boosted_probs[8] *= 0.7 # NV
|
| 256 |
+
|
| 257 |
+
if pigmentation < 0.20:
|
| 258 |
+
boosted_probs[7] *= 0.5 # MEL
|
| 259 |
+
boosted_probs[8] *= 0.5 # NV
|
| 260 |
+
|
| 261 |
+
# Renormalize
|
| 262 |
+
return boosted_probs / boosted_probs.sum()
|
| 263 |
+
|
| 264 |
+
def explain_prediction(
|
| 265 |
+
self,
|
| 266 |
+
probs: np.ndarray,
|
| 267 |
+
concept_scores: Dict[str, ConceptScore],
|
| 268 |
+
metadata: Dict
|
| 269 |
+
) -> str:
|
| 270 |
+
"""
|
| 271 |
+
Generate natural language explanation
|
| 272 |
+
|
| 273 |
+
Args:
|
| 274 |
+
probs: Class probabilities
|
| 275 |
+
concept_scores: MONET concept scores
|
| 276 |
+
metadata: Clinical metadata
|
| 277 |
+
|
| 278 |
+
Returns:
|
| 279 |
+
Natural language explanation
|
| 280 |
+
"""
|
| 281 |
+
predicted_idx = np.argmax(probs)
|
| 282 |
+
predicted_class = self.class_names[predicted_idx]
|
| 283 |
+
confidence = probs[predicted_idx]
|
| 284 |
+
|
| 285 |
+
explanation = f"**Primary Diagnosis: {predicted_class}**\n"
|
| 286 |
+
explanation += f"Confidence: {confidence:.1%}\n\n"
|
| 287 |
+
|
| 288 |
+
# Key MONET features
|
| 289 |
+
explanation += "**Key Dermoscopic Features:**\n"
|
| 290 |
+
|
| 291 |
+
sorted_concepts = sorted(
|
| 292 |
+
concept_scores.values(),
|
| 293 |
+
key=lambda x: x.score * x.confidence,
|
| 294 |
+
reverse=True
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
for i, concept in enumerate(sorted_concepts[:5], 1):
|
| 298 |
+
if concept.score > 0.3 or concept.score < 0.2:
|
| 299 |
+
explanation += f"{i}. {concept.name}: {concept.score:.2f} - {concept.description}\n"
|
| 300 |
+
if concept.clinical_relevance != "Non-specific":
|
| 301 |
+
explanation += f" → {concept.clinical_relevance}\n"
|
| 302 |
+
|
| 303 |
+
# Clinical context
|
| 304 |
+
explanation += "\n**Clinical Context:**\n"
|
| 305 |
+
if 'age_approx' in metadata:
|
| 306 |
+
explanation += f"• Age: {metadata['age_approx']} years\n"
|
| 307 |
+
if 'sex' in metadata:
|
| 308 |
+
explanation += f"• Sex: {metadata['sex']}\n"
|
| 309 |
+
if 'site' in metadata:
|
| 310 |
+
explanation += f"• Location: {metadata['site']}\n"
|
| 311 |
+
|
| 312 |
+
return explanation
|
| 313 |
+
|
| 314 |
+
def get_top_concepts(
|
| 315 |
+
self,
|
| 316 |
+
concept_scores: Dict[str, ConceptScore],
|
| 317 |
+
top_k: int = 5,
|
| 318 |
+
min_score: float = 0.3
|
| 319 |
+
) -> List[ConceptScore]:
|
| 320 |
+
"""Get top-k most important concepts"""
|
| 321 |
+
filtered = [
|
| 322 |
+
cs for cs in concept_scores.values()
|
| 323 |
+
if cs.score >= min_score or cs.score < 0.2 # High or low
|
| 324 |
+
]
|
| 325 |
+
|
| 326 |
+
sorted_concepts = sorted(
|
| 327 |
+
filtered,
|
| 328 |
+
key=lambda x: x.score * x.confidence,
|
| 329 |
+
reverse=True
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
return sorted_concepts[:top_k]
|
models/monet_tool.py
ADDED
|
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MONET Tool - Skin lesion feature extraction using MONET model
|
| 3 |
+
Correct implementation based on MONET tutorial: automatic_concept_annotation.ipynb
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
import numpy as np
|
| 9 |
+
import scipy.special
|
| 10 |
+
from PIL import Image
|
| 11 |
+
from typing import Optional, Dict, List
|
| 12 |
+
import torchvision.transforms as T
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# The 7 MONET feature columns expected by ConvNeXt
|
| 16 |
+
MONET_FEATURES = [
|
| 17 |
+
"MONET_ulceration_crust",
|
| 18 |
+
"MONET_hair",
|
| 19 |
+
"MONET_vasculature_vessels",
|
| 20 |
+
"MONET_erythema",
|
| 21 |
+
"MONET_pigmented",
|
| 22 |
+
"MONET_gel_water_drop_fluid_dermoscopy_liquid",
|
| 23 |
+
"MONET_skin_markings_pen_ink_purple_pen",
|
| 24 |
+
]
|
| 25 |
+
|
| 26 |
+
# Concept terms for each MONET feature (multiple synonyms improve detection)
|
| 27 |
+
MONET_CONCEPT_TERMS = {
|
| 28 |
+
"MONET_ulceration_crust": ["ulceration", "crust", "crusting", "ulcer"],
|
| 29 |
+
"MONET_hair": ["hair", "hairy"],
|
| 30 |
+
"MONET_vasculature_vessels": ["blood vessels", "vasculature", "vessels", "telangiectasia"],
|
| 31 |
+
"MONET_erythema": ["erythema", "redness", "red"],
|
| 32 |
+
"MONET_pigmented": ["pigmented", "pigmentation", "melanin", "brown"],
|
| 33 |
+
"MONET_gel_water_drop_fluid_dermoscopy_liquid": ["dermoscopy gel", "fluid", "water drop", "immersion fluid"],
|
| 34 |
+
"MONET_skin_markings_pen_ink_purple_pen": ["pen marking", "ink", "surgical marking", "purple pen"],
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
# Prompt templates (from MONET paper)
|
| 38 |
+
PROMPT_TEMPLATES = [
|
| 39 |
+
"This is skin image of {}",
|
| 40 |
+
"This is dermatology image of {}",
|
| 41 |
+
"This is image of {}",
|
| 42 |
+
]
|
| 43 |
+
|
| 44 |
+
# Reference prompts (baseline for contrastive scoring)
|
| 45 |
+
PROMPT_REFS = [
|
| 46 |
+
["This is skin image"],
|
| 47 |
+
["This is dermatology image"],
|
| 48 |
+
["This is image"],
|
| 49 |
+
]
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def get_transform(n_px=224):
|
| 53 |
+
"""Get MONET preprocessing transform"""
|
| 54 |
+
def convert_image_to_rgb(image):
|
| 55 |
+
return image.convert("RGB")
|
| 56 |
+
|
| 57 |
+
return T.Compose([
|
| 58 |
+
T.Resize(n_px, interpolation=T.InterpolationMode.BICUBIC),
|
| 59 |
+
T.CenterCrop(n_px),
|
| 60 |
+
convert_image_to_rgb,
|
| 61 |
+
T.ToTensor(),
|
| 62 |
+
T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
| 63 |
+
])
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class MonetTool:
|
| 67 |
+
"""
|
| 68 |
+
MONET tool for extracting concept presence scores from skin lesion images.
|
| 69 |
+
Uses the proper contrastive scoring method from the MONET paper.
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
def __init__(self, device: Optional[str] = None, use_hf: bool = True):
|
| 73 |
+
"""
|
| 74 |
+
Args:
|
| 75 |
+
device: Device to run on (cuda, mps, cpu)
|
| 76 |
+
use_hf: Use HuggingFace implementation (True) or original CLIP (False)
|
| 77 |
+
"""
|
| 78 |
+
self.model = None
|
| 79 |
+
self.processor = None
|
| 80 |
+
self.device = device
|
| 81 |
+
self.use_hf = use_hf
|
| 82 |
+
self.loaded = False
|
| 83 |
+
self.transform = get_transform(224)
|
| 84 |
+
|
| 85 |
+
# Cache for concept embeddings
|
| 86 |
+
self._concept_embeddings = {}
|
| 87 |
+
|
| 88 |
+
def load(self):
|
| 89 |
+
"""Load MONET model"""
|
| 90 |
+
if self.loaded:
|
| 91 |
+
return
|
| 92 |
+
|
| 93 |
+
# Determine device
|
| 94 |
+
if self.device is None:
|
| 95 |
+
if torch.cuda.is_available():
|
| 96 |
+
self.device = "cuda:0"
|
| 97 |
+
elif torch.backends.mps.is_available():
|
| 98 |
+
self.device = "mps"
|
| 99 |
+
else:
|
| 100 |
+
self.device = "cpu"
|
| 101 |
+
|
| 102 |
+
if self.use_hf:
|
| 103 |
+
# HuggingFace implementation
|
| 104 |
+
from transformers import AutoProcessor, AutoModelForZeroShotImageClassification
|
| 105 |
+
|
| 106 |
+
self.processor = AutoProcessor.from_pretrained("chanwkim/monet")
|
| 107 |
+
self.model = AutoModelForZeroShotImageClassification.from_pretrained("chanwkim/monet")
|
| 108 |
+
self.model.to(self.device)
|
| 109 |
+
self.model.eval()
|
| 110 |
+
else:
|
| 111 |
+
# Original CLIP implementation
|
| 112 |
+
import clip
|
| 113 |
+
|
| 114 |
+
self.model, _ = clip.load("ViT-L/14", device=self.device, jit=False)
|
| 115 |
+
self.model.load_state_dict(
|
| 116 |
+
torch.hub.load_state_dict_from_url(
|
| 117 |
+
"https://aimslab.cs.washington.edu/MONET/weight_clip.pt"
|
| 118 |
+
)
|
| 119 |
+
)
|
| 120 |
+
self.model.eval()
|
| 121 |
+
|
| 122 |
+
self.loaded = True
|
| 123 |
+
|
| 124 |
+
# Pre-compute concept embeddings for all MONET features
|
| 125 |
+
self._precompute_concept_embeddings()
|
| 126 |
+
|
| 127 |
+
def _encode_text(self, text_list: List[str]) -> torch.Tensor:
|
| 128 |
+
"""Encode text to embeddings"""
|
| 129 |
+
if self.use_hf:
|
| 130 |
+
inputs = self.processor(text=text_list, return_tensors="pt", padding=True)
|
| 131 |
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
| 132 |
+
with torch.no_grad():
|
| 133 |
+
embeddings = self.model.get_text_features(**inputs)
|
| 134 |
+
else:
|
| 135 |
+
import clip
|
| 136 |
+
tokens = clip.tokenize(text_list, truncate=True).to(self.device)
|
| 137 |
+
with torch.no_grad():
|
| 138 |
+
embeddings = self.model.encode_text(tokens)
|
| 139 |
+
|
| 140 |
+
return embeddings.cpu()
|
| 141 |
+
|
| 142 |
+
def _encode_image(self, image: Image.Image) -> torch.Tensor:
|
| 143 |
+
"""Encode image to embedding"""
|
| 144 |
+
image_tensor = self.transform(image).unsqueeze(0).to(self.device)
|
| 145 |
+
|
| 146 |
+
if self.use_hf:
|
| 147 |
+
with torch.no_grad():
|
| 148 |
+
embedding = self.model.get_image_features(image_tensor)
|
| 149 |
+
else:
|
| 150 |
+
with torch.no_grad():
|
| 151 |
+
embedding = self.model.encode_image(image_tensor)
|
| 152 |
+
|
| 153 |
+
return embedding.cpu()
|
| 154 |
+
|
| 155 |
+
def _precompute_concept_embeddings(self):
|
| 156 |
+
"""Pre-compute embeddings for all MONET concepts"""
|
| 157 |
+
for feature_name, concept_terms in MONET_CONCEPT_TERMS.items():
|
| 158 |
+
self._concept_embeddings[feature_name] = self._get_concept_embedding(concept_terms)
|
| 159 |
+
|
| 160 |
+
def _get_concept_embedding(self, concept_terms: List[str]) -> Dict:
|
| 161 |
+
"""
|
| 162 |
+
Generate prompt embeddings for a concept using multiple templates.
|
| 163 |
+
|
| 164 |
+
Args:
|
| 165 |
+
concept_terms: List of synonymous terms for the concept
|
| 166 |
+
|
| 167 |
+
Returns:
|
| 168 |
+
dict with target and reference embeddings
|
| 169 |
+
"""
|
| 170 |
+
# Target prompts: "This is skin image of {term}"
|
| 171 |
+
prompt_target = [
|
| 172 |
+
[template.format(term) for term in concept_terms]
|
| 173 |
+
for template in PROMPT_TEMPLATES
|
| 174 |
+
]
|
| 175 |
+
|
| 176 |
+
# Flatten and encode
|
| 177 |
+
prompt_target_flat = [p for template_prompts in prompt_target for p in template_prompts]
|
| 178 |
+
target_embeddings = self._encode_text(prompt_target_flat)
|
| 179 |
+
|
| 180 |
+
# Reshape to [num_templates, num_terms, embed_dim]
|
| 181 |
+
num_templates = len(PROMPT_TEMPLATES)
|
| 182 |
+
num_terms = len(concept_terms)
|
| 183 |
+
embed_dim = target_embeddings.shape[-1]
|
| 184 |
+
target_embeddings = target_embeddings.view(num_templates, num_terms, embed_dim)
|
| 185 |
+
|
| 186 |
+
# Normalize
|
| 187 |
+
target_embeddings_norm = F.normalize(target_embeddings, dim=2)
|
| 188 |
+
|
| 189 |
+
# Reference prompts: "This is skin image"
|
| 190 |
+
prompt_ref_flat = [p for ref_list in PROMPT_REFS for p in ref_list]
|
| 191 |
+
ref_embeddings = self._encode_text(prompt_ref_flat)
|
| 192 |
+
ref_embeddings = ref_embeddings.view(num_templates, -1, embed_dim)
|
| 193 |
+
ref_embeddings_norm = F.normalize(ref_embeddings, dim=2)
|
| 194 |
+
|
| 195 |
+
return {
|
| 196 |
+
"target_embedding_norm": target_embeddings_norm,
|
| 197 |
+
"ref_embedding_norm": ref_embeddings_norm,
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
def _calculate_concept_score(
|
| 201 |
+
self,
|
| 202 |
+
image_features_norm: torch.Tensor,
|
| 203 |
+
concept_embedding: Dict,
|
| 204 |
+
temp: float = 1 / np.exp(4.5944)
|
| 205 |
+
) -> float:
|
| 206 |
+
"""
|
| 207 |
+
Calculate concept presence score using contrastive comparison.
|
| 208 |
+
|
| 209 |
+
Args:
|
| 210 |
+
image_features_norm: Normalized image embedding [1, embed_dim]
|
| 211 |
+
concept_embedding: Dict with target and reference embeddings
|
| 212 |
+
temp: Temperature for softmax
|
| 213 |
+
|
| 214 |
+
Returns:
|
| 215 |
+
Concept presence score (0-1)
|
| 216 |
+
"""
|
| 217 |
+
target_emb = concept_embedding["target_embedding_norm"].float()
|
| 218 |
+
ref_emb = concept_embedding["ref_embedding_norm"].float()
|
| 219 |
+
|
| 220 |
+
# Similarity: [num_templates, num_terms] @ [embed_dim, 1] -> [num_templates, num_terms, 1]
|
| 221 |
+
target_similarity = target_emb @ image_features_norm.T.float()
|
| 222 |
+
ref_similarity = ref_emb @ image_features_norm.T.float()
|
| 223 |
+
|
| 224 |
+
# Mean over terms for each template
|
| 225 |
+
target_mean = target_similarity.mean(dim=1).squeeze() # [num_templates]
|
| 226 |
+
ref_mean = ref_similarity.mean(dim=1).squeeze() # [num_templates]
|
| 227 |
+
|
| 228 |
+
# Softmax between target and reference (contrastive scoring)
|
| 229 |
+
scores = scipy.special.softmax(
|
| 230 |
+
np.array([target_mean.numpy() / temp, ref_mean.numpy() / temp]),
|
| 231 |
+
axis=0
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
# Return mean of target scores across templates
|
| 235 |
+
return float(scores[0].mean())
|
| 236 |
+
|
| 237 |
+
def extract_features(self, image: Image.Image) -> Dict[str, float]:
|
| 238 |
+
"""
|
| 239 |
+
Extract MONET feature scores from a skin lesion image.
|
| 240 |
+
|
| 241 |
+
Args:
|
| 242 |
+
image: PIL Image to analyze
|
| 243 |
+
|
| 244 |
+
Returns:
|
| 245 |
+
dict with 7 MONET feature scores (0-1 range)
|
| 246 |
+
"""
|
| 247 |
+
if not self.loaded:
|
| 248 |
+
self.load()
|
| 249 |
+
|
| 250 |
+
# Ensure RGB
|
| 251 |
+
if image.mode != "RGB":
|
| 252 |
+
image = image.convert("RGB")
|
| 253 |
+
|
| 254 |
+
# Get image embedding
|
| 255 |
+
image_features = self._encode_image(image)
|
| 256 |
+
image_features_norm = F.normalize(image_features, dim=1)
|
| 257 |
+
|
| 258 |
+
# Calculate score for each MONET feature
|
| 259 |
+
features = {}
|
| 260 |
+
for feature_name in MONET_FEATURES:
|
| 261 |
+
concept_emb = self._concept_embeddings[feature_name]
|
| 262 |
+
score = self._calculate_concept_score(image_features_norm, concept_emb)
|
| 263 |
+
features[feature_name] = score
|
| 264 |
+
|
| 265 |
+
return features
|
| 266 |
+
|
| 267 |
+
def get_feature_vector(self, image: Image.Image) -> List[float]:
|
| 268 |
+
"""Get MONET features as a list in the expected order."""
|
| 269 |
+
features = self.extract_features(image)
|
| 270 |
+
return [features[f] for f in MONET_FEATURES]
|
| 271 |
+
|
| 272 |
+
def get_feature_tensor(self, image: Image.Image) -> torch.Tensor:
|
| 273 |
+
"""Get MONET features as a PyTorch tensor."""
|
| 274 |
+
return torch.tensor(self.get_feature_vector(image), dtype=torch.float32)
|
| 275 |
+
|
| 276 |
+
def describe_features(self, features: Dict[str, float], threshold: float = 0.6) -> str:
|
| 277 |
+
"""Generate a natural language description of the MONET features."""
|
| 278 |
+
descriptions = {
|
| 279 |
+
"MONET_ulceration_crust": "ulceration or crusting",
|
| 280 |
+
"MONET_hair": "visible hair",
|
| 281 |
+
"MONET_vasculature_vessels": "visible blood vessels",
|
| 282 |
+
"MONET_erythema": "erythema (redness)",
|
| 283 |
+
"MONET_pigmented": "pigmentation",
|
| 284 |
+
"MONET_gel_water_drop_fluid_dermoscopy_liquid": "dermoscopy gel/fluid",
|
| 285 |
+
"MONET_skin_markings_pen_ink_purple_pen": "pen markings",
|
| 286 |
+
}
|
| 287 |
+
|
| 288 |
+
present = []
|
| 289 |
+
for feature, score in features.items():
|
| 290 |
+
if score >= threshold:
|
| 291 |
+
desc = descriptions.get(feature, feature)
|
| 292 |
+
present.append(f"{desc} ({score:.0%})")
|
| 293 |
+
|
| 294 |
+
if not present:
|
| 295 |
+
# Show top features even if below threshold
|
| 296 |
+
sorted_features = sorted(features.items(), key=lambda x: x[1], reverse=True)[:3]
|
| 297 |
+
present = [f"{descriptions.get(f, f)} ({s:.0%})" for f, s in sorted_features]
|
| 298 |
+
|
| 299 |
+
return "Detected features: " + ", ".join(present)
|
| 300 |
+
|
| 301 |
+
def analyze(self, image: Image.Image) -> Dict:
|
| 302 |
+
"""Full analysis returning features, vector, and description."""
|
| 303 |
+
features = self.extract_features(image)
|
| 304 |
+
vector = [features[f] for f in MONET_FEATURES]
|
| 305 |
+
description = self.describe_features(features)
|
| 306 |
+
|
| 307 |
+
return {
|
| 308 |
+
"features": features,
|
| 309 |
+
"vector": vector,
|
| 310 |
+
"description": description,
|
| 311 |
+
"feature_names": MONET_FEATURES,
|
| 312 |
+
}
|
| 313 |
+
|
| 314 |
+
def __call__(self, image: Image.Image) -> Dict:
|
| 315 |
+
"""Shorthand for analyze()"""
|
| 316 |
+
return self.analyze(image)
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
# Singleton instance
|
| 320 |
+
_monet_instance = None
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
def get_monet_tool() -> MonetTool:
|
| 324 |
+
"""Get or create MONET tool instance"""
|
| 325 |
+
global _monet_instance
|
| 326 |
+
if _monet_instance is None:
|
| 327 |
+
_monet_instance = MonetTool()
|
| 328 |
+
return _monet_instance
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
if __name__ == "__main__":
|
| 332 |
+
import sys
|
| 333 |
+
|
| 334 |
+
print("MONET Tool Test (Correct Implementation)")
|
| 335 |
+
print("=" * 50)
|
| 336 |
+
|
| 337 |
+
tool = MonetTool(use_hf=True)
|
| 338 |
+
print("Loading model...")
|
| 339 |
+
tool.load()
|
| 340 |
+
print("Model loaded!")
|
| 341 |
+
|
| 342 |
+
if len(sys.argv) > 1:
|
| 343 |
+
image_path = sys.argv[1]
|
| 344 |
+
print(f"\nAnalyzing: {image_path}")
|
| 345 |
+
image = Image.open(image_path).convert("RGB")
|
| 346 |
+
result = tool.analyze(image)
|
| 347 |
+
|
| 348 |
+
print("\nMONET Features (Contrastive Scores):")
|
| 349 |
+
for name, score in result["features"].items():
|
| 350 |
+
bar = "█" * int(score * 20)
|
| 351 |
+
print(f" {name}: {score:.3f} {bar}")
|
| 352 |
+
|
| 353 |
+
print(f"\nDescription: {result['description']}")
|
| 354 |
+
print(f"\nVector: {[f'{v:.3f}' for v in result['vector']]}")
|
models/overlay_tool.py
ADDED
|
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Overlay Tool - Generates visual markers for biopsy sites and excision margins
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import io
|
| 6 |
+
import tempfile
|
| 7 |
+
from typing import Tuple, Optional, Dict, Any
|
| 8 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class OverlayTool:
|
| 12 |
+
"""
|
| 13 |
+
Generates image overlays for clinical decision visualization:
|
| 14 |
+
- Biopsy site markers (circles)
|
| 15 |
+
- Excision margins (dashed outlines with margin indicators)
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
# Colors for different marker types
|
| 19 |
+
COLORS = {
|
| 20 |
+
'biopsy': (255, 69, 0, 200), # Orange-red with alpha
|
| 21 |
+
'excision': (220, 20, 60, 200), # Crimson with alpha
|
| 22 |
+
'margin': (255, 215, 0, 180), # Gold for margin line
|
| 23 |
+
'text': (255, 255, 255, 255), # White text
|
| 24 |
+
'text_bg': (0, 0, 0, 180), # Semi-transparent black bg
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
def __init__(self):
|
| 28 |
+
self.loaded = True
|
| 29 |
+
|
| 30 |
+
def generate_biopsy_overlay(
|
| 31 |
+
self,
|
| 32 |
+
image: Image.Image,
|
| 33 |
+
center_x: float,
|
| 34 |
+
center_y: float,
|
| 35 |
+
radius: float = 0.05,
|
| 36 |
+
label: str = "Biopsy Site"
|
| 37 |
+
) -> Dict[str, Any]:
|
| 38 |
+
"""
|
| 39 |
+
Generate biopsy site overlay with circle marker.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
image: PIL Image
|
| 43 |
+
center_x: X coordinate as fraction (0-1) of image width
|
| 44 |
+
center_y: Y coordinate as fraction (0-1) of image height
|
| 45 |
+
radius: Radius as fraction of image width
|
| 46 |
+
label: Text label for the marker
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
Dict with overlay image and metadata
|
| 50 |
+
"""
|
| 51 |
+
# Convert to RGBA for transparency
|
| 52 |
+
img = image.convert("RGBA")
|
| 53 |
+
width, height = img.size
|
| 54 |
+
|
| 55 |
+
# Create overlay layer
|
| 56 |
+
overlay = Image.new("RGBA", img.size, (0, 0, 0, 0))
|
| 57 |
+
draw = ImageDraw.Draw(overlay)
|
| 58 |
+
|
| 59 |
+
# Calculate pixel coordinates
|
| 60 |
+
cx = int(center_x * width)
|
| 61 |
+
cy = int(center_y * height)
|
| 62 |
+
r = int(radius * width)
|
| 63 |
+
|
| 64 |
+
# Draw outer circle (thicker)
|
| 65 |
+
for offset in range(3):
|
| 66 |
+
draw.ellipse(
|
| 67 |
+
[cx - r - offset, cy - r - offset, cx + r + offset, cy + r + offset],
|
| 68 |
+
outline=self.COLORS['biopsy'],
|
| 69 |
+
width=2
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
# Draw crosshairs
|
| 73 |
+
line_len = r // 2
|
| 74 |
+
draw.line([(cx - line_len, cy), (cx + line_len, cy)],
|
| 75 |
+
fill=self.COLORS['biopsy'], width=2)
|
| 76 |
+
draw.line([(cx, cy - line_len), (cx, cy + line_len)],
|
| 77 |
+
fill=self.COLORS['biopsy'], width=2)
|
| 78 |
+
|
| 79 |
+
# Draw label with background
|
| 80 |
+
try:
|
| 81 |
+
font = ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", 14)
|
| 82 |
+
except:
|
| 83 |
+
font = ImageFont.load_default()
|
| 84 |
+
|
| 85 |
+
text_bbox = draw.textbbox((0, 0), label, font=font)
|
| 86 |
+
text_width = text_bbox[2] - text_bbox[0]
|
| 87 |
+
text_height = text_bbox[3] - text_bbox[1]
|
| 88 |
+
|
| 89 |
+
text_x = cx - text_width // 2
|
| 90 |
+
text_y = cy + r + 10
|
| 91 |
+
|
| 92 |
+
# Background rectangle for text
|
| 93 |
+
padding = 4
|
| 94 |
+
draw.rectangle(
|
| 95 |
+
[text_x - padding, text_y - padding,
|
| 96 |
+
text_x + text_width + padding, text_y + text_height + padding],
|
| 97 |
+
fill=self.COLORS['text_bg']
|
| 98 |
+
)
|
| 99 |
+
draw.text((text_x, text_y), label, fill=self.COLORS['text'], font=font)
|
| 100 |
+
|
| 101 |
+
# Composite
|
| 102 |
+
result = Image.alpha_composite(img, overlay)
|
| 103 |
+
|
| 104 |
+
# Save to temp file
|
| 105 |
+
temp_file = tempfile.NamedTemporaryFile(suffix="_biopsy_overlay.png", delete=False)
|
| 106 |
+
result.save(temp_file.name, "PNG")
|
| 107 |
+
temp_file.close()
|
| 108 |
+
|
| 109 |
+
return {
|
| 110 |
+
"overlay": result,
|
| 111 |
+
"path": temp_file.name,
|
| 112 |
+
"type": "biopsy",
|
| 113 |
+
"coordinates": {
|
| 114 |
+
"center_x": center_x,
|
| 115 |
+
"center_y": center_y,
|
| 116 |
+
"radius": radius
|
| 117 |
+
},
|
| 118 |
+
"label": label
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
def generate_excision_overlay(
|
| 122 |
+
self,
|
| 123 |
+
image: Image.Image,
|
| 124 |
+
center_x: float,
|
| 125 |
+
center_y: float,
|
| 126 |
+
lesion_radius: float,
|
| 127 |
+
margin_mm: int = 5,
|
| 128 |
+
pixels_per_mm: float = 10.0,
|
| 129 |
+
label: str = "Excision Margin"
|
| 130 |
+
) -> Dict[str, Any]:
|
| 131 |
+
"""
|
| 132 |
+
Generate excision margin overlay with inner (lesion) and outer (margin) boundaries.
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
image: PIL Image
|
| 136 |
+
center_x: X coordinate as fraction (0-1)
|
| 137 |
+
center_y: Y coordinate as fraction (0-1)
|
| 138 |
+
lesion_radius: Lesion radius as fraction of image width
|
| 139 |
+
margin_mm: Excision margin in millimeters
|
| 140 |
+
pixels_per_mm: Estimated pixels per mm (for margin calculation)
|
| 141 |
+
label: Text label
|
| 142 |
+
|
| 143 |
+
Returns:
|
| 144 |
+
Dict with overlay image and metadata
|
| 145 |
+
"""
|
| 146 |
+
img = image.convert("RGBA")
|
| 147 |
+
width, height = img.size
|
| 148 |
+
|
| 149 |
+
overlay = Image.new("RGBA", img.size, (0, 0, 0, 0))
|
| 150 |
+
draw = ImageDraw.Draw(overlay)
|
| 151 |
+
|
| 152 |
+
# Calculate coordinates
|
| 153 |
+
cx = int(center_x * width)
|
| 154 |
+
cy = int(center_y * height)
|
| 155 |
+
inner_r = int(lesion_radius * width)
|
| 156 |
+
|
| 157 |
+
# Calculate margin in pixels
|
| 158 |
+
margin_px = int(margin_mm * pixels_per_mm)
|
| 159 |
+
outer_r = inner_r + margin_px
|
| 160 |
+
|
| 161 |
+
# Draw outer margin (dashed effect using multiple arcs)
|
| 162 |
+
dash_length = 10
|
| 163 |
+
for angle in range(0, 360, dash_length * 2):
|
| 164 |
+
draw.arc(
|
| 165 |
+
[cx - outer_r, cy - outer_r, cx + outer_r, cy + outer_r],
|
| 166 |
+
start=angle,
|
| 167 |
+
end=angle + dash_length,
|
| 168 |
+
fill=self.COLORS['margin'],
|
| 169 |
+
width=3
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
# Draw inner lesion boundary (solid)
|
| 173 |
+
draw.ellipse(
|
| 174 |
+
[cx - inner_r, cy - inner_r, cx + inner_r, cy + inner_r],
|
| 175 |
+
outline=self.COLORS['excision'],
|
| 176 |
+
width=2
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
# Draw margin indicator lines (radial)
|
| 180 |
+
for angle in [0, 90, 180, 270]:
|
| 181 |
+
import math
|
| 182 |
+
rad = math.radians(angle)
|
| 183 |
+
inner_x = cx + int(inner_r * math.cos(rad))
|
| 184 |
+
inner_y = cy + int(inner_r * math.sin(rad))
|
| 185 |
+
outer_x = cx + int(outer_r * math.cos(rad))
|
| 186 |
+
outer_y = cy + int(outer_r * math.sin(rad))
|
| 187 |
+
draw.line([(inner_x, inner_y), (outer_x, outer_y)],
|
| 188 |
+
fill=self.COLORS['margin'], width=2)
|
| 189 |
+
|
| 190 |
+
# Draw labels
|
| 191 |
+
try:
|
| 192 |
+
font = ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", 12)
|
| 193 |
+
font_small = ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", 10)
|
| 194 |
+
except:
|
| 195 |
+
font = ImageFont.load_default()
|
| 196 |
+
font_small = font
|
| 197 |
+
|
| 198 |
+
# Main label
|
| 199 |
+
text_bbox = draw.textbbox((0, 0), label, font=font)
|
| 200 |
+
text_width = text_bbox[2] - text_bbox[0]
|
| 201 |
+
text_height = text_bbox[3] - text_bbox[1]
|
| 202 |
+
|
| 203 |
+
text_x = cx - text_width // 2
|
| 204 |
+
text_y = cy + outer_r + 15
|
| 205 |
+
|
| 206 |
+
padding = 4
|
| 207 |
+
draw.rectangle(
|
| 208 |
+
[text_x - padding, text_y - padding,
|
| 209 |
+
text_x + text_width + padding, text_y + text_height + padding],
|
| 210 |
+
fill=self.COLORS['text_bg']
|
| 211 |
+
)
|
| 212 |
+
draw.text((text_x, text_y), label, fill=self.COLORS['text'], font=font)
|
| 213 |
+
|
| 214 |
+
# Margin measurement label
|
| 215 |
+
margin_label = f"{margin_mm}mm margin"
|
| 216 |
+
margin_bbox = draw.textbbox((0, 0), margin_label, font=font_small)
|
| 217 |
+
margin_width = margin_bbox[2] - margin_bbox[0]
|
| 218 |
+
|
| 219 |
+
margin_text_x = cx + outer_r + 5
|
| 220 |
+
margin_text_y = cy - 6
|
| 221 |
+
|
| 222 |
+
draw.rectangle(
|
| 223 |
+
[margin_text_x - 2, margin_text_y - 2,
|
| 224 |
+
margin_text_x + margin_width + 2, margin_text_y + 12],
|
| 225 |
+
fill=self.COLORS['text_bg']
|
| 226 |
+
)
|
| 227 |
+
draw.text((margin_text_x, margin_text_y), margin_label,
|
| 228 |
+
fill=self.COLORS['margin'], font=font_small)
|
| 229 |
+
|
| 230 |
+
# Composite
|
| 231 |
+
result = Image.alpha_composite(img, overlay)
|
| 232 |
+
|
| 233 |
+
temp_file = tempfile.NamedTemporaryFile(suffix="_excision_overlay.png", delete=False)
|
| 234 |
+
result.save(temp_file.name, "PNG")
|
| 235 |
+
temp_file.close()
|
| 236 |
+
|
| 237 |
+
return {
|
| 238 |
+
"overlay": result,
|
| 239 |
+
"path": temp_file.name,
|
| 240 |
+
"type": "excision",
|
| 241 |
+
"coordinates": {
|
| 242 |
+
"center_x": center_x,
|
| 243 |
+
"center_y": center_y,
|
| 244 |
+
"lesion_radius": lesion_radius,
|
| 245 |
+
"margin_mm": margin_mm,
|
| 246 |
+
"total_radius": outer_r / width
|
| 247 |
+
},
|
| 248 |
+
"label": label
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
def generate_comparison_overlay(
|
| 252 |
+
self,
|
| 253 |
+
image1: Image.Image,
|
| 254 |
+
image2: Image.Image,
|
| 255 |
+
label1: str = "Previous",
|
| 256 |
+
label2: str = "Current"
|
| 257 |
+
) -> Dict[str, Any]:
|
| 258 |
+
"""
|
| 259 |
+
Generate side-by-side comparison of two images for follow-up.
|
| 260 |
+
|
| 261 |
+
Args:
|
| 262 |
+
image1: First (previous) image
|
| 263 |
+
image2: Second (current) image
|
| 264 |
+
label1: Label for first image
|
| 265 |
+
label2: Label for second image
|
| 266 |
+
|
| 267 |
+
Returns:
|
| 268 |
+
Dict with comparison image and metadata
|
| 269 |
+
"""
|
| 270 |
+
# Resize to same height
|
| 271 |
+
max_height = 400
|
| 272 |
+
|
| 273 |
+
# Calculate sizes maintaining aspect ratio
|
| 274 |
+
w1, h1 = image1.size
|
| 275 |
+
w2, h2 = image2.size
|
| 276 |
+
|
| 277 |
+
ratio1 = max_height / h1
|
| 278 |
+
ratio2 = max_height / h2
|
| 279 |
+
|
| 280 |
+
new_w1 = int(w1 * ratio1)
|
| 281 |
+
new_w2 = int(w2 * ratio2)
|
| 282 |
+
|
| 283 |
+
img1 = image1.resize((new_w1, max_height), Image.Resampling.LANCZOS)
|
| 284 |
+
img2 = image2.resize((new_w2, max_height), Image.Resampling.LANCZOS)
|
| 285 |
+
|
| 286 |
+
# Create comparison canvas
|
| 287 |
+
gap = 20
|
| 288 |
+
total_width = new_w1 + gap + new_w2
|
| 289 |
+
header_height = 30
|
| 290 |
+
total_height = max_height + header_height
|
| 291 |
+
|
| 292 |
+
canvas = Image.new("RGB", (total_width, total_height), (255, 255, 255))
|
| 293 |
+
draw = ImageDraw.Draw(canvas)
|
| 294 |
+
|
| 295 |
+
# Draw labels
|
| 296 |
+
try:
|
| 297 |
+
font = ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", 14)
|
| 298 |
+
except:
|
| 299 |
+
font = ImageFont.load_default()
|
| 300 |
+
|
| 301 |
+
# Previous label
|
| 302 |
+
draw.rectangle([0, 0, new_w1, header_height], fill=(70, 130, 180))
|
| 303 |
+
bbox1 = draw.textbbox((0, 0), label1, font=font)
|
| 304 |
+
text_w1 = bbox1[2] - bbox1[0]
|
| 305 |
+
draw.text(((new_w1 - text_w1) // 2, 8), label1, fill=(255, 255, 255), font=font)
|
| 306 |
+
|
| 307 |
+
# Current label
|
| 308 |
+
draw.rectangle([new_w1 + gap, 0, total_width, header_height], fill=(60, 179, 113))
|
| 309 |
+
bbox2 = draw.textbbox((0, 0), label2, font=font)
|
| 310 |
+
text_w2 = bbox2[2] - bbox2[0]
|
| 311 |
+
draw.text((new_w1 + gap + (new_w2 - text_w2) // 2, 8), label2,
|
| 312 |
+
fill=(255, 255, 255), font=font)
|
| 313 |
+
|
| 314 |
+
# Paste images
|
| 315 |
+
canvas.paste(img1, (0, header_height))
|
| 316 |
+
canvas.paste(img2, (new_w1 + gap, header_height))
|
| 317 |
+
|
| 318 |
+
# Draw divider
|
| 319 |
+
draw.line([(new_w1 + gap // 2, header_height), (new_w1 + gap // 2, total_height)],
|
| 320 |
+
fill=(200, 200, 200), width=2)
|
| 321 |
+
|
| 322 |
+
temp_file = tempfile.NamedTemporaryFile(suffix="_comparison.png", delete=False)
|
| 323 |
+
canvas.save(temp_file.name, "PNG")
|
| 324 |
+
temp_file.close()
|
| 325 |
+
|
| 326 |
+
return {
|
| 327 |
+
"comparison": canvas,
|
| 328 |
+
"path": temp_file.name,
|
| 329 |
+
"type": "comparison"
|
| 330 |
+
}
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def get_overlay_tool() -> OverlayTool:
|
| 334 |
+
"""Get overlay tool instance"""
|
| 335 |
+
return OverlayTool()
|
requirements.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# requirements.txt
|
| 2 |
+
|
| 3 |
+
torch>=2.0.0
|
| 4 |
+
torchvision>=0.15.0
|
| 5 |
+
transformers>=4.40.0
|
| 6 |
+
timm>=0.9.0
|
| 7 |
+
gradio==4.44.0
|
| 8 |
+
gradio-client==1.3.0
|
| 9 |
+
opencv-python>=4.8.0
|
| 10 |
+
numpy>=1.24.0
|
| 11 |
+
Pillow>=10.0.0
|
| 12 |
+
sentencepiece>=0.1.99
|
| 13 |
+
accelerate>=0.25.0
|
| 14 |
+
protobuf>=4.0.0
|
| 15 |
+
mcp>=1.0.0 # installed via python3.11 (requires Python >=3.10)
|
test_models.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Test script to verify model loading"""
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import timm
|
| 7 |
+
from transformers import AutoModel, AutoProcessor
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
DEVICE = "cpu"
|
| 11 |
+
print(f"Device: {DEVICE}")
|
| 12 |
+
|
| 13 |
+
# ConvNeXt model definition (matching checkpoint)
|
| 14 |
+
class ConvNeXtDualEncoder(nn.Module):
|
| 15 |
+
def __init__(self, model_name="convnext_base.fb_in22k_ft_in1k",
|
| 16 |
+
metadata_dim=19, num_classes=11, dropout=0.3):
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.backbone = timm.create_model(model_name, pretrained=False, num_classes=0)
|
| 19 |
+
backbone_dim = self.backbone.num_features
|
| 20 |
+
self.meta_mlp = nn.Sequential(
|
| 21 |
+
nn.Linear(metadata_dim, 64), nn.LayerNorm(64), nn.GELU(), nn.Dropout(dropout)
|
| 22 |
+
)
|
| 23 |
+
fusion_dim = backbone_dim * 2 + 64
|
| 24 |
+
self.classifier = nn.Sequential(
|
| 25 |
+
nn.Linear(fusion_dim, 512), nn.LayerNorm(512), nn.GELU(), nn.Dropout(dropout),
|
| 26 |
+
nn.Linear(512, 256), nn.LayerNorm(256), nn.GELU(), nn.Dropout(dropout),
|
| 27 |
+
nn.Linear(256, num_classes)
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
def forward(self, clinical_img, derm_img=None, metadata=None):
|
| 31 |
+
clinical_features = self.backbone(clinical_img)
|
| 32 |
+
derm_features = self.backbone(derm_img) if derm_img is not None else clinical_features
|
| 33 |
+
if metadata is not None:
|
| 34 |
+
meta_features = self.meta_mlp(metadata)
|
| 35 |
+
else:
|
| 36 |
+
meta_features = torch.zeros(clinical_features.size(0), 64, device=clinical_features.device)
|
| 37 |
+
fused = torch.cat([clinical_features, derm_features, meta_features], dim=1)
|
| 38 |
+
return self.classifier(fused)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# MedSigLIP model definition
|
| 42 |
+
class MedSigLIPClassifier(nn.Module):
|
| 43 |
+
def __init__(self, num_classes=11, model_name="google/siglip-base-patch16-384"):
|
| 44 |
+
super().__init__()
|
| 45 |
+
self.siglip = AutoModel.from_pretrained(model_name)
|
| 46 |
+
self.processor = AutoProcessor.from_pretrained(model_name)
|
| 47 |
+
hidden_dim = self.siglip.config.vision_config.hidden_size
|
| 48 |
+
self.classifier = nn.Sequential(
|
| 49 |
+
nn.Linear(hidden_dim, 512), nn.LayerNorm(512), nn.GELU(), nn.Dropout(0.3),
|
| 50 |
+
nn.Linear(512, num_classes)
|
| 51 |
+
)
|
| 52 |
+
for param in self.siglip.parameters():
|
| 53 |
+
param.requires_grad = False
|
| 54 |
+
|
| 55 |
+
def forward(self, pixel_values):
|
| 56 |
+
vision_outputs = self.siglip.vision_model(pixel_values=pixel_values)
|
| 57 |
+
pooled_features = vision_outputs.pooler_output
|
| 58 |
+
return self.classifier(pooled_features)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
if __name__ == "__main__":
|
| 62 |
+
print("\n[1/2] Loading ConvNeXt...")
|
| 63 |
+
convnext_model = ConvNeXtDualEncoder()
|
| 64 |
+
ckpt = torch.load("models/seed42_fold0.pt", map_location=DEVICE, weights_only=False)
|
| 65 |
+
convnext_model.load_state_dict(ckpt)
|
| 66 |
+
convnext_model.eval()
|
| 67 |
+
print(" ConvNeXt loaded!")
|
| 68 |
+
|
| 69 |
+
print("\n[2/2] Loading MedSigLIP...")
|
| 70 |
+
medsiglip_model = MedSigLIPClassifier()
|
| 71 |
+
medsiglip_model.eval()
|
| 72 |
+
print(" MedSigLIP loaded!")
|
| 73 |
+
|
| 74 |
+
# Quick inference test
|
| 75 |
+
print("\nTesting inference...")
|
| 76 |
+
dummy_img = torch.randn(1, 3, 384, 384)
|
| 77 |
+
with torch.no_grad():
|
| 78 |
+
convnext_out = convnext_model(dummy_img)
|
| 79 |
+
print(f" ConvNeXt output: {convnext_out.shape}")
|
| 80 |
+
|
| 81 |
+
dummy_pil = np.random.randint(0, 255, (384, 384, 3), dtype=np.uint8)
|
| 82 |
+
siglip_input = medsiglip_model.processor(images=[dummy_pil], return_tensors="pt")
|
| 83 |
+
siglip_out = medsiglip_model(siglip_input["pixel_values"])
|
| 84 |
+
print(f" MedSigLIP output: {siglip_out.shape}")
|
| 85 |
+
|
| 86 |
+
print("\nAll tests passed!")
|
web/index.html
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="UTF-8" />
|
| 5 |
+
<link rel="icon" type="image/svg+xml" href="/vite.svg" />
|
| 6 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
| 7 |
+
<title>SkinProAI</title>
|
| 8 |
+
<link rel="preconnect" href="https://fonts.googleapis.com">
|
| 9 |
+
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
|
| 10 |
+
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap" rel="stylesheet">
|
| 11 |
+
</head>
|
| 12 |
+
<body>
|
| 13 |
+
<div id="root"></div>
|
| 14 |
+
<script type="module" src="/src/main.tsx"></script>
|
| 15 |
+
</body>
|
| 16 |
+
</html>
|
web/package-lock.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
web/package.json
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"name": "skinproai-web",
|
| 3 |
+
"private": true,
|
| 4 |
+
"version": "1.0.0",
|
| 5 |
+
"type": "module",
|
| 6 |
+
"scripts": {
|
| 7 |
+
"dev": "vite",
|
| 8 |
+
"build": "tsc && vite build",
|
| 9 |
+
"preview": "vite preview"
|
| 10 |
+
},
|
| 11 |
+
"dependencies": {
|
| 12 |
+
"react": "^18.2.0",
|
| 13 |
+
"react-dom": "^18.2.0",
|
| 14 |
+
"react-markdown": "^10.1.0",
|
| 15 |
+
"react-router-dom": "^6.20.0"
|
| 16 |
+
},
|
| 17 |
+
"devDependencies": {
|
| 18 |
+
"@types/react": "^18.2.0",
|
| 19 |
+
"@types/react-dom": "^18.2.0",
|
| 20 |
+
"@vitejs/plugin-react": "^4.2.0",
|
| 21 |
+
"typescript": "^5.3.0",
|
| 22 |
+
"vite": "^5.0.0"
|
| 23 |
+
}
|
| 24 |
+
}
|
web/src/App.tsx
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { BrowserRouter, Routes, Route } from 'react-router-dom';
|
| 2 |
+
import { PatientsPage } from './pages/PatientsPage';
|
| 3 |
+
import { ChatPage } from './pages/ChatPage';
|
| 4 |
+
|
| 5 |
+
export function App() {
|
| 6 |
+
return (
|
| 7 |
+
<BrowserRouter>
|
| 8 |
+
<Routes>
|
| 9 |
+
<Route path="/" element={<PatientsPage />} />
|
| 10 |
+
<Route path="/chat/:patientId" element={<ChatPage />} />
|
| 11 |
+
</Routes>
|
| 12 |
+
</BrowserRouter>
|
| 13 |
+
);
|
| 14 |
+
}
|
web/src/components/MessageContent.css
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/* ─── Root ───────────────────────────────────────────────────────────────── */
|
| 2 |
+
.mc-root {
|
| 3 |
+
display: flex;
|
| 4 |
+
flex-direction: column;
|
| 5 |
+
gap: 6px;
|
| 6 |
+
font-size: 0.9375rem;
|
| 7 |
+
line-height: 1.6;
|
| 8 |
+
color: var(--gray-800);
|
| 9 |
+
width: 100%;
|
| 10 |
+
}
|
| 11 |
+
|
| 12 |
+
/* ─── Stage header ───────────────────────────────────────────────────────── */
|
| 13 |
+
.mc-stage {
|
| 14 |
+
font-size: 0.75rem;
|
| 15 |
+
font-weight: 600;
|
| 16 |
+
color: var(--primary);
|
| 17 |
+
text-transform: uppercase;
|
| 18 |
+
letter-spacing: 0.06em;
|
| 19 |
+
padding: 8px 0 2px;
|
| 20 |
+
border-top: 1px solid var(--gray-100);
|
| 21 |
+
margin-top: 4px;
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
.mc-stage:first-child {
|
| 25 |
+
border-top: none;
|
| 26 |
+
margin-top: 0;
|
| 27 |
+
padding-top: 0;
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
/* ─── Thinking text ──────────────────────────────────────────────────────── */
|
| 31 |
+
.mc-thinking {
|
| 32 |
+
font-size: 0.8125rem;
|
| 33 |
+
color: var(--gray-500);
|
| 34 |
+
font-style: italic;
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
/* ─── Response block (markdown) ──────────────────────────────────────────── */
|
| 38 |
+
.mc-response {
|
| 39 |
+
color: var(--gray-800);
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
.mc-response p {
|
| 43 |
+
margin: 0 0 8px;
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
.mc-response p:last-child {
|
| 47 |
+
margin-bottom: 0;
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
.mc-response strong {
|
| 51 |
+
font-weight: 600;
|
| 52 |
+
color: var(--gray-900);
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
.mc-response em {
|
| 56 |
+
font-style: italic;
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
.mc-response ul,
|
| 60 |
+
.mc-response ol {
|
| 61 |
+
margin: 4px 0 8px 20px;
|
| 62 |
+
padding: 0;
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
.mc-response li {
|
| 66 |
+
margin-bottom: 2px;
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
.mc-response h1,
|
| 70 |
+
.mc-response h2,
|
| 71 |
+
.mc-response h3,
|
| 72 |
+
.mc-response h4 {
|
| 73 |
+
font-size: 0.9375rem;
|
| 74 |
+
font-weight: 600;
|
| 75 |
+
color: var(--gray-900);
|
| 76 |
+
margin: 10px 0 4px;
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
.mc-response code {
|
| 80 |
+
font-family: monospace;
|
| 81 |
+
font-size: 0.875em;
|
| 82 |
+
background: var(--gray-100);
|
| 83 |
+
padding: 1px 5px;
|
| 84 |
+
border-radius: 4px;
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
/* ─── Tool output (monospace block) ─────────────────────────────────────── */
|
| 88 |
+
.mc-tool-output {
|
| 89 |
+
background: var(--gray-900);
|
| 90 |
+
border-radius: 8px;
|
| 91 |
+
overflow: hidden;
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
.mc-tool-output-label {
|
| 95 |
+
font-size: 0.6875rem;
|
| 96 |
+
font-weight: 600;
|
| 97 |
+
color: var(--gray-400);
|
| 98 |
+
text-transform: uppercase;
|
| 99 |
+
letter-spacing: 0.05em;
|
| 100 |
+
padding: 6px 12px 4px;
|
| 101 |
+
background: rgba(255, 255, 255, 0.05);
|
| 102 |
+
border-bottom: 1px solid rgba(255, 255, 255, 0.08);
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
.mc-tool-output pre {
|
| 106 |
+
margin: 0;
|
| 107 |
+
padding: 10px 12px;
|
| 108 |
+
font-family: 'SF Mono', 'Fira Code', monospace;
|
| 109 |
+
font-size: 0.8rem;
|
| 110 |
+
line-height: 1.5;
|
| 111 |
+
color: #e2e8f0;
|
| 112 |
+
white-space: pre;
|
| 113 |
+
overflow-x: auto;
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
/* ─── Image blocks (GradCAM, comparison) ────────────────────────────────── */
|
| 117 |
+
.mc-image-block {
|
| 118 |
+
margin-top: 4px;
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
.mc-image-label {
|
| 122 |
+
font-size: 0.75rem;
|
| 123 |
+
font-weight: 600;
|
| 124 |
+
color: var(--gray-500);
|
| 125 |
+
text-transform: uppercase;
|
| 126 |
+
letter-spacing: 0.05em;
|
| 127 |
+
margin-bottom: 6px;
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
.mc-gradcam-img {
|
| 131 |
+
width: 100%;
|
| 132 |
+
max-width: 380px;
|
| 133 |
+
border-radius: 10px;
|
| 134 |
+
border: 1px solid var(--gray-200);
|
| 135 |
+
display: block;
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
.mc-comparison-img {
|
| 139 |
+
width: 100%;
|
| 140 |
+
max-width: 560px;
|
| 141 |
+
border-radius: 10px;
|
| 142 |
+
border: 1px solid var(--gray-200);
|
| 143 |
+
display: block;
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
/* ─── GradCAM side-by-side comparison ───────────────────────────────────── */
|
| 147 |
+
.mc-gradcam-compare {
|
| 148 |
+
display: grid;
|
| 149 |
+
grid-template-columns: 1fr 1fr;
|
| 150 |
+
gap: 10px;
|
| 151 |
+
max-width: 560px;
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
.mc-gradcam-compare-item {
|
| 155 |
+
display: flex;
|
| 156 |
+
flex-direction: column;
|
| 157 |
+
gap: 4px;
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
.mc-gradcam-compare-title {
|
| 161 |
+
font-size: 0.75rem;
|
| 162 |
+
font-weight: 600;
|
| 163 |
+
color: var(--gray-600);
|
| 164 |
+
text-align: center;
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
.mc-gradcam-compare-img {
|
| 168 |
+
width: 100%;
|
| 169 |
+
border-radius: 8px;
|
| 170 |
+
border: 1px solid var(--gray-200);
|
| 171 |
+
display: block;
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
/* ─── Result / error / complete / observation ───────────────────────────── */
|
| 175 |
+
.mc-result {
|
| 176 |
+
background: linear-gradient(135deg, #f0fdf4, #dcfce7);
|
| 177 |
+
border: 1px solid #86efac;
|
| 178 |
+
border-radius: 8px;
|
| 179 |
+
padding: 8px 12px;
|
| 180 |
+
font-size: 0.875rem;
|
| 181 |
+
font-weight: 500;
|
| 182 |
+
color: #15803d;
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
.mc-error {
|
| 186 |
+
background: #fef2f2;
|
| 187 |
+
border: 1px solid #fca5a5;
|
| 188 |
+
border-radius: 8px;
|
| 189 |
+
padding: 8px 12px;
|
| 190 |
+
font-size: 0.875rem;
|
| 191 |
+
color: #dc2626;
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
.mc-complete {
|
| 195 |
+
font-size: 0.8rem;
|
| 196 |
+
color: var(--gray-400);
|
| 197 |
+
text-align: right;
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
.mc-observation {
|
| 201 |
+
font-size: 0.875rem;
|
| 202 |
+
color: var(--gray-600);
|
| 203 |
+
font-style: italic;
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
/* ─── Plain streaming text ───────────────────────────────────────────────── */
|
| 207 |
+
.mc-text {
|
| 208 |
+
white-space: pre-wrap;
|
| 209 |
+
word-break: break-word;
|
| 210 |
+
color: var(--gray-700);
|
| 211 |
+
font-size: 0.875rem;
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
/* ─── References ─────────────────────────────────────────────────────────── */
|
| 215 |
+
.mc-references {
|
| 216 |
+
border-top: 1px solid var(--gray-100);
|
| 217 |
+
padding-top: 8px;
|
| 218 |
+
margin-top: 4px;
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
.mc-references-title {
|
| 222 |
+
font-size: 0.75rem;
|
| 223 |
+
font-weight: 600;
|
| 224 |
+
color: var(--gray-500);
|
| 225 |
+
text-transform: uppercase;
|
| 226 |
+
letter-spacing: 0.05em;
|
| 227 |
+
margin-bottom: 4px;
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
.mc-ref-item {
|
| 231 |
+
font-size: 0.8125rem;
|
| 232 |
+
color: var(--gray-600);
|
| 233 |
+
line-height: 1.5;
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
.mc-ref-sup {
|
| 237 |
+
font-size: 0.6875rem;
|
| 238 |
+
vertical-align: super;
|
| 239 |
+
margin-right: 4px;
|
| 240 |
+
color: var(--primary);
|
| 241 |
+
font-weight: 600;
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
.mc-ref-source {
|
| 245 |
+
font-style: italic;
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
.mc-ref-page {
|
| 249 |
+
color: var(--gray-400);
|
| 250 |
+
}
|
web/src/components/MessageContent.tsx
ADDED
|
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import ReactMarkdown from 'react-markdown';
|
| 2 |
+
import './MessageContent.css';
|
| 3 |
+
|
| 4 |
+
// Serve any temp visualization image (GradCAM, comparison) through the API
|
| 5 |
+
const TEMP_IMG_URL = (path: string) =>
|
| 6 |
+
`/api/patients/gradcam?path=${encodeURIComponent(path)}`;
|
| 7 |
+
|
| 8 |
+
// ─── Types ─────────────────────────────────────────────────────────────────
|
| 9 |
+
|
| 10 |
+
type Segment =
|
| 11 |
+
| { type: 'stage'; label: string }
|
| 12 |
+
| { type: 'thinking'; content: string }
|
| 13 |
+
| { type: 'response'; content: string }
|
| 14 |
+
| { type: 'tool_output'; label: string; content: string }
|
| 15 |
+
| { type: 'gradcam'; path: string }
|
| 16 |
+
| { type: 'comparison'; path: string }
|
| 17 |
+
| { type: 'gradcam_compare'; path1: string; path2: string }
|
| 18 |
+
| { type: 'result'; content: string }
|
| 19 |
+
| { type: 'error'; content: string }
|
| 20 |
+
| { type: 'complete'; content: string }
|
| 21 |
+
| { type: 'references'; content: string }
|
| 22 |
+
| { type: 'observation'; content: string }
|
| 23 |
+
| { type: 'text'; content: string };
|
| 24 |
+
|
| 25 |
+
// ─── Parser ────────────────────────────────────────────────────────────────
|
| 26 |
+
|
| 27 |
+
// Splits raw text by all known complete tag patterns (capturing group preserves them)
|
| 28 |
+
const TAG_SPLIT_RE = new RegExp(
|
| 29 |
+
'(' +
|
| 30 |
+
[
|
| 31 |
+
'\\[STAGE:[^\\]]*\\][\\s\\S]*?\\[\\/STAGE\\]',
|
| 32 |
+
'\\[THINKING\\][\\s\\S]*?\\[\\/THINKING\\]',
|
| 33 |
+
'\\[RESPONSE\\][\\s\\S]*?\\[\\/RESPONSE\\]',
|
| 34 |
+
'\\[TOOL_OUTPUT:[^\\]]*\\][\\s\\S]*?\\[\\/TOOL_OUTPUT\\]',
|
| 35 |
+
'\\[GRADCAM_IMAGE:[^\\]]+\\]',
|
| 36 |
+
'\\[COMPARISON_IMAGE:[^\\]]+\\]',
|
| 37 |
+
'\\[GRADCAM_COMPARE:[^:\\]]+:[^\\]]+\\]',
|
| 38 |
+
'\\[RESULT\\][\\s\\S]*?\\[\\/RESULT\\]',
|
| 39 |
+
'\\[ERROR\\][\\s\\S]*?\\[\\/ERROR\\]',
|
| 40 |
+
'\\[COMPLETE\\][\\s\\S]*?\\[\\/COMPLETE\\]',
|
| 41 |
+
'\\[REFERENCES\\][\\s\\S]*?\\[\\/REFERENCES\\]',
|
| 42 |
+
'\\[OBSERVATION\\][\\s\\S]*?\\[\\/OBSERVATION\\]',
|
| 43 |
+
'\\[CONFIRM:[^\\]]*\\][\\s\\S]*?\\[\\/CONFIRM\\]',
|
| 44 |
+
].join('|') +
|
| 45 |
+
')',
|
| 46 |
+
'g',
|
| 47 |
+
);
|
| 48 |
+
|
| 49 |
+
// Strips known opening tags that haven't yet been closed (mid-stream partial content)
|
| 50 |
+
function cleanStreamingText(text: string): string {
|
| 51 |
+
return text.replace(
|
| 52 |
+
/\[(STAGE:[^\]]*|THINKING|RESPONSE|TOOL_OUTPUT:[^\]]*|RESULT|ERROR|COMPLETE|REFERENCES|OBSERVATION|CONFIRM:[^\]]*)\]/g,
|
| 53 |
+
'',
|
| 54 |
+
);
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
function parseContent(raw: string): Segment[] {
|
| 58 |
+
const segments: Segment[] = [];
|
| 59 |
+
|
| 60 |
+
for (const part of raw.split(TAG_SPLIT_RE)) {
|
| 61 |
+
if (!part) continue;
|
| 62 |
+
|
| 63 |
+
let m: RegExpMatchArray | null;
|
| 64 |
+
|
| 65 |
+
if ((m = part.match(/^\[STAGE:([^\]]*)\]([\s\S]*)\[\/STAGE\]$/))) {
|
| 66 |
+
const label = m[2].trim();
|
| 67 |
+
if (label) segments.push({ type: 'stage', label });
|
| 68 |
+
|
| 69 |
+
} else if ((m = part.match(/^\[THINKING\]([\s\S]*)\[\/THINKING\]$/))) {
|
| 70 |
+
const c = m[1].trim();
|
| 71 |
+
if (c) segments.push({ type: 'thinking', content: c });
|
| 72 |
+
|
| 73 |
+
} else if ((m = part.match(/^\[RESPONSE\]([\s\S]*)\[\/RESPONSE\]$/))) {
|
| 74 |
+
const c = m[1].trim();
|
| 75 |
+
if (c) segments.push({ type: 'response', content: c });
|
| 76 |
+
|
| 77 |
+
} else if ((m = part.match(/^\[TOOL_OUTPUT:([^\]]*)\]([\s\S]*)\[\/TOOL_OUTPUT\]$/))) {
|
| 78 |
+
segments.push({ type: 'tool_output', label: m[1], content: m[2] });
|
| 79 |
+
|
| 80 |
+
} else if ((m = part.match(/^\[GRADCAM_IMAGE:([^\]]+)\]$/))) {
|
| 81 |
+
segments.push({ type: 'gradcam', path: m[1] });
|
| 82 |
+
|
| 83 |
+
} else if ((m = part.match(/^\[COMPARISON_IMAGE:([^\]]+)\]$/))) {
|
| 84 |
+
segments.push({ type: 'comparison', path: m[1] });
|
| 85 |
+
|
| 86 |
+
} else if ((m = part.match(/^\[GRADCAM_COMPARE:([^:\]]+):([^\]]+)\]$/))) {
|
| 87 |
+
segments.push({ type: 'gradcam_compare', path1: m[1], path2: m[2] });
|
| 88 |
+
|
| 89 |
+
} else if ((m = part.match(/^\[RESULT\]([\s\S]*)\[\/RESULT\]$/))) {
|
| 90 |
+
const c = m[1].trim();
|
| 91 |
+
if (c) segments.push({ type: 'result', content: c });
|
| 92 |
+
|
| 93 |
+
} else if ((m = part.match(/^\[ERROR\]([\s\S]*)\[\/ERROR\]$/))) {
|
| 94 |
+
const c = m[1].trim();
|
| 95 |
+
if (c) segments.push({ type: 'error', content: c });
|
| 96 |
+
|
| 97 |
+
} else if ((m = part.match(/^\[COMPLETE\]([\s\S]*)\[\/COMPLETE\]$/))) {
|
| 98 |
+
const c = m[1].trim();
|
| 99 |
+
if (c) segments.push({ type: 'complete', content: c });
|
| 100 |
+
|
| 101 |
+
} else if ((m = part.match(/^\[REFERENCES\]([\s\S]*)\[\/REFERENCES\]$/))) {
|
| 102 |
+
segments.push({ type: 'references', content: m[1].trim() });
|
| 103 |
+
|
| 104 |
+
} else if ((m = part.match(/^\[OBSERVATION\]([\s\S]*)\[\/OBSERVATION\]$/))) {
|
| 105 |
+
const c = m[1].trim();
|
| 106 |
+
if (c) segments.push({ type: 'observation', content: c });
|
| 107 |
+
|
| 108 |
+
} else if ((m = part.match(/^\[CONFIRM:[^\]]*\]([\s\S]*)\[\/CONFIRM\]$/))) {
|
| 109 |
+
const c = m[1].trim();
|
| 110 |
+
if (c) segments.push({ type: 'result', content: c });
|
| 111 |
+
|
| 112 |
+
} else {
|
| 113 |
+
// Plain text (may be mid-stream with incomplete opening tags)
|
| 114 |
+
const cleaned = cleanStreamingText(part);
|
| 115 |
+
if (cleaned.trim()) segments.push({ type: 'text', content: cleaned });
|
| 116 |
+
}
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
return segments;
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
// ─── References renderer ──────────────────��────────────────────────────────
|
| 123 |
+
|
| 124 |
+
function References({ content }: { content: string }) {
|
| 125 |
+
const refs = content.match(/\[REF:[^\]]+\]/g) ?? [];
|
| 126 |
+
if (!refs.length) return null;
|
| 127 |
+
|
| 128 |
+
return (
|
| 129 |
+
<div className="mc-references">
|
| 130 |
+
<div className="mc-references-title">References</div>
|
| 131 |
+
{refs.map((ref, i) => {
|
| 132 |
+
// [REF:id:source:page:file:superscript]
|
| 133 |
+
const parts = ref.slice(1, -1).split(':');
|
| 134 |
+
const source = parts[2] ?? '';
|
| 135 |
+
const page = parts[3] ?? '';
|
| 136 |
+
const sup = parts[5] ?? `[${i + 1}]`;
|
| 137 |
+
return (
|
| 138 |
+
<div key={i} className="mc-ref-item">
|
| 139 |
+
<span className="mc-ref-sup">{sup}</span>
|
| 140 |
+
<span className="mc-ref-source">{source}</span>
|
| 141 |
+
{page && <span className="mc-ref-page">, p.{page}</span>}
|
| 142 |
+
</div>
|
| 143 |
+
);
|
| 144 |
+
})}
|
| 145 |
+
</div>
|
| 146 |
+
);
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
// ─── Main component ────────────────────────────────────────────────────────
|
| 150 |
+
|
| 151 |
+
export function MessageContent({ text }: { text: string }) {
|
| 152 |
+
const segments = parseContent(text);
|
| 153 |
+
|
| 154 |
+
return (
|
| 155 |
+
<div className="mc-root">
|
| 156 |
+
{segments.map((seg, i) => {
|
| 157 |
+
switch (seg.type) {
|
| 158 |
+
case 'stage':
|
| 159 |
+
return <div key={i} className="mc-stage">{seg.label}</div>;
|
| 160 |
+
|
| 161 |
+
case 'thinking':
|
| 162 |
+
return <div key={i} className="mc-thinking">{seg.content}</div>;
|
| 163 |
+
|
| 164 |
+
case 'response':
|
| 165 |
+
return (
|
| 166 |
+
<div key={i} className="mc-response">
|
| 167 |
+
<ReactMarkdown>{seg.content}</ReactMarkdown>
|
| 168 |
+
</div>
|
| 169 |
+
);
|
| 170 |
+
|
| 171 |
+
case 'tool_output':
|
| 172 |
+
return (
|
| 173 |
+
<div key={i} className="mc-tool-output">
|
| 174 |
+
{seg.label && <div className="mc-tool-output-label">{seg.label}</div>}
|
| 175 |
+
<pre>{seg.content}</pre>
|
| 176 |
+
</div>
|
| 177 |
+
);
|
| 178 |
+
|
| 179 |
+
case 'gradcam':
|
| 180 |
+
return (
|
| 181 |
+
<div key={i} className="mc-image-block">
|
| 182 |
+
<div className="mc-image-label">Grad-CAM Attention Map</div>
|
| 183 |
+
<img
|
| 184 |
+
src={TEMP_IMG_URL(seg.path)}
|
| 185 |
+
className="mc-gradcam-img"
|
| 186 |
+
alt="Grad-CAM attention map"
|
| 187 |
+
/>
|
| 188 |
+
</div>
|
| 189 |
+
);
|
| 190 |
+
|
| 191 |
+
case 'comparison':
|
| 192 |
+
return (
|
| 193 |
+
<div key={i} className="mc-image-block">
|
| 194 |
+
<div className="mc-image-label">Lesion Comparison</div>
|
| 195 |
+
<img
|
| 196 |
+
src={TEMP_IMG_URL(seg.path)}
|
| 197 |
+
className="mc-comparison-img"
|
| 198 |
+
alt="Side-by-side lesion comparison"
|
| 199 |
+
/>
|
| 200 |
+
</div>
|
| 201 |
+
);
|
| 202 |
+
|
| 203 |
+
case 'gradcam_compare':
|
| 204 |
+
return (
|
| 205 |
+
<div key={i} className="mc-image-block">
|
| 206 |
+
<div className="mc-image-label">Grad-CAM Comparison</div>
|
| 207 |
+
<div className="mc-gradcam-compare">
|
| 208 |
+
<div className="mc-gradcam-compare-item">
|
| 209 |
+
<div className="mc-gradcam-compare-title">Previous</div>
|
| 210 |
+
<img
|
| 211 |
+
src={TEMP_IMG_URL(seg.path1)}
|
| 212 |
+
className="mc-gradcam-compare-img"
|
| 213 |
+
alt="Previous GradCAM"
|
| 214 |
+
/>
|
| 215 |
+
</div>
|
| 216 |
+
<div className="mc-gradcam-compare-item">
|
| 217 |
+
<div className="mc-gradcam-compare-title">Current</div>
|
| 218 |
+
<img
|
| 219 |
+
src={TEMP_IMG_URL(seg.path2)}
|
| 220 |
+
className="mc-gradcam-compare-img"
|
| 221 |
+
alt="Current GradCAM"
|
| 222 |
+
/>
|
| 223 |
+
</div>
|
| 224 |
+
</div>
|
| 225 |
+
</div>
|
| 226 |
+
);
|
| 227 |
+
|
| 228 |
+
case 'result':
|
| 229 |
+
return <div key={i} className="mc-result">{seg.content}</div>;
|
| 230 |
+
|
| 231 |
+
case 'error':
|
| 232 |
+
return <div key={i} className="mc-error">{seg.content}</div>;
|
| 233 |
+
|
| 234 |
+
case 'complete':
|
| 235 |
+
return <div key={i} className="mc-complete">{seg.content}</div>;
|
| 236 |
+
|
| 237 |
+
case 'references':
|
| 238 |
+
return <References key={i} content={seg.content} />;
|
| 239 |
+
|
| 240 |
+
case 'observation':
|
| 241 |
+
return <div key={i} className="mc-observation">{seg.content}</div>;
|
| 242 |
+
|
| 243 |
+
case 'text':
|
| 244 |
+
return seg.content.trim() ? (
|
| 245 |
+
<div key={i} className="mc-text">{seg.content}</div>
|
| 246 |
+
) : null;
|
| 247 |
+
|
| 248 |
+
default:
|
| 249 |
+
return null;
|
| 250 |
+
}
|
| 251 |
+
})}
|
| 252 |
+
</div>
|
| 253 |
+
);
|
| 254 |
+
}
|
web/src/components/ToolCallCard.css
ADDED
|
@@ -0,0 +1,338 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/* ─── Card container ─────────────────────────────────────────────────────── */
|
| 2 |
+
.tool-card {
|
| 3 |
+
border: 1px solid var(--gray-200);
|
| 4 |
+
border-left: 3px solid var(--primary);
|
| 5 |
+
border-radius: 10px;
|
| 6 |
+
overflow: hidden;
|
| 7 |
+
background: var(--gray-50);
|
| 8 |
+
margin-top: 8px;
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
.tool-card.loading {
|
| 12 |
+
border-left-color: var(--gray-400);
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
.tool-card.error {
|
| 16 |
+
border-left-color: #ef4444;
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
/* ─── Header (collapsed row) ─────────────────────────────────────────────── */
|
| 20 |
+
.tool-card-header {
|
| 21 |
+
width: 100%;
|
| 22 |
+
display: flex;
|
| 23 |
+
align-items: center;
|
| 24 |
+
gap: 8px;
|
| 25 |
+
padding: 10px 14px;
|
| 26 |
+
background: transparent;
|
| 27 |
+
border: none;
|
| 28 |
+
cursor: pointer;
|
| 29 |
+
text-align: left;
|
| 30 |
+
transition: background 0.15s;
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
.tool-card-header:hover:not(:disabled) {
|
| 34 |
+
background: var(--gray-100);
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
.tool-card-header:disabled {
|
| 38 |
+
cursor: default;
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
.tool-icon {
|
| 42 |
+
font-size: 1rem;
|
| 43 |
+
flex-shrink: 0;
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
.tool-label {
|
| 47 |
+
flex: 1;
|
| 48 |
+
font-size: 0.875rem;
|
| 49 |
+
font-weight: 500;
|
| 50 |
+
color: var(--gray-700);
|
| 51 |
+
text-transform: capitalize;
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
.tool-status {
|
| 55 |
+
font-size: 0.8125rem;
|
| 56 |
+
flex-shrink: 0;
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
.tool-status.done {
|
| 60 |
+
color: var(--success, #22c55e);
|
| 61 |
+
font-weight: 600;
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
.tool-status.calling {
|
| 65 |
+
color: var(--gray-500);
|
| 66 |
+
display: flex;
|
| 67 |
+
align-items: center;
|
| 68 |
+
gap: 5px;
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
.tool-status.error-text {
|
| 72 |
+
color: #ef4444;
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
.tool-header-summary {
|
| 76 |
+
font-size: 0.8125rem;
|
| 77 |
+
color: var(--gray-500);
|
| 78 |
+
font-weight: 400;
|
| 79 |
+
white-space: nowrap;
|
| 80 |
+
overflow: hidden;
|
| 81 |
+
text-overflow: ellipsis;
|
| 82 |
+
max-width: 200px;
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
.tool-chevron {
|
| 86 |
+
font-size: 0.625rem;
|
| 87 |
+
color: var(--gray-400);
|
| 88 |
+
margin-left: 2px;
|
| 89 |
+
flex-shrink: 0;
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
/* ─── Spinner ────────────────────────────────────────────────────────────── */
|
| 93 |
+
.spinner {
|
| 94 |
+
display: inline-block;
|
| 95 |
+
width: 12px;
|
| 96 |
+
height: 12px;
|
| 97 |
+
border: 2px solid var(--gray-300);
|
| 98 |
+
border-top-color: var(--gray-600);
|
| 99 |
+
border-radius: 50%;
|
| 100 |
+
animation: spin 0.8s linear infinite;
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
@keyframes spin {
|
| 104 |
+
to { transform: rotate(360deg); }
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
/* ─── Card body ──────────────────────────────────────────────────────────── */
|
| 108 |
+
.tool-card-body {
|
| 109 |
+
padding: 14px;
|
| 110 |
+
border-top: 1px solid var(--gray-200);
|
| 111 |
+
background: white;
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
/* ─── analyze_image ──────────────────────────────────────────────────────── */
|
| 115 |
+
.analyze-result {
|
| 116 |
+
display: flex;
|
| 117 |
+
flex-direction: column;
|
| 118 |
+
gap: 12px;
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
.analyze-top {
|
| 122 |
+
display: flex;
|
| 123 |
+
gap: 14px;
|
| 124 |
+
align-items: flex-start;
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
.analyze-thumb {
|
| 128 |
+
width: 72px;
|
| 129 |
+
height: 72px;
|
| 130 |
+
object-fit: cover;
|
| 131 |
+
border-radius: 8px;
|
| 132 |
+
border: 1px solid var(--gray-200);
|
| 133 |
+
flex-shrink: 0;
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
.analyze-info {
|
| 137 |
+
flex: 1;
|
| 138 |
+
min-width: 0;
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
.diagnosis-name {
|
| 142 |
+
font-size: 0.9375rem;
|
| 143 |
+
font-weight: 600;
|
| 144 |
+
color: var(--gray-900);
|
| 145 |
+
margin: 0 0 4px;
|
| 146 |
+
line-height: 1.3;
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
.confidence-label {
|
| 150 |
+
font-size: 0.8125rem;
|
| 151 |
+
font-weight: 500;
|
| 152 |
+
margin: 0 0 6px;
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
.confidence-bar-track {
|
| 156 |
+
height: 6px;
|
| 157 |
+
background: var(--gray-200);
|
| 158 |
+
border-radius: 999px;
|
| 159 |
+
overflow: hidden;
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
.confidence-bar-fill {
|
| 163 |
+
height: 100%;
|
| 164 |
+
border-radius: 999px;
|
| 165 |
+
transition: width 0.3s ease;
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
.analyze-summary {
|
| 169 |
+
font-size: 0.875rem;
|
| 170 |
+
color: var(--gray-700);
|
| 171 |
+
line-height: 1.6;
|
| 172 |
+
margin: 0;
|
| 173 |
+
border-top: 1px solid var(--gray-100);
|
| 174 |
+
padding-top: 10px;
|
| 175 |
+
white-space: pre-wrap;
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
.other-predictions {
|
| 179 |
+
list-style: none;
|
| 180 |
+
padding: 0;
|
| 181 |
+
margin: 0;
|
| 182 |
+
display: flex;
|
| 183 |
+
flex-direction: column;
|
| 184 |
+
gap: 6px;
|
| 185 |
+
border-top: 1px solid var(--gray-100);
|
| 186 |
+
padding-top: 10px;
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
.prediction-row {
|
| 190 |
+
display: flex;
|
| 191 |
+
justify-content: space-between;
|
| 192 |
+
font-size: 0.8125rem;
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
.pred-name {
|
| 196 |
+
color: var(--gray-600);
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
.pred-pct {
|
| 200 |
+
color: var(--gray-500);
|
| 201 |
+
font-variant-numeric: tabular-nums;
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
/* ─── compare_images ─────────────────────────────────────────────────────── */
|
| 205 |
+
.compare-result {
|
| 206 |
+
display: flex;
|
| 207 |
+
flex-direction: column;
|
| 208 |
+
gap: 12px;
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
.carousel {
|
| 212 |
+
position: relative;
|
| 213 |
+
display: flex;
|
| 214 |
+
align-items: center;
|
| 215 |
+
justify-content: center;
|
| 216 |
+
gap: 8px;
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
.carousel-image-wrap {
|
| 220 |
+
position: relative;
|
| 221 |
+
display: inline-block;
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
.carousel-image {
|
| 225 |
+
width: 200px;
|
| 226 |
+
height: 160px;
|
| 227 |
+
object-fit: cover;
|
| 228 |
+
border-radius: 10px;
|
| 229 |
+
border: 1px solid var(--gray-200);
|
| 230 |
+
display: block;
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
.carousel-label {
|
| 234 |
+
position: absolute;
|
| 235 |
+
bottom: 8px;
|
| 236 |
+
left: 50%;
|
| 237 |
+
transform: translateX(-50%);
|
| 238 |
+
background: rgba(0, 0, 0, 0.55);
|
| 239 |
+
color: white;
|
| 240 |
+
font-size: 0.75rem;
|
| 241 |
+
font-weight: 600;
|
| 242 |
+
padding: 3px 10px;
|
| 243 |
+
border-radius: 999px;
|
| 244 |
+
white-space: nowrap;
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
.carousel-btn {
|
| 248 |
+
background: white;
|
| 249 |
+
border: 1px solid var(--gray-300);
|
| 250 |
+
border-radius: 6px;
|
| 251 |
+
width: 28px;
|
| 252 |
+
height: 28px;
|
| 253 |
+
cursor: pointer;
|
| 254 |
+
font-size: 0.75rem;
|
| 255 |
+
color: var(--gray-600);
|
| 256 |
+
display: flex;
|
| 257 |
+
align-items: center;
|
| 258 |
+
justify-content: center;
|
| 259 |
+
flex-shrink: 0;
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
.carousel-btn:hover {
|
| 263 |
+
background: var(--gray-100);
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
.carousel-dots {
|
| 267 |
+
position: absolute;
|
| 268 |
+
bottom: -18px;
|
| 269 |
+
left: 50%;
|
| 270 |
+
transform: translateX(-50%);
|
| 271 |
+
display: flex;
|
| 272 |
+
gap: 5px;
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
.carousel-dot {
|
| 276 |
+
width: 6px;
|
| 277 |
+
height: 6px;
|
| 278 |
+
border-radius: 50%;
|
| 279 |
+
background: var(--gray-300);
|
| 280 |
+
cursor: pointer;
|
| 281 |
+
transition: background 0.15s;
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
.carousel-dot.active {
|
| 285 |
+
background: var(--primary);
|
| 286 |
+
}
|
| 287 |
+
|
| 288 |
+
.compare-status {
|
| 289 |
+
font-size: 0.9375rem;
|
| 290 |
+
margin-top: 6px;
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
.feature-changes {
|
| 294 |
+
list-style: none;
|
| 295 |
+
padding: 0;
|
| 296 |
+
margin: 0;
|
| 297 |
+
display: flex;
|
| 298 |
+
flex-direction: column;
|
| 299 |
+
gap: 6px;
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
.feature-row {
|
| 303 |
+
display: flex;
|
| 304 |
+
justify-content: space-between;
|
| 305 |
+
font-size: 0.8125rem;
|
| 306 |
+
}
|
| 307 |
+
|
| 308 |
+
.feature-name {
|
| 309 |
+
color: var(--gray-600);
|
| 310 |
+
text-transform: capitalize;
|
| 311 |
+
}
|
| 312 |
+
|
| 313 |
+
.feature-delta {
|
| 314 |
+
font-variant-numeric: tabular-nums;
|
| 315 |
+
font-weight: 500;
|
| 316 |
+
}
|
| 317 |
+
|
| 318 |
+
.compare-summary {
|
| 319 |
+
font-size: 0.875rem;
|
| 320 |
+
color: var(--gray-600);
|
| 321 |
+
line-height: 1.5;
|
| 322 |
+
margin: 0;
|
| 323 |
+
border-top: 1px solid var(--gray-100);
|
| 324 |
+
padding-top: 10px;
|
| 325 |
+
}
|
| 326 |
+
|
| 327 |
+
/* ─── Generic fallback ───────────────────────────────────────────────────── */
|
| 328 |
+
.generic-result {
|
| 329 |
+
font-size: 0.75rem;
|
| 330 |
+
background: var(--gray-50);
|
| 331 |
+
border-radius: 6px;
|
| 332 |
+
padding: 10px;
|
| 333 |
+
overflow-x: auto;
|
| 334 |
+
color: var(--gray-700);
|
| 335 |
+
margin: 0;
|
| 336 |
+
white-space: pre-wrap;
|
| 337 |
+
word-break: break-all;
|
| 338 |
+
}
|
web/src/components/ToolCallCard.tsx
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { useEffect, useState } from 'react';
|
| 2 |
+
import { ToolCall } from '../types';
|
| 3 |
+
import './ToolCallCard.css';
|
| 4 |
+
|
| 5 |
+
interface ToolCallCardProps {
|
| 6 |
+
toolCall: ToolCall;
|
| 7 |
+
}
|
| 8 |
+
|
| 9 |
+
/** One-line summary shown in the collapsed header so results are visible at a glance */
|
| 10 |
+
function CollapsedSummary({ toolCall }: { toolCall: ToolCall }) {
|
| 11 |
+
const r = toolCall.result;
|
| 12 |
+
if (!r) return null;
|
| 13 |
+
|
| 14 |
+
if (toolCall.tool === 'analyze_image') {
|
| 15 |
+
const name = r.full_name ?? r.diagnosis;
|
| 16 |
+
const pct = r.confidence != null ? `${Math.round(r.confidence * 100)}%` : null;
|
| 17 |
+
if (name) return (
|
| 18 |
+
<span className="tool-header-summary">
|
| 19 |
+
{name}{pct ? ` — ${pct}` : ''}
|
| 20 |
+
</span>
|
| 21 |
+
);
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
if (toolCall.tool === 'compare_images') {
|
| 25 |
+
const key = r.status_label ?? 'STABLE';
|
| 26 |
+
const cfg = STATUS_CONFIG[key] ?? { emoji: '⚪', label: key };
|
| 27 |
+
return (
|
| 28 |
+
<span className="tool-header-summary">
|
| 29 |
+
{cfg.emoji} {cfg.label}
|
| 30 |
+
</span>
|
| 31 |
+
);
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
return null;
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
export function ToolCallCard({ toolCall }: ToolCallCardProps) {
|
| 38 |
+
// Auto-expand when the tool completes so results are immediately visible.
|
| 39 |
+
// User can collapse manually afterwards.
|
| 40 |
+
const [expanded, setExpanded] = useState(false);
|
| 41 |
+
|
| 42 |
+
useEffect(() => {
|
| 43 |
+
if (toolCall.status === 'complete') setExpanded(true);
|
| 44 |
+
}, [toolCall.status]);
|
| 45 |
+
|
| 46 |
+
const isLoading = toolCall.status === 'calling';
|
| 47 |
+
const isError = toolCall.status === 'error';
|
| 48 |
+
|
| 49 |
+
const icon = toolCall.tool === 'compare_images' ? '🔄' : '🔬';
|
| 50 |
+
const label = toolCall.tool.replace(/_/g, ' ');
|
| 51 |
+
|
| 52 |
+
return (
|
| 53 |
+
<div className={`tool-card ${isLoading ? 'loading' : ''} ${isError ? 'error' : ''}`}>
|
| 54 |
+
<button
|
| 55 |
+
className="tool-card-header"
|
| 56 |
+
onClick={() => !isLoading && setExpanded(e => !e)}
|
| 57 |
+
disabled={isLoading}
|
| 58 |
+
>
|
| 59 |
+
<span className="tool-icon">{icon}</span>
|
| 60 |
+
<span className="tool-label">{label}</span>
|
| 61 |
+
{isLoading ? (
|
| 62 |
+
<span className="tool-status calling">
|
| 63 |
+
<span className="spinner" /> running…
|
| 64 |
+
</span>
|
| 65 |
+
) : isError ? (
|
| 66 |
+
<span className="tool-status error-text">error</span>
|
| 67 |
+
) : (
|
| 68 |
+
<>
|
| 69 |
+
<span className="tool-status done">✓</span>
|
| 70 |
+
{!expanded && <CollapsedSummary toolCall={toolCall} />}
|
| 71 |
+
</>
|
| 72 |
+
)}
|
| 73 |
+
{!isLoading && (
|
| 74 |
+
<span className="tool-chevron">{expanded ? '▲' : '▼'}</span>
|
| 75 |
+
)}
|
| 76 |
+
</button>
|
| 77 |
+
|
| 78 |
+
{expanded && !isLoading && toolCall.result && (
|
| 79 |
+
<div className="tool-card-body">
|
| 80 |
+
{toolCall.tool === 'analyze_image' && (
|
| 81 |
+
<AnalyzeImageResult result={toolCall.result} />
|
| 82 |
+
)}
|
| 83 |
+
{toolCall.tool === 'compare_images' && (
|
| 84 |
+
<CompareImagesResult result={toolCall.result} />
|
| 85 |
+
)}
|
| 86 |
+
{toolCall.tool !== 'analyze_image' && toolCall.tool !== 'compare_images' && (
|
| 87 |
+
<GenericResult result={toolCall.result} />
|
| 88 |
+
)}
|
| 89 |
+
</div>
|
| 90 |
+
)}
|
| 91 |
+
</div>
|
| 92 |
+
);
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
/* ─── analyze_image renderer ─────────────────────────────────────────────── */
|
| 96 |
+
|
| 97 |
+
function AnalyzeImageResult({ result }: { result: ToolCall['result'] }) {
|
| 98 |
+
if (!result) return null;
|
| 99 |
+
|
| 100 |
+
const hasClassifier = result.diagnosis != null;
|
| 101 |
+
const topPrediction = result.all_predictions?.[0];
|
| 102 |
+
const otherPredictions = result.all_predictions?.slice(1) ?? [];
|
| 103 |
+
const confidence = result.confidence ?? topPrediction?.probability ?? 0;
|
| 104 |
+
const pct = Math.round(confidence * 100);
|
| 105 |
+
const statusColor = pct >= 70 ? '#ef4444' : pct >= 40 ? '#f59e0b' : '#22c55e';
|
| 106 |
+
|
| 107 |
+
return (
|
| 108 |
+
<div className="analyze-result">
|
| 109 |
+
<div className="analyze-top">
|
| 110 |
+
{result.image_url && (
|
| 111 |
+
<img
|
| 112 |
+
src={result.image_url}
|
| 113 |
+
alt="Analyzed lesion"
|
| 114 |
+
className="analyze-thumb"
|
| 115 |
+
/>
|
| 116 |
+
)}
|
| 117 |
+
<div className="analyze-info">
|
| 118 |
+
{hasClassifier ? (
|
| 119 |
+
<>
|
| 120 |
+
<p className="diagnosis-name">{result.full_name ?? result.diagnosis}</p>
|
| 121 |
+
<p className="confidence-label" style={{ color: statusColor }}>
|
| 122 |
+
Confidence: {pct}%
|
| 123 |
+
</p>
|
| 124 |
+
<div className="confidence-bar-track">
|
| 125 |
+
<div
|
| 126 |
+
className="confidence-bar-fill"
|
| 127 |
+
style={{ width: `${pct}%`, background: statusColor }}
|
| 128 |
+
/>
|
| 129 |
+
</div>
|
| 130 |
+
</>
|
| 131 |
+
) : (
|
| 132 |
+
<p className="diagnosis-name" style={{ color: 'var(--gray-500)', fontWeight: 400, fontSize: '0.875rem' }}>
|
| 133 |
+
Visual assessment complete — classifier unavailable
|
| 134 |
+
</p>
|
| 135 |
+
)}
|
| 136 |
+
</div>
|
| 137 |
+
</div>
|
| 138 |
+
|
| 139 |
+
{hasClassifier && otherPredictions.length > 0 && (
|
| 140 |
+
<ul className="other-predictions">
|
| 141 |
+
{otherPredictions.map(p => (
|
| 142 |
+
<li key={p.class} className="prediction-row">
|
| 143 |
+
<span className="pred-name">{p.full_name ?? p.class}</span>
|
| 144 |
+
<span className="pred-pct">{Math.round(p.probability * 100)}%</span>
|
| 145 |
+
</li>
|
| 146 |
+
))}
|
| 147 |
+
</ul>
|
| 148 |
+
)}
|
| 149 |
+
</div>
|
| 150 |
+
);
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
/* ─── compare_images renderer ────────────────────────────────────────────── */
|
| 154 |
+
|
| 155 |
+
const STATUS_CONFIG: Record<string, { label: string; color: string; emoji: string }> = {
|
| 156 |
+
STABLE: { label: 'Stable', color: '#22c55e', emoji: '🟢' },
|
| 157 |
+
MINOR_CHANGE: { label: 'Minor Change', color: '#f59e0b', emoji: '🟡' },
|
| 158 |
+
SIGNIFICANT_CHANGE: { label: 'Significant Change', color: '#ef4444', emoji: '🔴' },
|
| 159 |
+
IMPROVED: { label: 'Improved', color: '#3b82f6', emoji: '🔵' },
|
| 160 |
+
};
|
| 161 |
+
|
| 162 |
+
function CompareImagesResult({ result }: { result: ToolCall['result'] }) {
|
| 163 |
+
if (!result) return null;
|
| 164 |
+
|
| 165 |
+
const statusKey = result.status_label ?? 'STABLE';
|
| 166 |
+
const status = STATUS_CONFIG[statusKey] ?? { label: statusKey, color: '#6b7280', emoji: '⚪' };
|
| 167 |
+
const featureChanges = Object.entries(result.feature_changes ?? {});
|
| 168 |
+
|
| 169 |
+
return (
|
| 170 |
+
<div className="compare-result">
|
| 171 |
+
<div className="compare-status" style={{ color: status.color }}>
|
| 172 |
+
<strong>Status: {status.label} {status.emoji}</strong>
|
| 173 |
+
</div>
|
| 174 |
+
|
| 175 |
+
{featureChanges.length > 0 && (
|
| 176 |
+
<ul className="feature-changes">
|
| 177 |
+
{featureChanges.map(([name, vals]) => {
|
| 178 |
+
const delta = vals.curr - vals.prev;
|
| 179 |
+
const sign = delta > 0 ? '+' : '';
|
| 180 |
+
return (
|
| 181 |
+
<li key={name} className="feature-row">
|
| 182 |
+
<span className="feature-name">{name}</span>
|
| 183 |
+
<span className="feature-delta" style={{ color: Math.abs(delta) > 0.1 ? '#f59e0b' : '#6b7280' }}>
|
| 184 |
+
{sign}{(delta * 100).toFixed(1)}%
|
| 185 |
+
</span>
|
| 186 |
+
</li>
|
| 187 |
+
);
|
| 188 |
+
})}
|
| 189 |
+
</ul>
|
| 190 |
+
)}
|
| 191 |
+
|
| 192 |
+
{result.summary && (
|
| 193 |
+
<p className="compare-summary">{result.summary}</p>
|
| 194 |
+
)}
|
| 195 |
+
</div>
|
| 196 |
+
);
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
/* ─── Generic (unknown tool) renderer ───────────────────────────────────── */
|
| 200 |
+
|
| 201 |
+
function GenericResult({ result }: { result: ToolCall['result'] }) {
|
| 202 |
+
return (
|
| 203 |
+
<pre className="generic-result">
|
| 204 |
+
{JSON.stringify(result, null, 2)}
|
| 205 |
+
</pre>
|
| 206 |
+
);
|
| 207 |
+
}
|
web/src/index.css
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
* {
|
| 2 |
+
margin: 0;
|
| 3 |
+
padding: 0;
|
| 4 |
+
box-sizing: border-box;
|
| 5 |
+
}
|
| 6 |
+
|
| 7 |
+
body {
|
| 8 |
+
font-family: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
| 9 |
+
background: #f8fafc;
|
| 10 |
+
color: #1e293b;
|
| 11 |
+
line-height: 1.5;
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
button {
|
| 15 |
+
font-family: inherit;
|
| 16 |
+
cursor: pointer;
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
input {
|
| 20 |
+
font-family: inherit;
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
:root {
|
| 24 |
+
--primary: #6366f1;
|
| 25 |
+
--primary-hover: #4f46e5;
|
| 26 |
+
--success: #10b981;
|
| 27 |
+
--error: #ef4444;
|
| 28 |
+
--gray-50: #f8fafc;
|
| 29 |
+
--gray-100: #f1f5f9;
|
| 30 |
+
--gray-200: #e2e8f0;
|
| 31 |
+
--gray-300: #cbd5e1;
|
| 32 |
+
--gray-400: #94a3b8;
|
| 33 |
+
--gray-500: #64748b;
|
| 34 |
+
--gray-600: #475569;
|
| 35 |
+
--gray-700: #334155;
|
| 36 |
+
--gray-800: #1e293b;
|
| 37 |
+
--gray-900: #0f172a;
|
| 38 |
+
}
|
web/src/main.tsx
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import React from 'react'
|
| 2 |
+
import ReactDOM from 'react-dom/client'
|
| 3 |
+
import { App } from './App'
|
| 4 |
+
import './index.css'
|
| 5 |
+
|
| 6 |
+
ReactDOM.createRoot(document.getElementById('root')!).render(
|
| 7 |
+
<React.StrictMode>
|
| 8 |
+
<App />
|
| 9 |
+
</React.StrictMode>,
|
| 10 |
+
)
|
web/src/pages/ChatPage.css
ADDED
|
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/* ─── Layout ───────────────────────────────────────────────────────────── */
|
| 2 |
+
.chat-page {
|
| 3 |
+
display: flex;
|
| 4 |
+
flex-direction: column;
|
| 5 |
+
height: 100vh;
|
| 6 |
+
background: var(--gray-50);
|
| 7 |
+
overflow: hidden;
|
| 8 |
+
}
|
| 9 |
+
|
| 10 |
+
/* ─── Header ────────────────────────────────────────────────────────────── */
|
| 11 |
+
.chat-header {
|
| 12 |
+
display: flex;
|
| 13 |
+
align-items: center;
|
| 14 |
+
gap: 12px;
|
| 15 |
+
padding: 0 16px;
|
| 16 |
+
height: 56px;
|
| 17 |
+
background: white;
|
| 18 |
+
border-bottom: 1px solid var(--gray-200);
|
| 19 |
+
flex-shrink: 0;
|
| 20 |
+
z-index: 10;
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
.header-back-btn {
|
| 24 |
+
width: 36px;
|
| 25 |
+
height: 36px;
|
| 26 |
+
display: flex;
|
| 27 |
+
align-items: center;
|
| 28 |
+
justify-content: center;
|
| 29 |
+
border: none;
|
| 30 |
+
background: transparent;
|
| 31 |
+
cursor: pointer;
|
| 32 |
+
color: var(--gray-600);
|
| 33 |
+
border-radius: 8px;
|
| 34 |
+
transition: background 0.15s;
|
| 35 |
+
flex-shrink: 0;
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
.header-back-btn:hover {
|
| 39 |
+
background: var(--gray-100);
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
.header-back-btn svg {
|
| 43 |
+
width: 20px;
|
| 44 |
+
height: 20px;
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
.header-center {
|
| 48 |
+
flex: 1;
|
| 49 |
+
display: flex;
|
| 50 |
+
flex-direction: column;
|
| 51 |
+
min-width: 0;
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
.header-app-name {
|
| 55 |
+
font-size: 0.7rem;
|
| 56 |
+
font-weight: 600;
|
| 57 |
+
color: var(--primary);
|
| 58 |
+
text-transform: uppercase;
|
| 59 |
+
letter-spacing: 0.05em;
|
| 60 |
+
line-height: 1;
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
.header-patient-name {
|
| 64 |
+
font-size: 1rem;
|
| 65 |
+
font-weight: 600;
|
| 66 |
+
color: var(--gray-900);
|
| 67 |
+
white-space: nowrap;
|
| 68 |
+
overflow: hidden;
|
| 69 |
+
text-overflow: ellipsis;
|
| 70 |
+
line-height: 1.3;
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
.header-clear-btn {
|
| 74 |
+
border: 1px solid var(--gray-300);
|
| 75 |
+
background: transparent;
|
| 76 |
+
border-radius: 8px;
|
| 77 |
+
padding: 6px 14px;
|
| 78 |
+
font-size: 0.8125rem;
|
| 79 |
+
color: var(--gray-600);
|
| 80 |
+
cursor: pointer;
|
| 81 |
+
transition: all 0.15s;
|
| 82 |
+
flex-shrink: 0;
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
.header-clear-btn:hover {
|
| 86 |
+
background: var(--gray-100);
|
| 87 |
+
border-color: var(--gray-400);
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
/* ─── Messages ──────────────────────────────────────────────────────────── */
|
| 91 |
+
.chat-messages {
|
| 92 |
+
flex: 1;
|
| 93 |
+
overflow-y: auto;
|
| 94 |
+
padding: 20px 16px;
|
| 95 |
+
display: flex;
|
| 96 |
+
flex-direction: column;
|
| 97 |
+
gap: 12px;
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
.chat-empty {
|
| 101 |
+
flex: 1;
|
| 102 |
+
display: flex;
|
| 103 |
+
flex-direction: column;
|
| 104 |
+
align-items: center;
|
| 105 |
+
justify-content: center;
|
| 106 |
+
color: var(--gray-400);
|
| 107 |
+
text-align: center;
|
| 108 |
+
gap: 12px;
|
| 109 |
+
margin: auto;
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
.chat-empty-icon svg {
|
| 113 |
+
width: 40px;
|
| 114 |
+
height: 40px;
|
| 115 |
+
color: var(--gray-300);
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
.chat-empty p {
|
| 119 |
+
font-size: 0.9375rem;
|
| 120 |
+
max-width: 280px;
|
| 121 |
+
line-height: 1.5;
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
.message-row {
|
| 125 |
+
display: flex;
|
| 126 |
+
max-width: 720px;
|
| 127 |
+
width: 100%;
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
.message-row.user {
|
| 131 |
+
align-self: flex-end;
|
| 132 |
+
justify-content: flex-end;
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
.message-row.assistant {
|
| 136 |
+
align-self: flex-start;
|
| 137 |
+
justify-content: flex-start;
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
/* ─── Bubbles ────────────────────────────────────────────────────────────── */
|
| 141 |
+
.bubble {
|
| 142 |
+
max-width: 85%;
|
| 143 |
+
border-radius: 16px;
|
| 144 |
+
padding: 12px 16px;
|
| 145 |
+
display: flex;
|
| 146 |
+
flex-direction: column;
|
| 147 |
+
gap: 8px;
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
.user-bubble {
|
| 151 |
+
background: var(--primary);
|
| 152 |
+
color: white;
|
| 153 |
+
border-bottom-right-radius: 4px;
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
.assistant-bubble {
|
| 157 |
+
background: white;
|
| 158 |
+
border: 1px solid var(--gray-200);
|
| 159 |
+
border-bottom-left-radius: 4px;
|
| 160 |
+
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.06);
|
| 161 |
+
max-width: 90%;
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
.bubble-text {
|
| 165 |
+
font-size: 0.9375rem;
|
| 166 |
+
line-height: 1.6;
|
| 167 |
+
white-space: pre-wrap;
|
| 168 |
+
word-break: break-word;
|
| 169 |
+
margin: 0;
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
.user-bubble .bubble-text {
|
| 173 |
+
color: white;
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
.assistant-text {
|
| 177 |
+
color: var(--gray-800);
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
/* Image in user bubble */
|
| 181 |
+
.message-image {
|
| 182 |
+
width: 100%;
|
| 183 |
+
max-width: 260px;
|
| 184 |
+
border-radius: 10px;
|
| 185 |
+
display: block;
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
/* ─── Thinking indicator ─────────────────────────────────────────────────── */
|
| 189 |
+
.thinking {
|
| 190 |
+
display: flex;
|
| 191 |
+
gap: 4px;
|
| 192 |
+
padding: 4px 0;
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
.dot {
|
| 196 |
+
width: 7px;
|
| 197 |
+
height: 7px;
|
| 198 |
+
background: var(--gray-400);
|
| 199 |
+
border-radius: 50%;
|
| 200 |
+
animation: bounce 1.2s infinite;
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
.dot:nth-child(2) { animation-delay: 0.2s; }
|
| 204 |
+
.dot:nth-child(3) { animation-delay: 0.4s; }
|
| 205 |
+
|
| 206 |
+
@keyframes bounce {
|
| 207 |
+
0%, 60%, 100% { transform: translateY(0); }
|
| 208 |
+
30% { transform: translateY(-6px); }
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
/* ─── Input bar ──────────────────────────────────────────────────────────── */
|
| 212 |
+
.chat-input-bar {
|
| 213 |
+
background: white;
|
| 214 |
+
border-top: 1px solid var(--gray-200);
|
| 215 |
+
padding: 12px 16px;
|
| 216 |
+
flex-shrink: 0;
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
.image-preview-container {
|
| 220 |
+
position: relative;
|
| 221 |
+
display: inline-block;
|
| 222 |
+
margin-bottom: 10px;
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
.image-preview-thumb {
|
| 226 |
+
width: 72px;
|
| 227 |
+
height: 72px;
|
| 228 |
+
object-fit: cover;
|
| 229 |
+
border-radius: 10px;
|
| 230 |
+
border: 2px solid var(--gray-200);
|
| 231 |
+
display: block;
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
.remove-image-btn {
|
| 235 |
+
position: absolute;
|
| 236 |
+
top: -8px;
|
| 237 |
+
right: -8px;
|
| 238 |
+
width: 22px;
|
| 239 |
+
height: 22px;
|
| 240 |
+
background: var(--gray-700);
|
| 241 |
+
color: white;
|
| 242 |
+
border: none;
|
| 243 |
+
border-radius: 50%;
|
| 244 |
+
font-size: 0.875rem;
|
| 245 |
+
line-height: 1;
|
| 246 |
+
cursor: pointer;
|
| 247 |
+
display: flex;
|
| 248 |
+
align-items: center;
|
| 249 |
+
justify-content: center;
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
.input-row {
|
| 253 |
+
display: flex;
|
| 254 |
+
align-items: flex-end;
|
| 255 |
+
gap: 8px;
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
.attach-btn {
|
| 259 |
+
width: 38px;
|
| 260 |
+
height: 38px;
|
| 261 |
+
border: 1px solid var(--gray-300);
|
| 262 |
+
background: transparent;
|
| 263 |
+
border-radius: 10px;
|
| 264 |
+
cursor: pointer;
|
| 265 |
+
color: var(--gray-500);
|
| 266 |
+
display: flex;
|
| 267 |
+
align-items: center;
|
| 268 |
+
justify-content: center;
|
| 269 |
+
flex-shrink: 0;
|
| 270 |
+
transition: all 0.15s;
|
| 271 |
+
}
|
| 272 |
+
|
| 273 |
+
.attach-btn:hover:not(:disabled) {
|
| 274 |
+
background: var(--gray-100);
|
| 275 |
+
border-color: var(--gray-400);
|
| 276 |
+
color: var(--gray-700);
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
.attach-btn:disabled {
|
| 280 |
+
opacity: 0.4;
|
| 281 |
+
cursor: not-allowed;
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
.attach-btn svg {
|
| 285 |
+
width: 18px;
|
| 286 |
+
height: 18px;
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
.chat-input {
|
| 290 |
+
flex: 1;
|
| 291 |
+
border: 1px solid var(--gray-300);
|
| 292 |
+
border-radius: 10px;
|
| 293 |
+
padding: 9px 14px;
|
| 294 |
+
font-size: 0.9375rem;
|
| 295 |
+
font-family: inherit;
|
| 296 |
+
resize: none;
|
| 297 |
+
line-height: 1.5;
|
| 298 |
+
max-height: 160px;
|
| 299 |
+
overflow-y: auto;
|
| 300 |
+
transition: border-color 0.15s;
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
.chat-input:focus {
|
| 304 |
+
outline: none;
|
| 305 |
+
border-color: var(--primary);
|
| 306 |
+
}
|
| 307 |
+
|
| 308 |
+
.chat-input:disabled {
|
| 309 |
+
background: var(--gray-50);
|
| 310 |
+
color: var(--gray-400);
|
| 311 |
+
}
|
| 312 |
+
|
| 313 |
+
.send-btn {
|
| 314 |
+
width: 38px;
|
| 315 |
+
height: 38px;
|
| 316 |
+
background: var(--primary);
|
| 317 |
+
color: white;
|
| 318 |
+
border: none;
|
| 319 |
+
border-radius: 10px;
|
| 320 |
+
cursor: pointer;
|
| 321 |
+
display: flex;
|
| 322 |
+
align-items: center;
|
| 323 |
+
justify-content: center;
|
| 324 |
+
flex-shrink: 0;
|
| 325 |
+
transition: background 0.15s;
|
| 326 |
+
}
|
| 327 |
+
|
| 328 |
+
.send-btn:hover:not(:disabled) {
|
| 329 |
+
background: var(--primary-hover);
|
| 330 |
+
}
|
| 331 |
+
|
| 332 |
+
.send-btn:disabled {
|
| 333 |
+
background: var(--gray-300);
|
| 334 |
+
cursor: not-allowed;
|
| 335 |
+
}
|
| 336 |
+
|
| 337 |
+
.send-btn svg {
|
| 338 |
+
width: 18px;
|
| 339 |
+
height: 18px;
|
| 340 |
+
}
|