Tirath5504 commited on
Commit
bf07f10
·
1 Parent(s): 54de50d
Files changed (44) hide show
  1. .streamlit/config.toml +22 -0
  2. Dockerfile +18 -8
  3. apps/__pycache__/patient_chat_app_cloud.cpython-311.pyc +0 -0
  4. apps/patient_chat_app_cloud.py +666 -0
  5. apps/patient_chat_app_local.py +663 -0
  6. outputs/.DS_Store +0 -0
  7. outputs/best_densenet169.pth +3 -0
  8. outputs/best_efficientnetv2.pth +3 -0
  9. outputs/best_maxvit.pth +3 -0
  10. outputs/best_mobilenetv2.pth +3 -0
  11. outputs/best_swin.pth +3 -0
  12. requirements-prod.txt +21 -0
  13. src/__init__.py +24 -0
  14. src/__pycache__/__init__.cpython-311.pyc +0 -0
  15. src/agents/__init__.py +19 -0
  16. src/agents/__pycache__/__init__.cpython-311.pyc +0 -0
  17. src/agents/__pycache__/cross_validation_agent.cpython-311.pyc +0 -0
  18. src/agents/__pycache__/diagnostic_agent.cpython-311.pyc +0 -0
  19. src/agents/__pycache__/educational_agent.cpython-311.pyc +0 -0
  20. src/agents/__pycache__/explain_agent.cpython-311.pyc +0 -0
  21. src/agents/__pycache__/knowledge_agent.cpython-311.pyc +0 -0
  22. src/agents/cross_validation_agent.py +183 -0
  23. src/agents/diagnostic_agent.py +142 -0
  24. src/agents/educational_agent.py +148 -0
  25. src/agents/explain_agent.py +164 -0
  26. src/agents/knowledge_agent.py +109 -0
  27. src/analysis/__init__.py +10 -0
  28. src/analysis/analyze.py +104 -0
  29. src/analysis/analyze_2.py +210 -0
  30. src/analysis/visualize_gradcam.py +314 -0
  31. src/config/cloud_deployment.py +253 -0
  32. src/training/__init__.py +9 -0
  33. src/training/pipeline.py +394 -0
  34. src/training/pipeline_2.py +225 -0
  35. src/utils/__init__.py +12 -0
  36. src/utils/__pycache__/__init__.cpython-311.pyc +0 -0
  37. src/utils/__pycache__/data_utils.cpython-311.pyc +0 -0
  38. src/utils/__pycache__/device_utils.cpython-311.pyc +0 -0
  39. src/utils/__pycache__/model_utils.cpython-311.pyc +0 -0
  40. src/utils/data_utils.py +59 -0
  41. src/utils/device_utils.py +18 -0
  42. src/utils/model_manager.py +190 -0
  43. src/utils/model_utils.py +54 -0
  44. streamlit_app.py +66 -0
.streamlit/config.toml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [theme]
2
+ primaryColor = "#1f77b4"
3
+ backgroundColor = "#ffffff"
4
+ secondaryBackgroundColor = "#f0f2f6"
5
+ textColor = "#262730"
6
+ font = "sans serif"
7
+
8
+ [client]
9
+ showErrorDetails = false
10
+
11
+ [logger]
12
+ level = "info"
13
+
14
+ [server]
15
+ port = 7860
16
+ address = "0.0.0.0"
17
+ headless = true
18
+ runOnSave = true
19
+ maxUploadSize = 500
20
+
21
+ [browser]
22
+ gatherUsageStats = false
Dockerfile CHANGED
@@ -1,20 +1,30 @@
1
- FROM python:3.13.5-slim
2
 
3
  WORKDIR /app
4
 
 
5
  RUN apt-get update && apt-get install -y \
6
- build-essential \
7
  curl \
8
- git \
9
  && rm -rf /var/lib/apt/lists/*
10
 
11
- COPY requirements.txt ./
12
- COPY src/ ./src/
13
 
14
- RUN pip3 install -r requirements.txt
 
15
 
 
 
 
 
 
 
 
16
  EXPOSE 8501
17
 
18
- HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
 
19
 
20
- ENTRYPOINT ["streamlit", "run", "src/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
 
 
1
+ FROM python:3.11-slim
2
 
3
  WORKDIR /app
4
 
5
+ # Install system dependencies
6
  RUN apt-get update && apt-get install -y \
7
+ libsm6 libxext6 libxrender-dev \
8
  curl \
 
9
  && rm -rf /var/lib/apt/lists/*
10
 
11
+ # Copy requirements
12
+ COPY requirements-prod.txt .
13
 
14
+ # Install Python dependencies
15
+ RUN pip install --no-cache-dir -r requirements-prod.txt
16
 
17
+ # Copy application
18
+ COPY . .
19
+
20
+ # Create outputs directory for models
21
+ RUN mkdir -p outputs
22
+
23
+ # Expose Streamlit port
24
  EXPOSE 8501
25
 
26
+ # Health check
27
+ HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health || exit 1
28
 
29
+ # Run Streamlit
30
+ CMD ["streamlit", "run", "streamlit_app.py", "--logger.level=info", "--server.port=7860", "--server.address=0.0.0.0"]
apps/__pycache__/patient_chat_app_cloud.cpython-311.pyc ADDED
Binary file (34.5 kB). View file
 
apps/patient_chat_app_cloud.py ADDED
@@ -0,0 +1,666 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Streamlit-based Patient Chat Application for Fracture Detection and Diagnosis.
3
+ CLOUD VERSION - Uses Hugging Face Inference API instead of Ollama
4
+
5
+ Supports:
6
+ 1. Running individual agents (Diagnostic, Educational, Explainability, Knowledge)
7
+ 2. Running the complete workflow
8
+ 3. LLM-based Q&A via Hugging Face Inference API for patient education
9
+ """
10
+
11
+ import os
12
+ import sys
13
+ import streamlit as st
14
+ import requests
15
+ import json
16
+ import numpy as np
17
+ from typing import Dict, Any, List
18
+ from pathlib import Path
19
+
20
+ # Add parent directory to path for imports
21
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
22
+
23
+ # --- Import the Agents ---
24
+ from src.agents.diagnostic_agent import DiagnosticAgent
25
+ from src.agents.educational_agent import EducationalAgent
26
+ from src.agents.explain_agent import ExplainabilityAgent, generate_random_heatmap, calculate_heatmap_centroid
27
+ from src.agents.knowledge_agent import KnowledgeAgent, MEDICAL_KNOWLEDGE_BASE
28
+ from src.agents.cross_validation_agent import ModelEnsembleAgent
29
+ from src.utils import get_device
30
+
31
+ # --- Hugging Face Inference API Configuration ---
32
+ # Try both uppercase and lowercase key names for flexibility
33
+ HF_API_KEY = st.secrets.get("HUGGINGFACE_API_KEY", st.secrets.get("huggingface_api_key", ""))
34
+ # HF_API_URL = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.2"
35
+ HF_API_URL = "https://router.huggingface.co/v1/chat/completions"
36
+ HF_HEADERS = {"Authorization": f"Bearer {HF_API_KEY}"}
37
+
38
+ # --- Constants ---
39
+ CLASS_NAMES = ["Comminuted", "Greenstick", "Healthy", "Oblique",
40
+ "Oblique Displaced", "Spiral", "Transverse", "Transverse Displaced"]
41
+ NUM_CLASSES = len(CLASS_NAMES)
42
+ IMG_SIZE = 224
43
+
44
+ # --- Page Configuration ---
45
+ st.set_page_config(
46
+ page_title="🦴 Fracture Detection AI System",
47
+ layout="wide",
48
+ initial_sidebar_state="expanded"
49
+ )
50
+
51
+ # --- Custom CSS for Better UI ---
52
+ st.markdown("""
53
+ <style>
54
+ .stTabs [data-baseweb="tab-list"] button {
55
+ font-size: 16px;
56
+ font-weight: bold;
57
+ }
58
+ .section-header {
59
+ font-size: 20px;
60
+ font-weight: bold;
61
+ margin-top: 20px;
62
+ margin-bottom: 10px;
63
+ }
64
+ </style>
65
+ """, unsafe_allow_html=True)
66
+
67
+
68
+ # ============================================================================
69
+ # --- 1. Hugging Face Inference API Patient Interaction Agent ---
70
+ # ============================================================================
71
+
72
+ class PatientInteractionAgent:
73
+ """
74
+ Handles communication with Mistral 7B model via Hugging Face Inference API.
75
+ Free tier available with rate limiting.
76
+ """
77
+ def __init__(self, medical_summary: Dict[str, Any], patient_history: Dict[str, Any]):
78
+ """Initialize the agent with medical context."""
79
+ # --- Configuration Check ---
80
+ if not HF_API_KEY:
81
+ raise ValueError(
82
+ "⚠️ HUGGINGFACE_API_KEY not found in Streamlit Secrets. "
83
+ "Please set your Hugging Face API token in Settings > Secrets."
84
+ )
85
+
86
+ self.medical_summary = medical_summary
87
+ self.patient_history = patient_history
88
+ self.system_prompt = self._build_system_prompt()
89
+
90
+ def _build_system_prompt(self) -> str:
91
+ """Creates a detailed instruction set for the LLM (RAG Context)."""
92
+ guidelines = "\n- ".join(self.medical_summary.get('Guidelines', ["No specific guidelines available."]))
93
+
94
+ return f"""You are a highly compassionate, clear, and professional medical assistant. Your goal is to answer patient questions in natural language based ONLY on the following diagnostic information and patient history.
95
+
96
+ RULES:
97
+ 1. Maintain a reassuring, non-technical, and empathetic tone suitable for a patient.
98
+ 2. Keep answers concise and address the patient's underlying concern.
99
+ 3. ALWAYS conclude your answer by advising the patient to consult their orthopedic specialist or doctor.
100
+
101
+ --- DIAGNOSTIC INFORMATION ---
102
+ Diagnosis: {self.medical_summary.get('Diagnosis')} (Confidence: {self.medical_summary.get('Ensemble_Confidence')})
103
+ Definition: {self.medical_summary.get('Type')}
104
+ Severity: {self.medical_summary.get('Severity')}
105
+ Treatment Guidelines:
106
+ {guidelines}
107
+
108
+ --- PATIENT HISTORY ---
109
+ Age: {self.patient_history.get('age')}
110
+ Gender: {self.patient_history.get('gender')}
111
+ Medical History: {self.patient_history.get('history')}"""
112
+
113
+ def get_response(self, query: str) -> str:
114
+ """Queries the Hugging Face Inference API with the patient's question."""
115
+ try:
116
+ # Format prompt for Mistral using [INST] tags
117
+ full_prompt = f"{self.system_prompt}\n\nPATIENT QUERY: {query}"
118
+
119
+ # payload = {
120
+ # "inputs": f"[INST] {full_prompt} [/INST]",
121
+ # "parameters": {
122
+ # "max_new_tokens": 512,
123
+ # "return_full_text": False,
124
+ # "temperature": 0.7,
125
+ # }
126
+ # }
127
+ payload = {
128
+ "messages": [
129
+ {
130
+ "role": "user",
131
+ "content": f"[INST] {full_prompt} [/INST]"
132
+ }
133
+ ],
134
+ "model": "meta-llama/Llama-3.1-8B-Instruct:cerebras"
135
+ }
136
+
137
+ response = requests.post(
138
+ HF_API_URL,
139
+ headers=HF_HEADERS,
140
+ json=payload,
141
+ timeout=60
142
+ )
143
+ response.raise_for_status()
144
+
145
+ result = response.json()
146
+ result = result["choices"][0]["message"]
147
+
148
+ # Handle different response formats
149
+ if isinstance(result, list) and len(result) > 0:
150
+ return result[0].get("generated_text", "Error: Unexpected API response format.")
151
+ elif isinstance(result, dict) and "generated_text" in result:
152
+ return result["generated_text"]
153
+ elif isinstance(result, dict) and "content" in result:
154
+ return result["content"]
155
+ elif "error" in result:
156
+ # Handle API errors (e.g., model loading, rate limiting)
157
+ error_msg = result.get("error", "Unknown error")
158
+ if "rate_limit" in str(error_msg).lower():
159
+ return "⚠️ API rate limit reached. Please wait a moment and try again."
160
+ return f"⚠️ API Error: {error_msg}"
161
+ else:
162
+ return "Error: Unknown API response format."
163
+
164
+ except requests.exceptions.Timeout:
165
+ return "⏱️ Request timed out. The model may be loading. Please try again."
166
+ except requests.exceptions.RequestException as e:
167
+ return f"❌ Network error: {str(e)}"
168
+ except Exception as e:
169
+ return f"❌ Unexpected error: {str(e)}"
170
+
171
+
172
+ # ============================================================================
173
+ # --- 2. Helper Functions ---
174
+ # ============================================================================
175
+
176
+ def save_uploaded_file(uploaded_file) -> str:
177
+ """Save uploaded file to a temporary location."""
178
+ if uploaded_file is None:
179
+ return None
180
+
181
+ try:
182
+ import tempfile
183
+ # Create a temporary file in temp_uploads directory
184
+ temp_dir = Path("./temp_uploads")
185
+ temp_dir.mkdir(exist_ok=True)
186
+
187
+ # Create temp file with proper extension
188
+ suffix = Path(uploaded_file.name).suffix or '.jpg'
189
+ with tempfile.NamedTemporaryFile(
190
+ dir=str(temp_dir),
191
+ suffix=suffix,
192
+ delete=False
193
+ ) as tmp_file:
194
+ tmp_file.write(uploaded_file.getbuffer())
195
+ return tmp_file.name # Returns full path
196
+ except Exception as e:
197
+ st.error(f"Error saving file: {e}")
198
+ return None
199
+
200
+
201
+ # ============================================================================
202
+ # --- 3. Workflow Functions ---
203
+ # ============================================================================
204
+
205
+ def run_diagnostic_agent(image_path: str) -> Dict[str, Any]:
206
+ """Run the diagnostic agent on an image."""
207
+ try:
208
+ # Placeholder checkpoint path - in production, use actual model checkpoint
209
+ checkpoint_path = "./outputs/best_swin.pth"
210
+
211
+ if not os.path.exists(checkpoint_path):
212
+ return {"error": f"Checkpoint not found at {checkpoint_path}"}
213
+
214
+ agent = DiagnosticAgent(
215
+ checkpoint_path=checkpoint_path,
216
+ model_name='swin',
217
+ num_classes=NUM_CLASSES,
218
+ img_size=IMG_SIZE,
219
+ class_names=CLASS_NAMES
220
+ )
221
+
222
+ result = agent.run_diagnosis(image_path)
223
+ return result
224
+ except Exception as e:
225
+ return {"error": str(e)}
226
+
227
+
228
+ def run_ensemble_agent(image_path: str) -> Dict[str, Any]:
229
+ """Run the ensemble agent on an image."""
230
+ try:
231
+ checkpoints_dir = "./outputs"
232
+
233
+ if not os.path.exists(checkpoints_dir):
234
+ return {"error": f"Checkpoints directory not found at {checkpoints_dir}"}
235
+
236
+ agent = ModelEnsembleAgent(
237
+ model_names=['swin', 'mobilenetv2', 'densenet169', 'efficientnetv2', 'maxvit'],
238
+ checkpoints_dir=checkpoints_dir,
239
+ num_classes=NUM_CLASSES,
240
+ class_names=CLASS_NAMES
241
+ )
242
+
243
+ result = agent.run_ensemble(image_path)
244
+ return result
245
+ except Exception as e:
246
+ return {"error": str(e)}
247
+
248
+
249
+ def run_educational_agent(diagnosis_result: Dict[str, Any], explanation_text: str = "") -> Dict[str, str]:
250
+ """Run the educational agent to translate diagnosis."""
251
+ try:
252
+ agent = EducationalAgent(doctor_name="your treating doctor")
253
+
254
+ # Map ensemble result format to educational agent format
255
+ # Ensemble uses: ensemble_prediction, ensemble_confidence
256
+ # EducationalAgent expects: predicted_class, confidence_score
257
+ mapped_result = {
258
+ "predicted_class": diagnosis_result.get("ensemble_prediction", "Unknown"),
259
+ "confidence_score": diagnosis_result.get("ensemble_confidence", 0.0),
260
+ "fracture_detected": diagnosis_result.get("fracture_detected", True)
261
+ }
262
+
263
+ result = agent.translate_to_layman_terms(mapped_result, explanation_text)
264
+ return result
265
+ except Exception as e:
266
+ return {"error": str(e)}
267
+
268
+
269
+ def run_explainability_agent(diagnosis_result: Dict[str, Any]) -> str:
270
+ """Run the explainability agent to generate explanations."""
271
+ try:
272
+ agent = ExplainabilityAgent(class_names=CLASS_NAMES, body_part="bone")
273
+
274
+ # Map ensemble result format to explainability agent format
275
+ # Ensemble uses: ensemble_prediction, ensemble_confidence
276
+ # ExplainabilityAgent expects: predicted_class, confidence_score
277
+ mapped_result = {
278
+ "predicted_class": diagnosis_result.get("ensemble_prediction", "Unknown"),
279
+ "confidence_score": diagnosis_result.get("ensemble_confidence", 0.0),
280
+ "fracture_detected": diagnosis_result.get("fracture_detected", True)
281
+ }
282
+
283
+ # Generate a random heatmap for demonstration
284
+ heatmap = generate_random_heatmap()
285
+
286
+ # Call with correct parameters
287
+ explanation = agent.generate_explanation(
288
+ diagnosis_result=mapped_result,
289
+ cam_array=heatmap
290
+ )
291
+ return explanation
292
+ except Exception as e:
293
+ return f"Error generating explanation: {str(e)}"
294
+
295
+
296
+ def run_knowledge_agent(diagnosis: str, confidence: float) -> Dict[str, Any]:
297
+ """Run the knowledge agent to retrieve medical information."""
298
+ try:
299
+ agent = KnowledgeAgent(knowledge_base=MEDICAL_KNOWLEDGE_BASE)
300
+ result = agent.get_medical_summary(diagnosis, confidence)
301
+ return result
302
+ except Exception as e:
303
+ return {"error": str(e)}
304
+
305
+
306
+ def run_complete_workflow(image_path: str) -> Dict[str, Any]:
307
+ """Run the complete workflow: Ensemble -> Education -> Knowledge."""
308
+ workflow_result = {
309
+ "ensemble_result": None,
310
+ "educational_result": None,
311
+ "knowledge_result": None,
312
+ "explanation_result": None
313
+ }
314
+
315
+ try:
316
+ # 1. Run Ensemble Agent
317
+ ensemble_result = run_ensemble_agent(image_path)
318
+ if "error" in ensemble_result:
319
+ return {"error": f"Ensemble failed: {ensemble_result['error']}"}
320
+
321
+ workflow_result["ensemble_result"] = ensemble_result
322
+
323
+ # 2. Run Educational Agent
324
+ educational_result = run_educational_agent(ensemble_result)
325
+ workflow_result["educational_result"] = educational_result
326
+
327
+ # 3. Run Explainability Agent
328
+ explanation = run_explainability_agent(ensemble_result)
329
+ workflow_result["explanation_result"] = explanation
330
+
331
+ # 4. Run Knowledge Agent
332
+ diagnosis = ensemble_result.get("ensemble_prediction", "Unknown")
333
+ confidence = ensemble_result.get("ensemble_confidence", 0.0)
334
+ knowledge_result = run_knowledge_agent(diagnosis, confidence)
335
+ workflow_result["knowledge_result"] = knowledge_result
336
+
337
+ return workflow_result
338
+ except Exception as e:
339
+ return {"error": str(e)}
340
+
341
+
342
+ # ============================================================================
343
+ # --- 4. Streamlit UI ---
344
+ # ============================================================================
345
+
346
+ def main():
347
+ """Main Streamlit application."""
348
+ st.title("🦴 AI Medical Assistant for Fracture Detection & Diagnosis")
349
+ st.info("⚠️ **Research/Educational Use Only** - This system is not approved for clinical use without professional oversight.")
350
+ st.markdown("---")
351
+
352
+ # Initialize session state
353
+ if "patient_context" not in st.session_state:
354
+ st.session_state.patient_context = {
355
+ "age": 45,
356
+ "gender": "Female",
357
+ "history": "No major past issues, but has mild osteoporosis."
358
+ }
359
+
360
+ # Initialize workflow results storage
361
+ if "workflow_result" not in st.session_state:
362
+ st.session_state.workflow_result = None
363
+
364
+ # --- Create Tabs ---
365
+ tab1, tab2, tab3, tab4, tab5 = st.tabs(
366
+ ["🏥 Single Agents", "⚙️ Complete Workflow", "💬 Patient Chat", "📋 Workflow Details", "ℹ️ About"]
367
+ )
368
+
369
+ # ========================================================================
370
+ # --- TAB 1: Individual Agents ---
371
+ # ========================================================================
372
+ with tab1:
373
+ st.header("Run Individual Agents")
374
+ st.markdown("Test each agent independently with sample diagnosis data.")
375
+
376
+ agent_choice = st.selectbox(
377
+ "Select an Agent",
378
+ ["Diagnostic Agent", "Ensemble Agent", "Educational Agent", "Explainability Agent", "Knowledge Agent"]
379
+ )
380
+
381
+ # Create columns for layout
382
+ col1, col2 = st.columns([2, 1])
383
+
384
+ with col1:
385
+ if agent_choice == "Diagnostic Agent":
386
+ st.subheader("🔍 Diagnostic Agent")
387
+ st.write("Runs a single model on an X-ray image to detect fractures.")
388
+
389
+ image_file = st.file_uploader("Upload X-ray image", type=["jpg", "png", "jpeg"])
390
+ if image_file and st.button("Run Diagnostic Agent"):
391
+ st.info("Note: Running this requires a valid model checkpoint at ./outputs/best_swin.pth")
392
+ with st.spinner("Running diagnostic agent..."):
393
+ image_path = save_uploaded_file(image_file)
394
+ result = run_diagnostic_agent(image_path)
395
+ st.json(result)
396
+
397
+ elif agent_choice == "Ensemble Agent":
398
+ st.subheader("🎯 Ensemble Agent (5 Models)")
399
+ st.write("Combines predictions from multiple models for robust diagnosis.")
400
+
401
+ image_file = st.file_uploader("Upload X-ray image", type=["jpg", "png", "jpeg"])
402
+ if image_file and st.button("Run Ensemble Agent"):
403
+ st.info("Note: Running this requires model checkpoints in ./outputs/")
404
+ with st.spinner("Running ensemble agent..."):
405
+ image_path = save_uploaded_file(image_file)
406
+ result = run_ensemble_agent(image_path)
407
+ st.json(result)
408
+
409
+ elif agent_choice == "Educational Agent":
410
+ st.subheader("📚 Educational Agent")
411
+ st.write("Translates technical diagnosis into patient-friendly language.")
412
+
413
+ # Sample diagnosis for demo
414
+ sample_diagnosis = {
415
+ "fracture_detected": True,
416
+ "predicted_class": "Transverse",
417
+ "confidence_score": 0.85,
418
+ "severity_type": "Transverse"
419
+ }
420
+
421
+ sample_explanation = "The bone shows a clear transverse break pattern."
422
+
423
+ if st.button("Run Educational Agent (Demo)"):
424
+ with st.spinner("Translating diagnosis..."):
425
+ result = run_educational_agent(sample_diagnosis, sample_explanation)
426
+ if isinstance(result, dict):
427
+ for key, value in result.items():
428
+ st.write(f"**{key}:**\n{value}")
429
+ else:
430
+ st.error(result)
431
+
432
+ elif agent_choice == "Explainability Agent":
433
+ st.subheader("🎨 Explainability Agent")
434
+ st.write("Generates human-readable explanations of model predictions.")
435
+
436
+ sample_diagnosis = {
437
+ "predicted_class": "Greenstick",
438
+ "confidence_score": 0.92,
439
+ "fracture_detected": True
440
+ }
441
+
442
+ if st.button("Run Explainability Agent (Demo)"):
443
+ with st.spinner("Generating explanation..."):
444
+ explanation = run_explainability_agent(sample_diagnosis)
445
+ st.write(explanation)
446
+
447
+ elif agent_choice == "Knowledge Agent":
448
+ st.subheader("🧠 Knowledge Agent")
449
+ st.write("Retrieves medical knowledge and guidelines for a diagnosis.")
450
+
451
+ diagnosis_input = st.selectbox("Select Diagnosis", CLASS_NAMES)
452
+ confidence_input = st.slider("Confidence Score", 0.0, 1.0, 0.85)
453
+
454
+ if st.button("Run Knowledge Agent"):
455
+ with st.spinner("Retrieving medical knowledge..."):
456
+ result = run_knowledge_agent(diagnosis_input, confidence_input)
457
+ if isinstance(result, dict):
458
+ st.json(result)
459
+ else:
460
+ st.error(result)
461
+
462
+ # ========================================================================
463
+ # --- TAB 2: Complete Workflow ---
464
+ # ========================================================================
465
+ with tab2:
466
+ st.header("Complete Diagnosis Workflow")
467
+ st.markdown("Upload an X-ray image and run the complete diagnostic pipeline.")
468
+
469
+ col1, col2 = st.columns([2, 1])
470
+
471
+ with col1:
472
+ st.subheader("📤 Upload X-ray Image")
473
+ image_file = st.file_uploader("Upload X-ray image for full diagnosis", type=["jpg", "png", "jpeg"])
474
+
475
+ if image_file:
476
+ st.image(image_file, caption="Uploaded Image", width='stretch')
477
+
478
+ with col2:
479
+ st.subheader("👤 Patient Information")
480
+ age = st.number_input("Age", min_value=1, max_value=120, value=st.session_state.patient_context["age"])
481
+ gender = st.selectbox("Gender", ["Male", "Female", "Other"],
482
+ index=0 if st.session_state.patient_context["gender"] == "Male" else
483
+ 1 if st.session_state.patient_context["gender"] == "Female" else 2)
484
+ history = st.text_area("Medical History", value=st.session_state.patient_context["history"])
485
+
486
+ st.session_state.patient_context = {"age": age, "gender": gender, "history": history}
487
+
488
+ if image_file and st.button("🚀 Run Complete Workflow", key="workflow"):
489
+ st.info("Note: Running this requires all model checkpoints in ./outputs/")
490
+ with st.spinner("Running complete diagnostic workflow..."):
491
+ image_path = save_uploaded_file(image_file)
492
+ workflow_result = run_complete_workflow(image_path)
493
+
494
+ # Store workflow result in session state for use in other tabs
495
+ st.session_state.workflow_result = workflow_result
496
+
497
+ if "error" in workflow_result:
498
+ st.error(f"❌ Error: {workflow_result['error']}")
499
+ else:
500
+ # Display results
501
+ st.success("✅ Workflow completed successfully!")
502
+
503
+ # Ensemble Results
504
+ if workflow_result["ensemble_result"]:
505
+ st.subheader("1️⃣ Ensemble Agent Results")
506
+ ensemble = workflow_result["ensemble_result"]
507
+ col1, col2, col3 = st.columns(3)
508
+ col1.metric("Prediction", ensemble.get("ensemble_prediction", "N/A"))
509
+ col2.metric("Confidence", f"{ensemble.get('ensemble_confidence', 0):.2%}")
510
+ col3.metric("Fracture Detected", "Yes" if ensemble.get("fracture_detected") else "No")
511
+
512
+ # Educational Results
513
+ if workflow_result["educational_result"]:
514
+ st.subheader("2️⃣ Patient-Friendly Summary")
515
+ educational = workflow_result["educational_result"]
516
+ for key, value in educational.items():
517
+ st.write(f"**{key}:**\n{value}")
518
+
519
+ # Explainability Results
520
+ if workflow_result["explanation_result"]:
521
+ st.subheader("3️⃣ Technical Explanation")
522
+ st.write(workflow_result["explanation_result"])
523
+
524
+ # Knowledge Results
525
+ if workflow_result["knowledge_result"]:
526
+ st.subheader("4️⃣ Medical Knowledge Base")
527
+ st.json(workflow_result["knowledge_result"])
528
+
529
+ # ========================================================================
530
+ # --- TAB 3: Patient Chat (Hugging Face) ---
531
+ # ========================================================================
532
+ with tab3:
533
+ st.header("💬 Patient Q&A with AI Assistant")
534
+ st.markdown("Ask questions about your fracture diagnosis using Hugging Face Inference API")
535
+
536
+ # Check if workflow has been run
537
+ if st.session_state.workflow_result is None or "error" in st.session_state.workflow_result:
538
+ st.info("ℹ️ Please run the 'Complete Workflow' first to generate a diagnosis for the chat feature.")
539
+ else:
540
+ # Check HF API configuration
541
+ if not HF_API_KEY:
542
+ st.error(
543
+ "❌ Hugging Face API key not configured. "
544
+ "Please add your HUGGINGFACE_API_KEY to Streamlit Secrets."
545
+ )
546
+ st.markdown("""
547
+ ### How to set up Hugging Face API:
548
+ 1. Get your API key from https://huggingface.co/settings/tokens
549
+ 2. In Streamlit Cloud, go to Settings > Secrets
550
+ 3. Add: `HUGGINGFACE_API_KEY = "hf_your_token_here"`
551
+ 4. Refresh the app
552
+ """)
553
+ else:
554
+ # Build medical summary from workflow results
555
+ ensemble_result = st.session_state.workflow_result.get("ensemble_result", {})
556
+ knowledge_result = st.session_state.workflow_result.get("knowledge_result", {})
557
+
558
+ diagnosis = ensemble_result.get("ensemble_prediction", "Unknown")
559
+ confidence = ensemble_result.get("ensemble_confidence", 0.0)
560
+
561
+ # Create medical summary from knowledge base
562
+ medical_summary = {
563
+ "Diagnosis": diagnosis,
564
+ "Ensemble_Confidence": f"{confidence:.2f}",
565
+ "Type": knowledge_result.get("Type", "Unknown fracture type"),
566
+ "Severity": knowledge_result.get("Severity", "Unknown"),
567
+ "Guidelines": knowledge_result.get("Guidelines", [])
568
+ }
569
+
570
+ try:
571
+ agent = PatientInteractionAgent(medical_summary, st.session_state.patient_context)
572
+
573
+ # Initialize chat history with diagnosis info
574
+ if "messages" not in st.session_state:
575
+ st.session_state.messages = []
576
+ st.session_state.messages.append({
577
+ "role": "assistant",
578
+ "content": f"Hello! I'm your AI medical assistant. I've reviewed your diagnosis: **{medical_summary['Diagnosis']}** (Confidence: {medical_summary['Ensemble_Confidence']}). How can I help answer your questions?"
579
+ })
580
+
581
+ # Display chat messages
582
+ for message in st.session_state.messages:
583
+ with st.chat_message(message["role"]):
584
+ st.markdown(message["content"])
585
+
586
+ # Accept user input
587
+ if prompt := st.chat_input("Ask a question about your diagnosis..."):
588
+ st.session_state.messages.append({"role": "user", "content": prompt})
589
+ with st.chat_message("user"):
590
+ st.markdown(prompt)
591
+
592
+ with st.chat_message("assistant"):
593
+ with st.spinner("🤖 Consulting Mistral 7B via Hugging Face..."):
594
+ response = agent.get_response(prompt)
595
+ st.markdown(response)
596
+
597
+ st.session_state.messages.append({"role": "assistant", "content": response})
598
+
599
+ except ValueError as e:
600
+ st.error(str(e))
601
+ except Exception as e:
602
+ st.error(f"❌ Error initializing chat agent: {str(e)}")
603
+
604
+ # ========================================================================
605
+ # --- TAB 4: Workflow Details ---
606
+ # ========================================================================
607
+ with tab4:
608
+ st.header("📋 Workflow Execution Details")
609
+
610
+ if st.session_state.workflow_result is None:
611
+ st.info("ℹ️ No workflow results available. Please run a workflow first.")
612
+ else:
613
+ if "error" in st.session_state.workflow_result:
614
+ st.error(f"Workflow Error: {st.session_state.workflow_result['error']}")
615
+ else:
616
+ st.success("Workflow executed successfully!")
617
+ st.json(st.session_state.workflow_result)
618
+
619
+ # ========================================================================
620
+ # --- TAB 5: About ---
621
+ # ========================================================================
622
+ with tab5:
623
+ st.header("ℹ️ About This Application")
624
+ st.markdown("""
625
+ ### 🦴 AI-Powered Fracture Detection System
626
+
627
+ This application uses advanced deep learning models to detect and classify fractures from X-ray images.
628
+
629
+ **Features:**
630
+ - **Multi-Model Ensemble:** Combines 5 different architectures (Swin, MobileNetV2, DenseNet, EfficientNet, MaxViT)
631
+ - **Explainability:** Generates human-readable explanations for predictions
632
+ - **Patient Education:** Translates medical terminology into patient-friendly language
633
+ - **AI Chatbot:** Ask questions about your diagnosis powered by Mistral 7B via Hugging Face
634
+
635
+ **Models Used:**
636
+ - Swin Transformer
637
+ - MobileNetV2
638
+ - DenseNet169
639
+ - EfficientNetV2
640
+ - MaxViT
641
+
642
+ **Fracture Types Detected:**
643
+ """)
644
+
645
+ for i, fracture_type in enumerate(CLASS_NAMES, 1):
646
+ st.write(f"{i}. {fracture_type}")
647
+
648
+ st.markdown("""
649
+ ### ⚠️ Important Disclaimer
650
+ This system is for **research and educational purposes only**.
651
+ It is **NOT approved for clinical use** without professional medical oversight.
652
+ Always consult with a qualified healthcare professional for medical diagnosis.
653
+
654
+ ### 🔧 Technology Stack
655
+ - **Frontend:** Streamlit
656
+ - **ML Models:** PyTorch
657
+ - **AI Assistant:** Hugging Face Inference API (Mistral 7B)
658
+ - **Deployment:** Streamlit Cloud
659
+
660
+ ### 📞 Contact & Support
661
+ For issues or questions, please contact the development team.
662
+ """)
663
+
664
+
665
+ if __name__ == "__main__":
666
+ main()
apps/patient_chat_app_local.py ADDED
@@ -0,0 +1,663 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Streamlit-based Patient Chat Application for Fracture Detection and Diagnosis.
3
+
4
+ Supports:
5
+ 1. Running individual agents (Diagnostic, Educational, Explainability, Knowledge)
6
+ 2. Running the complete workflow
7
+ 3. LLM-based Q&A for patient education
8
+ """
9
+
10
+ import os
11
+ import sys
12
+ import streamlit as st
13
+ import requests
14
+ import json
15
+ import numpy as np
16
+ from typing import Dict, Any, List
17
+ from pathlib import Path
18
+
19
+ # Add parent directory to path for imports
20
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
21
+
22
+ # --- Import the Agents ---
23
+ from src.agents.diagnostic_agent import DiagnosticAgent
24
+ from src.agents.educational_agent import EducationalAgent
25
+ from src.agents.explain_agent import ExplainabilityAgent, generate_random_heatmap, calculate_heatmap_centroid
26
+ from src.agents.knowledge_agent import KnowledgeAgent, MEDICAL_KNOWLEDGE_BASE
27
+ from src.agents.cross_validation_agent import ModelEnsembleAgent
28
+ from src.utils import get_device
29
+
30
+ # --- Configuration for Ollama ---
31
+ # Support both localhost and host.docker.internal for Docker deployments
32
+ OLLAMA_ENDPOINT = os.getenv("OLLAMA_ENDPOINT", os.getenv("OLLAMA_HOST", "http://localhost:11434") + "/api/generate")
33
+ OLLAMA_MODEL = os.getenv("OLLAMA_MODEL", "llama3") # Ensure you have pulled this model using 'ollama pull llama3'
34
+ OLLAMA_CHECK_URL = os.getenv("OLLAMA_HOST", "http://localhost:11434")
35
+
36
+ # --- Constants ---
37
+ CLASS_NAMES = ["Comminuted", "Greenstick", "Healthy", "Oblique",
38
+ "Oblique Displaced", "Spiral", "Transverse", "Transverse Displaced"]
39
+ NUM_CLASSES = len(CLASS_NAMES)
40
+ IMG_SIZE = 224
41
+
42
+ # --- Page Configuration ---
43
+ st.set_page_config(
44
+ page_title="🦴 Fracture Detection AI System",
45
+ layout="wide",
46
+ initial_sidebar_state="expanded"
47
+ )
48
+
49
+ # --- Custom CSS for Better UI ---
50
+ st.markdown("""
51
+ <style>
52
+ .stTabs [data-baseweb="tab-list"] button {
53
+ font-size: 16px;
54
+ font-weight: bold;
55
+ }
56
+ .section-header {
57
+ font-size: 20px;
58
+ font-weight: bold;
59
+ margin-top: 20px;
60
+ margin-bottom: 10px;
61
+ }
62
+ </style>
63
+ """, unsafe_allow_html=True)
64
+
65
+
66
+ # ============================================================================
67
+ # --- 1. Ollama-based Patient Interaction Agent ---
68
+ # ============================================================================
69
+
70
+ class PatientInteractionAgent:
71
+ """
72
+ Handles communication with the local Llama 3 model via the Ollama API endpoint.
73
+ """
74
+ def __init__(self, medical_summary: Dict[str, Any], patient_history: Dict[str, Any]):
75
+ """Initialize the agent with medical context."""
76
+ # --- Connection Check ---
77
+ try:
78
+ response = requests.get(OLLAMA_CHECK_URL, timeout=5)
79
+ if response.status_code != 200:
80
+ raise ConnectionError("Ollama server is not running or accessible.")
81
+ except requests.exceptions.ConnectionError:
82
+ raise ConnectionError("Ollama server is not running. Please start Ollama.")
83
+
84
+ self.medical_summary = medical_summary
85
+ self.patient_history = patient_history
86
+ self.system_prompt = self._build_system_prompt()
87
+
88
+ def _build_system_prompt(self) -> str:
89
+ """Creates a detailed instruction set for the LLM (RAG Context)."""
90
+ guidelines = "\n- ".join(self.medical_summary.get('Guidelines', ["No specific guidelines available."]))
91
+
92
+ return f"""
93
+ You are a highly compassionate, clear, and professional medical assistant. Your goal is to answer patient questions
94
+ in natural language based ONLY on the following diagnostic information and patient history.
95
+
96
+ RULES:
97
+ 1. Maintain a reassuring, non-technical, and empathetic tone suitable for a patient.
98
+ 2. Keep answers concise and address the patient's underlying concern.
99
+ 3. ALWAYS conclude your answer by advising the patient to consult their orthopedic specialist or doctor.
100
+
101
+ --- DIAGNOSTIC INFORMATION ---
102
+ Diagnosis: {self.medical_summary.get('Diagnosis')} (Confidence: {self.medical_summary.get('Ensemble_Confidence')})
103
+ Definition: {self.medical_summary.get('Type')}
104
+ Severity: {self.medical_summary.get('Severity')}
105
+ Treatment Guidelines:
106
+ {guidelines}
107
+
108
+ --- PATIENT HISTORY ---
109
+ Age: {self.patient_history.get('age')}
110
+ Gender: {self.patient_history.get('gender')}
111
+ Medical History: {self.patient_history.get('history')}
112
+ """
113
+
114
+ def get_response(self, query: str) -> str:
115
+ """Sends the user query to the Llama 3 model via Ollama."""
116
+ full_prompt = f"{self.system_prompt}\n\nPATIENT QUERY: {query}"
117
+
118
+ payload = {
119
+ "model": OLLAMA_MODEL,
120
+ "prompt": full_prompt,
121
+ "stream": False,
122
+ "options": {"temperature": 0.1}
123
+ }
124
+
125
+ try:
126
+ response = requests.post(OLLAMA_ENDPOINT, json=payload, timeout=300)
127
+ response.raise_for_status()
128
+ data = response.json()
129
+ return data.get("response", "Error: Could not extract response from Ollama.")
130
+ except requests.exceptions.RequestException as e:
131
+ return f"Error communicating with Ollama: {e}"
132
+ except Exception as e:
133
+ return f"An unexpected error occurred: {e}"
134
+
135
+
136
+ # ============================================================================
137
+ # --- 2. Helper Functions ---
138
+ # ============================================================================
139
+
140
+ def save_uploaded_file(uploaded_file) -> str:
141
+ """Save uploaded file to temp location and return path."""
142
+ if uploaded_file is None:
143
+ return None
144
+
145
+ try:
146
+ import tempfile
147
+ # Create a temporary file
148
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as tmp_file:
149
+ tmp_file.write(uploaded_file.getbuffer())
150
+ return tmp_file.name
151
+ except Exception as e:
152
+ st.error(f"Error saving file: {e}")
153
+ return None
154
+
155
+
156
+ # ============================================================================
157
+ # --- 3. Workflow Functions ---
158
+ # ============================================================================
159
+
160
+ def run_diagnostic_agent(image_path: str) -> Dict[str, Any]:
161
+ """Run the diagnostic agent on an image."""
162
+ try:
163
+ # Placeholder checkpoint path - in production, use actual model checkpoint
164
+ checkpoint_path = "./outputs/best_swin.pth"
165
+
166
+ if not os.path.exists(checkpoint_path):
167
+ return {"error": f"Checkpoint not found at {checkpoint_path}"}
168
+
169
+ agent = DiagnosticAgent(
170
+ checkpoint_path=checkpoint_path,
171
+ model_name='swin',
172
+ num_classes=NUM_CLASSES,
173
+ img_size=IMG_SIZE,
174
+ class_names=CLASS_NAMES
175
+ )
176
+
177
+ result = agent.run_diagnosis(image_path)
178
+ return result
179
+ except Exception as e:
180
+ return {"error": str(e)}
181
+
182
+
183
+ def run_ensemble_agent(image_path: str) -> Dict[str, Any]:
184
+ """Run the ensemble agent on an image."""
185
+ try:
186
+ checkpoints_dir = "./outputs"
187
+
188
+ if not os.path.exists(checkpoints_dir):
189
+ return {"error": f"Checkpoints directory not found at {checkpoints_dir}"}
190
+
191
+ agent = ModelEnsembleAgent(
192
+ model_names=['swin', 'mobilenetv2', 'densenet169', 'efficientnetv2', 'maxvit'],
193
+ checkpoints_dir=checkpoints_dir,
194
+ num_classes=NUM_CLASSES,
195
+ class_names=CLASS_NAMES
196
+ )
197
+
198
+ result = agent.run_ensemble(image_path)
199
+ return result
200
+ except Exception as e:
201
+ return {"error": str(e)}
202
+
203
+
204
+ def run_educational_agent(diagnosis_result: Dict[str, Any], explanation_text: str = "") -> Dict[str, str]:
205
+ """Run the educational agent to translate diagnosis."""
206
+ try:
207
+ agent = EducationalAgent(doctor_name="your treating doctor")
208
+
209
+ # Map ensemble result format to educational agent format
210
+ # Ensemble uses: ensemble_prediction, ensemble_confidence
211
+ # EducationalAgent expects: predicted_class, confidence_score
212
+ mapped_result = {
213
+ "predicted_class": diagnosis_result.get("ensemble_prediction", "Unknown"),
214
+ "confidence_score": diagnosis_result.get("ensemble_confidence", 0.0),
215
+ "fracture_detected": diagnosis_result.get("fracture_detected", True)
216
+ }
217
+
218
+ result = agent.translate_to_layman_terms(mapped_result, explanation_text)
219
+ return result
220
+ except Exception as e:
221
+ return {"error": str(e)}
222
+
223
+
224
+ def run_explainability_agent(diagnosis_result: Dict[str, Any]) -> str:
225
+ """Run the explainability agent to generate explanations."""
226
+ try:
227
+ agent = ExplainabilityAgent(class_names=CLASS_NAMES, body_part="bone")
228
+
229
+ # Map ensemble result format to explainability agent format
230
+ # Ensemble uses: ensemble_prediction, ensemble_confidence
231
+ # ExplainabilityAgent expects: predicted_class, confidence_score
232
+ mapped_result = {
233
+ "predicted_class": diagnosis_result.get("ensemble_prediction", "Unknown"),
234
+ "confidence_score": diagnosis_result.get("ensemble_confidence", 0.0),
235
+ "fracture_detected": diagnosis_result.get("fracture_detected", True)
236
+ }
237
+
238
+ # Generate a random heatmap for demonstration
239
+ heatmap = generate_random_heatmap()
240
+
241
+ # Call with correct parameters
242
+ explanation = agent.generate_explanation(
243
+ diagnosis_result=mapped_result,
244
+ cam_array=heatmap
245
+ )
246
+ return explanation
247
+ except Exception as e:
248
+ return f"Error generating explanation: {str(e)}"
249
+
250
+
251
+ def run_knowledge_agent(diagnosis: str, confidence: float) -> Dict[str, Any]:
252
+ """Run the knowledge agent to retrieve medical information."""
253
+ try:
254
+ agent = KnowledgeAgent(knowledge_base=MEDICAL_KNOWLEDGE_BASE)
255
+ result = agent.get_medical_summary(diagnosis, confidence)
256
+ return result
257
+ except Exception as e:
258
+ return {"error": str(e)}
259
+
260
+
261
+ def run_complete_workflow(image_path: str) -> Dict[str, Any]:
262
+ """Run the complete workflow: Ensemble -> Education -> Knowledge."""
263
+ workflow_result = {
264
+ "ensemble_result": None,
265
+ "educational_result": None,
266
+ "knowledge_result": None,
267
+ "explanation_result": None
268
+ }
269
+
270
+ try:
271
+ # 1. Run Ensemble Agent
272
+ ensemble_result = run_ensemble_agent(image_path)
273
+ if "error" in ensemble_result:
274
+ return {"error": f"Ensemble failed: {ensemble_result['error']}"}
275
+
276
+ workflow_result["ensemble_result"] = ensemble_result
277
+
278
+ # 2. Run Educational Agent
279
+ educational_result = run_educational_agent(ensemble_result)
280
+ workflow_result["educational_result"] = educational_result
281
+
282
+ # 3. Run Explainability Agent
283
+ explanation = run_explainability_agent(ensemble_result)
284
+ workflow_result["explanation_result"] = explanation
285
+
286
+ # 4. Run Knowledge Agent
287
+ diagnosis = ensemble_result.get("ensemble_prediction", "Unknown")
288
+ confidence = ensemble_result.get("ensemble_confidence", 0.0)
289
+ knowledge_result = run_knowledge_agent(diagnosis, confidence)
290
+ workflow_result["knowledge_result"] = knowledge_result
291
+
292
+ return workflow_result
293
+ except Exception as e:
294
+ return {"error": str(e)}
295
+
296
+
297
+ # ============================================================================
298
+ # --- 3. Streamlit UI ---
299
+ # ============================================================================
300
+
301
+ def main():
302
+ """Main Streamlit application."""
303
+ st.title("🦴 AI Medical Assistant for Fracture Detection & Diagnosis")
304
+ st.info("⚠️ **Research/Educational Use Only** - This system is not approved for clinical use without professional oversight.")
305
+ st.markdown("---")
306
+
307
+ # Initialize session state
308
+ if "patient_context" not in st.session_state:
309
+ st.session_state.patient_context = {
310
+ "age": 45,
311
+ "gender": "Female",
312
+ "history": "No major past issues, but has mild osteoporosis."
313
+ }
314
+
315
+ # Initialize workflow results storage
316
+ if "workflow_result" not in st.session_state:
317
+ st.session_state.workflow_result = None
318
+
319
+ # --- Create Tabs ---
320
+ tab1, tab2, tab3, tab4, tab5 = st.tabs(
321
+ ["🏥 Single Agents", "⚙️ Complete Workflow", "💬 Patient Chat", "📋 Workflow Details", "ℹ️ About"]
322
+ )
323
+
324
+ # ========================================================================
325
+ # --- TAB 1: Individual Agents ---
326
+ # ========================================================================
327
+ with tab1:
328
+ st.header("Run Individual Agents")
329
+ st.markdown("Test each agent independently with sample diagnosis data.")
330
+
331
+ agent_choice = st.selectbox(
332
+ "Select an Agent",
333
+ ["Diagnostic Agent", "Ensemble Agent", "Educational Agent", "Explainability Agent", "Knowledge Agent"]
334
+ )
335
+
336
+ # Create columns for layout
337
+ col1, col2 = st.columns([2, 1])
338
+
339
+ with col1:
340
+ if agent_choice == "Diagnostic Agent":
341
+ st.subheader("🔍 Diagnostic Agent")
342
+ st.write("Runs a single model on an X-ray image to detect fractures.")
343
+
344
+ image_file = st.file_uploader("Upload X-ray image", type=["jpg", "png", "jpeg"])
345
+ if image_file and st.button("Run Diagnostic Agent"):
346
+ st.info("Note: Running this requires a valid model checkpoint at ./outputs/best_swin.pth")
347
+ with st.spinner("Running diagnostic agent..."):
348
+ image_path = save_uploaded_file(image_file)
349
+ result = run_diagnostic_agent(image_path)
350
+ st.json(result)
351
+
352
+ elif agent_choice == "Ensemble Agent":
353
+ st.subheader("🎯 Ensemble Agent (5 Models)")
354
+ st.write("Combines predictions from multiple models for robust diagnosis.")
355
+
356
+ image_file = st.file_uploader("Upload X-ray image", type=["jpg", "png", "jpeg"])
357
+ if image_file and st.button("Run Ensemble Agent"):
358
+ st.info("Note: Running this requires model checkpoints in ./outputs/")
359
+ with st.spinner("Running ensemble agent..."):
360
+ image_path = save_uploaded_file(image_file)
361
+ result = run_ensemble_agent(image_path)
362
+ st.json(result)
363
+
364
+ elif agent_choice == "Educational Agent":
365
+ st.subheader("📚 Educational Agent")
366
+ st.write("Translates technical diagnosis into patient-friendly language.")
367
+
368
+ # Sample diagnosis for demo
369
+ sample_diagnosis = {
370
+ "fracture_detected": True,
371
+ "predicted_class": "Transverse",
372
+ "confidence_score": 0.85,
373
+ "severity_type": "Transverse"
374
+ }
375
+
376
+ sample_explanation = "The bone shows a clear transverse break pattern."
377
+
378
+ if st.button("Run Educational Agent (Demo)"):
379
+ with st.spinner("Translating diagnosis..."):
380
+ result = run_educational_agent(sample_diagnosis, sample_explanation)
381
+ if isinstance(result, dict):
382
+ for key, value in result.items():
383
+ st.write(f"**{key}:**\n{value}")
384
+ else:
385
+ st.error(result)
386
+
387
+ elif agent_choice == "Explainability Agent":
388
+ st.subheader("🎨 Explainability Agent")
389
+ st.write("Generates human-readable explanations of model predictions.")
390
+
391
+ sample_diagnosis = {
392
+ "predicted_class": "Greenstick",
393
+ "confidence_score": 0.92,
394
+ "fracture_detected": True
395
+ }
396
+
397
+ if st.button("Run Explainability Agent (Demo)"):
398
+ with st.spinner("Generating explanation..."):
399
+ explanation = run_explainability_agent(sample_diagnosis)
400
+ st.write(explanation)
401
+
402
+ elif agent_choice == "Knowledge Agent":
403
+ st.subheader("🧠 Knowledge Agent")
404
+ st.write("Retrieves medical knowledge and guidelines for a diagnosis.")
405
+
406
+ diagnosis_input = st.selectbox("Select Diagnosis", CLASS_NAMES)
407
+ confidence_input = st.slider("Confidence Score", 0.0, 1.0, 0.85)
408
+
409
+ if st.button("Run Knowledge Agent"):
410
+ with st.spinner("Retrieving medical knowledge..."):
411
+ result = run_knowledge_agent(diagnosis_input, confidence_input)
412
+ if isinstance(result, dict):
413
+ st.json(result)
414
+ else:
415
+ st.error(result)
416
+
417
+ # ========================================================================
418
+ # --- TAB 2: Complete Workflow ---
419
+ # ========================================================================
420
+ with tab2:
421
+ st.header("Complete Diagnosis Workflow")
422
+ st.markdown("Upload an X-ray image and run the complete diagnostic pipeline.")
423
+
424
+ col1, col2 = st.columns([2, 1])
425
+
426
+ with col1:
427
+ st.subheader("📤 Upload X-ray Image")
428
+ image_file = st.file_uploader("Upload X-ray image for full diagnosis", type=["jpg", "png", "jpeg"])
429
+
430
+ if image_file:
431
+ st.image(image_file, caption="Uploaded Image", width='stretch')
432
+
433
+ with col2:
434
+ st.subheader("👤 Patient Information")
435
+ age = st.number_input("Age", min_value=1, max_value=120, value=st.session_state.patient_context["age"])
436
+ gender = st.selectbox("Gender", ["Male", "Female", "Other"],
437
+ index=0 if st.session_state.patient_context["gender"] == "Male" else
438
+ 1 if st.session_state.patient_context["gender"] == "Female" else 2)
439
+ history = st.text_area("Medical History", value=st.session_state.patient_context["history"])
440
+
441
+ st.session_state.patient_context = {"age": age, "gender": gender, "history": history}
442
+
443
+ if image_file and st.button("🚀 Run Complete Workflow", key="workflow"):
444
+ st.info("Note: Running this requires all model checkpoints in ./outputs/")
445
+ with st.spinner("Running complete diagnostic workflow..."):
446
+ image_path = save_uploaded_file(image_file)
447
+ workflow_result = run_complete_workflow(image_path)
448
+
449
+ # Store workflow result in session state for use in other tabs
450
+ st.session_state.workflow_result = workflow_result
451
+
452
+ if "error" in workflow_result:
453
+ st.error(f"❌ Error: {workflow_result['error']}")
454
+ else:
455
+ # Display results
456
+ st.success("✅ Workflow completed successfully!")
457
+
458
+ # Ensemble Results
459
+ if workflow_result["ensemble_result"]:
460
+ st.subheader("1️⃣ Ensemble Agent Results")
461
+ ensemble = workflow_result["ensemble_result"]
462
+ col1, col2, col3 = st.columns(3)
463
+ col1.metric("Prediction", ensemble.get("ensemble_prediction", "N/A"))
464
+ col2.metric("Confidence", f"{ensemble.get('ensemble_confidence', 0):.2%}")
465
+ col3.metric("Fracture Detected", "Yes" if ensemble.get("fracture_detected") else "No")
466
+
467
+ # Educational Results
468
+ if workflow_result["educational_result"]:
469
+ st.subheader("2️⃣ Patient-Friendly Summary")
470
+ educational = workflow_result["educational_result"]
471
+ for key, value in educational.items():
472
+ st.write(f"**{key}:**\n{value}")
473
+
474
+ # Explainability Results
475
+ if workflow_result["explanation_result"]:
476
+ st.subheader("3️⃣ Technical Explanation")
477
+ st.write(workflow_result["explanation_result"])
478
+
479
+ # Knowledge Results
480
+ if workflow_result["knowledge_result"]:
481
+ st.subheader("4️⃣ Medical Knowledge Base")
482
+ st.json(workflow_result["knowledge_result"])
483
+
484
+ # ========================================================================
485
+ # --- TAB 3: Patient Chat ---
486
+ # ========================================================================
487
+ with tab3:
488
+ st.header("💬 Patient Q&A with AI Assistant")
489
+ st.markdown("Ask questions about your fracture diagnosis (requires Ollama running)")
490
+
491
+ # Check if workflow has been run
492
+ if st.session_state.workflow_result is None or "error" in st.session_state.workflow_result:
493
+ st.info("ℹ️ Please run the 'Complete Workflow' first to generate a diagnosis for the chat feature.")
494
+ else:
495
+ # Check for Ollama availability
496
+ ollama_available = False
497
+ try:
498
+ response = requests.get(OLLAMA_CHECK_URL, timeout=2)
499
+ ollama_available = response.status_code == 200
500
+ except:
501
+ ollama_available = False
502
+
503
+ if not ollama_available:
504
+ st.warning("⚠️ Ollama server is not running. Please start Ollama to use the chat feature.")
505
+ st.info("Download Ollama from https://ollama.ai and run: ollama pull llama3")
506
+ else:
507
+ # Build medical summary from workflow results
508
+ ensemble_result = st.session_state.workflow_result.get("ensemble_result", {})
509
+ knowledge_result = st.session_state.workflow_result.get("knowledge_result", {})
510
+
511
+ diagnosis = ensemble_result.get("ensemble_prediction", "Unknown")
512
+ confidence = ensemble_result.get("ensemble_confidence", 0.0)
513
+
514
+ # Create medical summary from knowledge base
515
+ medical_summary = {
516
+ "Diagnosis": diagnosis,
517
+ "Ensemble_Confidence": f"{confidence:.2f}",
518
+ "Type": knowledge_result.get("Type", "Unknown fracture type"),
519
+ "Severity": knowledge_result.get("Severity", "Unknown"),
520
+ "Guidelines": knowledge_result.get("Guidelines", [])
521
+ }
522
+
523
+ try:
524
+ agent = PatientInteractionAgent(medical_summary, st.session_state.patient_context)
525
+
526
+ # Initialize chat history with diagnosis info
527
+ if "messages" not in st.session_state:
528
+ st.session_state.messages = []
529
+ st.session_state.messages.append({
530
+ "role": "assistant",
531
+ "content": f"Hello! I'm your AI medical assistant. I've reviewed your diagnosis: **{medical_summary['Diagnosis']}** (Confidence: {medical_summary['Ensemble_Confidence']}). How can I help answer your questions?"
532
+ })
533
+
534
+ # Display chat messages
535
+ for message in st.session_state.messages:
536
+ with st.chat_message(message["role"]):
537
+ st.markdown(message["content"])
538
+
539
+ # Accept user input
540
+ if prompt := st.chat_input("Ask a question about your diagnosis..."):
541
+ st.session_state.messages.append({"role": "user", "content": prompt})
542
+ with st.chat_message("user"):
543
+ st.markdown(prompt)
544
+
545
+ with st.chat_message("assistant"):
546
+ with st.spinner(f"Asking {OLLAMA_MODEL}..."):
547
+ response = agent.get_response(prompt)
548
+ st.markdown(response)
549
+
550
+ st.session_state.messages.append({"role": "assistant", "content": response})
551
+
552
+ except ConnectionError as e:
553
+ st.error(f"❌ Connection Error: {e}")
554
+ except Exception as e:
555
+ st.error(f"❌ Error: {e}")
556
+
557
+ # ========================================================================
558
+ # --- TAB 4: Workflow Details ---
559
+ # ========================================================================
560
+ with tab4:
561
+ st.header("📋 System Architecture & Workflow")
562
+
563
+ st.subheader("1. Ensemble Agent (Cross-Validation)")
564
+ st.write("""
565
+ - **Purpose**: Combines predictions from 5 different deep learning models
566
+ - **Models**: Swin, MobileNetV2, DenseNet169, EfficientNetV2, MaxViT
567
+ - **Output**: Ensemble prediction with confidence score
568
+ - **Benefit**: More robust and reliable predictions than single model
569
+ """)
570
+
571
+ st.subheader("2. Educational Agent")
572
+ st.write("""
573
+ - **Purpose**: Translates technical diagnosis into patient-friendly language
574
+ - **Input**: Diagnosis result from ensemble
575
+ - **Output**:
576
+ - Patient summary
577
+ - Severity assessment in simple terms
578
+ - Next steps and action plan
579
+ """)
580
+
581
+ st.subheader("3. Explainability Agent")
582
+ st.write("""
583
+ - **Purpose**: Generates visual and textual explanations of predictions
584
+ - **Input**: Diagnosis result and Grad-CAM heatmap
585
+ - **Output**: Human-readable explanation of what the model "saw"
586
+ """)
587
+
588
+ st.subheader("4. Knowledge Agent")
589
+ st.write("""
590
+ - **Purpose**: Retrieves medical knowledge for each diagnosis
591
+ - **Input**: Final diagnosis and confidence
592
+ - **Output**:
593
+ - Medical definition
594
+ - ICD code
595
+ - Treatment guidelines
596
+ - Severity level
597
+ """)
598
+
599
+ st.markdown("---")
600
+ st.markdown("### Workflow Pipeline")
601
+ st.markdown("""
602
+ ```
603
+ X-ray Image
604
+
605
+ [Ensemble Agent] → Diagnosis + Confidence
606
+
607
+ [Educational Agent] → Patient-Friendly Summary
608
+
609
+ [Explainability Agent] → Visual Explanation
610
+
611
+ [Knowledge Agent] → Medical Guidelines
612
+
613
+ Patient Report
614
+ ```
615
+ """)
616
+
617
+ # ========================================================================
618
+ # --- TAB 5: About ---
619
+ # ========================================================================
620
+ with tab5:
621
+ st.header("ℹ️ About This System")
622
+
623
+ st.markdown("""
624
+ ## MedAI - Explainable Fracture Detection
625
+
626
+ This application demonstrates an AI-powered medical diagnosis system designed to assist
627
+ healthcare professionals in fracture detection and patient education.
628
+
629
+ ### Features:
630
+ - 🎯 **Ensemble Learning**: 5 deep learning models for robust predictions
631
+ - 📚 **Patient Education**: Automatic translation of technical diagnoses
632
+ - 🎨 **Explainability**: Visual and textual explanations of AI decisions
633
+ - 🧠 **Knowledge Integration**: Evidence-based medical guidelines
634
+ - 💬 **LLM Integration**: Natural language Q&A with Llama 3
635
+
636
+ ### Supported Fracture Types:
637
+ - Comminuted
638
+ - Greenstick
639
+ - Oblique
640
+ - Oblique Displaced
641
+ - Spiral
642
+ - Transverse
643
+ - Transverse Displaced
644
+ - Healthy (No fracture)
645
+
646
+ ### Technical Stack:
647
+ - **Deep Learning**: PyTorch with timm models
648
+ - **Frontend**: Streamlit
649
+ - **LLM**: Llama 3 via Ollama
650
+ - **Explainability**: Grad-CAM + Centroid analysis
651
+
652
+ ### Disclaimer:
653
+ This system is for educational and research purposes. It should not be used
654
+ for actual medical diagnosis without proper clinical validation and oversight.
655
+ Always consult with qualified medical professionals for diagnosis and treatment.
656
+ """)
657
+
658
+ st.markdown("---")
659
+ st.info("📧 For more information, visit the project repository on GitHub.")
660
+
661
+
662
+ if __name__ == "__main__":
663
+ main()
outputs/.DS_Store ADDED
Binary file (8.2 kB). View file
 
outputs/best_densenet169.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3686669a0348a9043c9a697ec4a13463b4acc07649cb682cf44814ea6c3ac085
3
+ size 195498793
outputs/best_efficientnetv2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2979a53289c91ab60e2697141952c0ecb66174c8d3ae998ccb16b8065b5fc93f
3
+ size 195498793
outputs/best_maxvit.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4819e32d99287b04501f7e52fd80e9eff86a92584eeef7c394f70729180538a9
3
+ size 195498793
outputs/best_mobilenetv2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:efea091ddf0c43ce7cbca8a7eefe450974810d537984ae9521dc44ed3aa9055a
3
+ size 195498793
outputs/best_swin.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:095f19e8c238fda959f7412f6c39a4f43362ea03178efb1535f93dc8386763f7
3
+ size 195498793
requirements-prod.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core ML/DL
2
+ torch==2.9.0
3
+ torchvision==0.24.0
4
+ timm==1.0.20
5
+
6
+ # Web Framework
7
+ streamlit==1.51.0
8
+
9
+ # Data & ML Tools
10
+ numpy
11
+ pandas
12
+ scikit-learn==1.7.2
13
+ Pillow
14
+ opencv-python
15
+
16
+ # API & HTTP
17
+ requests
18
+
19
+ # Utilities
20
+ tqdm
21
+ pyyaml
src/__init__.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Source code for MedAI Fracture Detection System.
3
+ """
4
+
5
+ __version__ = "1.0.0"
6
+
7
+ # Import key utilities for easy access
8
+ from .utils import (
9
+ get_device,
10
+ require_mps,
11
+ DEVICE,
12
+ get_model,
13
+ get_transforms,
14
+ FractureDataset
15
+ )
16
+
17
+ __all__ = [
18
+ 'get_device',
19
+ 'require_mps',
20
+ 'DEVICE',
21
+ 'get_model',
22
+ 'get_transforms',
23
+ 'FractureDataset'
24
+ ]
src/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (532 Bytes). View file
 
src/agents/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Agent modules for fracture detection and diagnosis system.
3
+ """
4
+
5
+ from .diagnostic_agent import DiagnosticAgent
6
+ from .explain_agent import generate_random_heatmap, calculate_heatmap_centroid
7
+ from .educational_agent import EducationalAgent
8
+ from .knowledge_agent import KnowledgeAgent, MEDICAL_KNOWLEDGE_BASE
9
+ from .cross_validation_agent import ModelEnsembleAgent
10
+
11
+ __all__ = [
12
+ 'DiagnosticAgent',
13
+ 'generate_random_heatmap',
14
+ 'calculate_heatmap_centroid',
15
+ 'EducationalAgent',
16
+ 'KnowledgeAgent',
17
+ 'MEDICAL_KNOWLEDGE_BASE',
18
+ 'ModelEnsembleAgent'
19
+ ]
src/agents/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (785 Bytes). View file
 
src/agents/__pycache__/cross_validation_agent.cpython-311.pyc ADDED
Binary file (10.8 kB). View file
 
src/agents/__pycache__/diagnostic_agent.cpython-311.pyc ADDED
Binary file (8.22 kB). View file
 
src/agents/__pycache__/educational_agent.cpython-311.pyc ADDED
Binary file (6.68 kB). View file
 
src/agents/__pycache__/explain_agent.cpython-311.pyc ADDED
Binary file (7.19 kB). View file
 
src/agents/__pycache__/knowledge_agent.cpython-311.pyc ADDED
Binary file (5.22 kB). View file
 
src/agents/cross_validation_agent.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import argparse
4
+ import torch
5
+ import torch.nn as nn
6
+ import torchvision.transforms as T
7
+ import numpy as np
8
+ from PIL import Image
9
+ from typing import List, Dict, Any
10
+ import timm
11
+
12
+ # Add parent directory to path for imports
13
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
14
+
15
+ from src.utils import get_device, get_model, get_transforms
16
+
17
+ # ----------------------------------------------------------------------
18
+ # --- Global Variables ---
19
+ # ----------------------------------------------------------------------
20
+
21
+ DEVICE = get_device()
22
+ IMG_SIZE = 224
23
+
24
+ # ----------------------------------------------------------------------
25
+ # --- Model Ensemble Agent Core (with all fixes) ---
26
+ # ----------------------------------------------------------------------
27
+
28
+ class ModelEnsembleAgent:
29
+ def __init__(self, model_names: List[str], checkpoints_dir: str, num_classes: int, class_names: List[str]):
30
+ self.models = {}
31
+ self.model_names = model_names
32
+ self.num_classes = num_classes
33
+ self.class_names = class_names
34
+ self.transforms = get_transforms('val', IMG_SIZE)
35
+
36
+ self.device = DEVICE
37
+ self._load_all_models(checkpoints_dir)
38
+
39
+ def _load_all_models(self, checkpoints_dir: str):
40
+ """Loads all specified model checkpoints with strict=False fallback."""
41
+ print(f"Loading {len(self.model_names)} models from {checkpoints_dir} on {self.device}...")
42
+
43
+ for name in self.model_names:
44
+
45
+ # FIX: Corrected file naming convention (best_modelname.pth)
46
+ checkpoint_path = os.path.join(checkpoints_dir, f"best_{name}.pth")
47
+
48
+ print(f" Attempting to load {name} from expected path: {checkpoint_path}...")
49
+
50
+ try:
51
+ model = get_model(name, self.num_classes, pretrained=False).to(self.device)
52
+
53
+ checkpoint = torch.load(checkpoint_path, map_location=self.device)
54
+ state_dict = checkpoint.get('model_state_dict', checkpoint)
55
+
56
+ # FIX: Filter out incompatible head layers that have size mismatches
57
+ # This handles cases where checkpoint was trained with different head architecture
58
+ model_state = model.state_dict()
59
+ filtered_state_dict = {}
60
+ for key, value in state_dict.items():
61
+ if key in model_state and model_state[key].shape == value.shape:
62
+ filtered_state_dict[key] = value
63
+ elif key not in model_state:
64
+ # Key doesn't exist in current model, skip it
65
+ pass
66
+ else:
67
+ # Shape mismatch, skip this layer (usually head layers)
68
+ print(f" (Skipping layer '{key}' due to shape mismatch: {value.shape} vs {model_state[key].shape})")
69
+
70
+ # Load only compatible layers
71
+ model.load_state_dict(filtered_state_dict, strict=False)
72
+
73
+ model.eval()
74
+ self.models[name] = model
75
+ print(f" ✅ Successfully loaded {name}.")
76
+
77
+ except FileNotFoundError:
78
+ print(f" ❌ Checkpoint not found at: {checkpoint_path}. Skipping.")
79
+ except Exception as e:
80
+ # FIX: Detailed error reporting to show the full RuntimeError message
81
+ print(f" ❌ Failed to load {name}. Error: {e.__class__.__name__}. Details: {e}. Skipping.")
82
+
83
+ if not self.models:
84
+ raise RuntimeError("No models were successfully loaded. Cannot run ensemble.")
85
+
86
+ @torch.no_grad()
87
+ def run_ensemble(self, image_path: str) -> Dict[str, Any]:
88
+ """Runs inference across all loaded models and computes the ensemble prediction."""
89
+
90
+ try:
91
+ image = Image.open(image_path).convert('RGB')
92
+ input_tensor = self.transforms(image).unsqueeze(0).to(self.device)
93
+ except Exception as e:
94
+ return {"error": f"Failed to load or process image: {e}"}
95
+
96
+ all_probs = []
97
+ individual_predictions = {}
98
+
99
+ for name, model in self.models.items():
100
+ outputs = model(input_tensor)
101
+ probs = torch.softmax(outputs, dim=1).cpu().numpy()[0]
102
+
103
+ all_probs.append(probs)
104
+
105
+ pred_idx = np.argmax(probs)
106
+ pred_conf = probs[pred_idx]
107
+
108
+ individual_predictions[name] = {
109
+ "class": self.class_names[pred_idx],
110
+ "confidence": float(pred_conf)
111
+ }
112
+
113
+ # Ensemble Decision (Weighted Voting)
114
+ # Use max confidence from each model as the weight
115
+ weights = np.array([np.max(probs) for probs in all_probs])
116
+ # Normalize weights
117
+ weights = weights / np.sum(weights)
118
+
119
+ # Weighted average of probabilities
120
+ weighted_avg_probs = np.average(all_probs, axis=0, weights=weights)
121
+ ensemble_idx = np.argmax(weighted_avg_probs)
122
+ ensemble_confidence = weighted_avg_probs[ensemble_idx]
123
+ ensemble_class = self.class_names[ensemble_idx]
124
+
125
+ return {
126
+ "image_path": image_path,
127
+ "ensemble_prediction": ensemble_class,
128
+ "ensemble_confidence": float(ensemble_confidence),
129
+ "individual_predictions": individual_predictions,
130
+ "fracture_detected": ensemble_class != "Healthy"
131
+ }
132
+
133
+ # ----------------------------------------------------------------------
134
+ # --- Execution Block ---
135
+ # ----------------------------------------------------------------------
136
+
137
+ if __name__ == '__main__':
138
+ parser = argparse.ArgumentParser(description='Multi-Model Ensemble (Cross-Validation) Agent.')
139
+ parser.add_argument('--image-path', required=True, help='Path to the image for inference.')
140
+ parser.add_argument('--checkpoints-dir', required=True, # Made required since default path was confusing
141
+ help='Absolute path to the directory containing the model checkpoints (e.g., best_swin.pth).')
142
+ parser.add_argument('--models', type=str, default='swin,mobilenetv2,efficientnetv2,maxvit,densenet169',
143
+ help='Comma-separated names of the models to load.')
144
+ parser.add_argument('--num-classes', type=int, default=8)
145
+ parser.add_argument('--class-names', required=True,
146
+ help='Comma-separated list of class names.')
147
+
148
+ args = parser.parse_args()
149
+
150
+ models_list = [m.strip() for m in args.models.split(',')]
151
+ class_names_list = [c.strip() for c in args.class_names.split(',')]
152
+
153
+ try:
154
+ ensemble_agent = ModelEnsembleAgent(
155
+ model_names=models_list,
156
+ checkpoints_dir=args.checkpoints_dir,
157
+ num_classes=args.num_classes,
158
+ class_names=class_names_list
159
+ )
160
+ except RuntimeError as e:
161
+ print(f"\nFATAL ERROR during initialization: {e}")
162
+ exit(1)
163
+
164
+ result = ensemble_agent.run_ensemble(args.image_path)
165
+
166
+ print("\n--- ENSEMBLE AGENT RESULT ---")
167
+ if "error" in result:
168
+ print(f"Error: {result['error']}")
169
+ else:
170
+ print(f"Image: {os.path.basename(result['image_path'])}")
171
+ print(f"FINAL ENSEMBLE PREDICTION: **{result['ensemble_prediction']}** (Confidence: {result['ensemble_confidence']:.4f})")
172
+
173
+ print("\nIndividual Model Predictions:")
174
+ loaded_model_names = ensemble_agent.models.keys()
175
+
176
+ for name in models_list:
177
+ if name in loaded_model_names:
178
+ pred = result['individual_predictions'][name]
179
+ print(f" {name.upper():<15}: {pred['class']:<20} (Conf: {pred['confidence']:.4f})")
180
+ else:
181
+ print(f" {name.upper():<15}: (Skipped/Failed to Load)")
182
+
183
+ print("-----------------------------\n")
src/agents/diagnostic_agent.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import argparse
4
+ import torch
5
+ from PIL import Image
6
+ from typing import Dict, Any, List
7
+
8
+ # Add parent directory to path for imports
9
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
10
+
11
+ # --- 1. CONFIGURATION ---
12
+ from src.utils import get_device, get_model, get_transforms
13
+
14
+ DEVICE = get_device()
15
+
16
+ # --- 2. DIAGNOSTIC AGENT CORE ---
17
+
18
+ class DiagnosticAgent:
19
+ def __init__(self, checkpoint_path: str, model_name: str, num_classes: int, img_size: int, class_names: List[str]):
20
+ self.device = DEVICE
21
+ self.img_size = img_size
22
+ self.class_names = class_names
23
+ self.model_name = model_name
24
+
25
+ # 1. Load Model Architecture
26
+ self.model = get_model(model_name, num_classes, pretrained=False).to(self.device)
27
+
28
+ # 2. Load Weights from Checkpoint
29
+ try:
30
+ ck = torch.load(checkpoint_path, map_location=self.device)
31
+ state_dict = ck.get('model_state_dict', ck)
32
+ self.model.load_state_dict(state_dict)
33
+ self.model.eval()
34
+ print(f"✅ Diagnostic Agent loaded model from {checkpoint_path} on {self.device}.")
35
+ except FileNotFoundError:
36
+ print(f"❌ Error: Checkpoint file not found at {checkpoint_path}")
37
+ exit(1)
38
+ except Exception as e:
39
+ print(f"❌ Error loading model state: {e}")
40
+ exit(1)
41
+
42
+ # 3. Setup Transforms
43
+ self.transform = get_transforms('val', self.img_size)
44
+
45
+ def run_diagnosis(self, image_path: str) -> Dict[str, Any]:
46
+ """
47
+ Runs the image classification model, detects fractures, and outputs scores.
48
+
49
+ This method includes the fix for FileNotFoundError by resolving the path.
50
+ """
51
+
52
+ # CRITICAL FIX: Convert relative path to absolute path for reliable file access
53
+ full_image_path = os.path.abspath(image_path)
54
+
55
+ if not os.path.exists(full_image_path):
56
+ # Report the original path back to the user for clarity
57
+ return {"error": f"Image file not found at {image_path}"}
58
+
59
+ # 1. Image Loading and Preprocessing
60
+ try:
61
+ # Use the resolved full path for PIL to open
62
+ img = Image.open(full_image_path).convert('RGB')
63
+ except Exception as e:
64
+ return {"error": f"Failed to open image at {full_image_path}. Reason: {e}"}
65
+
66
+
67
+ img_tensor = self.transform(img).unsqueeze(0).to(self.device)
68
+
69
+ # 2. Model Inference
70
+ with torch.no_grad():
71
+ outputs = self.model(img_tensor)
72
+
73
+ # Softmax to get probabilities (confidence scores)
74
+ probabilities = torch.softmax(outputs, dim=1).squeeze(0)
75
+
76
+ # 3. Score Calculation
77
+
78
+ predicted_idx = torch.argmax(probabilities).item()
79
+ confidence = probabilities[predicted_idx].item()
80
+ uncertainty = 1.0 - confidence
81
+ predicted_class_name = self.class_names[predicted_idx]
82
+
83
+ # Determine Fracture Presence (assuming 'Healthy' is a known class)
84
+ is_fracture_detected = (predicted_class_name != 'Healthy')
85
+
86
+ return {
87
+ "image_path": image_path,
88
+ "fracture_detected": is_fracture_detected,
89
+ "predicted_class": predicted_class_name,
90
+ "severity_type": predicted_class_name, # Proxy for severity
91
+ "confidence_score": confidence,
92
+ "uncertainty_score": uncertainty,
93
+ "all_probabilities": probabilities.cpu().numpy().tolist()
94
+ }
95
+
96
+ # --- 3. EXECUTION ---
97
+
98
+ if __name__ == '__main__':
99
+ parser = argparse.ArgumentParser(description='Run a diagnostic agent on a single image.')
100
+ parser.add_argument('--image-path', type=str, required=True, help='Path to the image file to diagnose.')
101
+ parser.add_argument('--checkpoint', type=str, required=True, help='Path to the model checkpoint (e.g., outputs/swin_mps/best.pth)')
102
+ parser.add_argument('--model', type=str, default='swin', choices=['swin', 'convnext', 'densenet'])
103
+ parser.add_argument('--num-classes', type=int, default=8)
104
+ parser.add_argument('--img-size', type=int, default=224)
105
+ parser.add_argument('--class-names', type=str, required=True,
106
+ help='Comma-separated list of class names (e.g., "A,B,C")')
107
+
108
+ args = parser.parse_args()
109
+
110
+ # Convert class names string to a list
111
+ class_names_list = [c.strip() for c in args.class_names.split(',')]
112
+
113
+ # Ensure 'Healthy' is in the list for the 'fracture_detected' check to work reliably
114
+ if 'Healthy' not in class_names_list:
115
+ print("Warning: 'Healthy' class not found in --class-names list. Fracture detection may be inaccurate.")
116
+
117
+ # Initialize the Agent
118
+ agent = DiagnosticAgent(
119
+ checkpoint_path=args.checkpoint,
120
+ model_name=args.model,
121
+ num_classes=args.num_classes,
122
+ img_size=args.img_size,
123
+ class_names=class_names_list
124
+ )
125
+
126
+ # Run the Diagnosis
127
+ result = agent.run_diagnosis(args.image_path)
128
+
129
+ # Output Results
130
+ print("\n--- DIAGNOSTIC RESULTS ---")
131
+ if "error" in result:
132
+ print(f"Status: FAILED\nReason: {result['error']}")
133
+ else:
134
+ print(f"Status: SUCCESS")
135
+ print(f"Image: {result['image_path']}")
136
+ print(f"Fracture Detected: {'YES' if result['fracture_detected'] else 'NO'}")
137
+ print(f"Predicted Class: {result['predicted_class']}")
138
+ print(f"--- Scores ---")
139
+ print(f"Severity Type: {result['severity_type']}")
140
+ print(f"Confidence Score: {result['confidence_score']:.4f}")
141
+ print(f"Uncertainty Score: {result['uncertainty_score']:.4f}")
142
+ print("--------------------------\n")
src/agents/educational_agent.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from typing import Dict, Any
3
+
4
+ class EducationalAgent:
5
+ """
6
+ Translates technical diagnosis and explanation into simple, patient-friendly terms.
7
+ """
8
+ def __init__(self, doctor_name: str = "your treating doctor"):
9
+ self.doctor_name = doctor_name
10
+
11
+ def translate_to_layman_terms(self, diagnosis_result: Dict[str, Any], explanation_text: str) -> Dict[str, str]:
12
+ """
13
+ Generates simple summaries and next steps for the patient.
14
+
15
+ Args:
16
+ diagnosis_result: The dictionary output from DiagnosticAgent.
17
+ explanation_text: The string output from ExplainabilityAgent.
18
+
19
+ Returns:
20
+ A dictionary containing patient-friendly summary, severity, and next steps.
21
+ """
22
+
23
+ # 1. Extract Key Findings
24
+ fracture_detected = diagnosis_result.get("fracture_detected", False)
25
+ predicted_class = diagnosis_result.get("predicted_class", "a specific type of injury")
26
+ confidence = diagnosis_result.get("confidence_score", 0.0)
27
+
28
+ # 2. Determine Severity in Layman Terms
29
+ severity_map = {
30
+ "Healthy": "None",
31
+ "Greenstick": "Mild (The bone is cracked but not completely broken through.)",
32
+ "Transverse": "Moderate (A clean break straight across the bone.)",
33
+ "Oblique": "Moderate (A clean break at an angle.)",
34
+ "Spiral": "Serious (A twisting break that spirals around the bone.)",
35
+ "Comminuted": "Severe (The bone has broken into three or more pieces.)",
36
+ "Oblique Displaced": "Serious (The bone is broken at an angle, and the pieces are shifted out of place.)",
37
+ "Transverse Displaced": "Serious (The bone is broken straight across, and the pieces are shifted out of place.)",
38
+ }
39
+
40
+ layman_severity = severity_map.get(predicted_class, "We need more information on this type of break.")
41
+
42
+ # 3. Simplify the Explanation
43
+
44
+ # Clean up the technical explanation to remove technical jargon like 'centroid' or 'activation'
45
+ simple_explanation = explanation_text.replace("consistent with a", "which looks like a")
46
+ simple_explanation = simple_explanation.replace("Confidence:", "Our computer model is highly sure (")
47
+ simple_explanation = simple_explanation.replace("The model's focus is", "The computer saw a clear sign of this")
48
+ simple_explanation = simple_explanation.replace("distal end", "end of the bone near the hand/foot")
49
+ simple_explanation = simple_explanation.replace("proximal end", "end of the bone near the shoulder/hip")
50
+ simple_explanation = simple_explanation.replace("humerus", "upper arm bone")
51
+ simple_explanation = simple_explanation.replace("radius", "lower arm bone")
52
+ simple_explanation = simple_explanation.replace("tibia", "shin bone")
53
+ simple_explanation = simple_explanation.replace("mild", "small")
54
+ simple_explanation = simple_explanation.replace("strong", "very clear")
55
+
56
+
57
+ # 4. Generate Final Summary and Next Steps
58
+
59
+ if not fracture_detected or predicted_class == "Healthy":
60
+ patient_summary = (
61
+ f"**Great news!** Our analysis suggests your bone is **healthy** "
62
+ f"with high confidence ({confidence:.2f}). There are no signs of a fracture."
63
+ )
64
+ next_steps = (
65
+ "You can discuss your pain symptoms with your doctor, but based on this image, "
66
+ "a fracture is highly unlikely. No immediate orthopedic action is needed."
67
+ )
68
+ else:
69
+ patient_summary = (
70
+ f"Our computer analysis strongly indicates a **break in the bone** (a fracture). "
71
+ f"The specific type appears to be a **{predicted_class}** fracture."
72
+ )
73
+
74
+ # Combine simple explanation and confidence
75
+ patient_summary += f"\n\n**What the computer saw:** {simple_explanation}"
76
+ patient_summary += f".\n\n**Severity Level:** {layman_severity}"
77
+
78
+ next_steps = (
79
+ "This finding requires immediate medical follow-up. Please do the following:\n"
80
+ f"* **Do not move** the affected area.\n"
81
+ f"* **Immediately share these findings** with {self.doctor_name}.\n"
82
+ f"* Your doctor will confirm the diagnosis and determine the best treatment, "
83
+ "which may involve a cast, splint, or surgery."
84
+ )
85
+
86
+ return {
87
+ "patient_summary": patient_summary,
88
+ "patient_severity_assessment": layman_severity,
89
+ "next_steps_action_plan": next_steps,
90
+ }
91
+
92
+ # --- EXAMPLE USAGE ---
93
+
94
+ if __name__ == '__main__':
95
+ # --- SIMULATED INPUT from Diagnostic & Explainability Agents ---
96
+
97
+ # Example 1: Serious Fracture
98
+ SIMULATED_DIAGNOSIS_1 = {
99
+ "image_path": "fracture_image.jpg",
100
+ "fracture_detected": True,
101
+ "predicted_class": "Spiral",
102
+ "severity_type": "Spiral",
103
+ "confidence_score": 0.96,
104
+ "uncertainty_score": 0.04,
105
+ "all_probabilities": [0.01, 0.01, 0.01, 0.01, 0.01, 0.96, 0.01, 0.01]
106
+ }
107
+ SIMULATED_EXPLANATION_1 = (
108
+ "A fracture pattern consistent with a **Spiral** type is detected (Confidence: 0.96). "
109
+ "The model's focus is clear near the **middle region** of the humerus in the center. "
110
+ "This is based on a distinct linear focus."
111
+ )
112
+
113
+ # Example 2: Healthy Bone
114
+ SIMULATED_DIAGNOSIS_2 = {
115
+ "image_path": "healthy_image.jpg",
116
+ "fracture_detected": False,
117
+ "predicted_class": "Healthy",
118
+ "severity_type": "Healthy",
119
+ "confidence_score": 0.99,
120
+ "uncertainty_score": 0.01,
121
+ "all_probabilities": [0.00, 0.00, 0.99, 0.00, 0.00, 0.00, 0.00, 0.01]
122
+ }
123
+ SIMULATED_EXPLANATION_2 = (
124
+ "The bone appears **healthy** with high confidence (0.99). No fracture pattern was detected."
125
+ )
126
+
127
+
128
+ # --- Run Agent ---
129
+
130
+ agent = EducationalAgent(doctor_name="Dr. Smith")
131
+
132
+ # Run Example 1
133
+ results_1 = agent.translate_to_layman_terms(SIMULATED_DIAGNOSIS_1, SIMULATED_EXPLANATION_1)
134
+
135
+ print("\n--- PATIENT REPORT (FRACTURE DETECTED) ---")
136
+ print(f"**SUMMARY:** {results_1['patient_summary']}")
137
+ print("\n**ACTION PLAN:**")
138
+ print(results_1['next_steps_action_plan'])
139
+ print("-------------------------------------------\n")
140
+
141
+ # Run Example 2
142
+ results_2 = agent.translate_to_layman_terms(SIMULATED_DIAGNOSIS_2, SIMULATED_EXPLANATION_2)
143
+
144
+ print("\n--- PATIENT REPORT (HEALTHY BONE) ---")
145
+ print(f"**SUMMARY:** {results_2['patient_summary']}")
146
+ print("\n**ACTION PLAN:**")
147
+ print(results_2['next_steps_action_plan'])
148
+ print("-------------------------------------\n")
src/agents/explain_agent.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import json
3
+ from typing import Dict, Any
4
+
5
+ # --- New Helper Function for Dynamic Testing ---
6
+ def generate_random_heatmap(size: int = 224) -> np.ndarray:
7
+ """
8
+ Generates a randomized, plausible heatmap array for testing the agent's dynamism.
9
+ The heatmap will have a focused, high-intensity area somewhere random.
10
+ """
11
+ # Create a base array of zeros
12
+ cam_array = np.zeros((size, size), dtype=np.float32)
13
+
14
+ # 1. Define random center and size for the activation zone
15
+ center_y = np.random.randint(size // 4, size * 3 // 4)
16
+ center_x = np.random.randint(size // 4, size * 3 // 4)
17
+ height = np.random.randint(30, 80)
18
+ width = np.random.randint(30, 80)
19
+
20
+ # Define activation bounds (ensure they stay within the array limits)
21
+ y_min = max(0, center_y - height // 2)
22
+ y_max = min(size, center_y + height // 2)
23
+ x_min = max(0, center_x - width // 2)
24
+ x_max = min(size, center_x + width // 2)
25
+
26
+ # 2. Apply activation with random strength
27
+ random_strength = np.random.uniform(0.6, 1.0)
28
+ cam_array[y_min:y_max, x_min:x_max] = random_strength
29
+
30
+ # Optional: Add minor noise to make it less blocky
31
+ cam_array = cam_array + np.random.uniform(0, 0.1, (size, size))
32
+ cam_array = np.clip(cam_array, 0, 1)
33
+
34
+ return cam_array
35
+
36
+ # --- Helper function for localization (No changes needed, it is dynamic) ---
37
+
38
+ def calculate_heatmap_centroid(cam_array: np.ndarray, threshold: float = 0.5) -> tuple:
39
+ """
40
+ Calculates the centroid (center of mass) of the significant activation area
41
+ in the Grad-CAM heatmap.
42
+ """
43
+ # 1. Apply threshold to isolate the 'hot' region
44
+ binary_map = cam_array > threshold
45
+
46
+ if not np.any(binary_map):
47
+ return (0.5, 0.5, 0.0)
48
+
49
+ # 2. Calculate coordinates and weights (activation values)
50
+ coords = np.argwhere(binary_map)
51
+ weights = cam_array[binary_map]
52
+
53
+ if len(weights) == 0:
54
+ return (0.5, 0.5, 0.0)
55
+
56
+ # 3. Calculate weighted average for the centroid
57
+ y_coords = coords[:, 0] # Rows (Y)
58
+ x_coords = coords[:, 1] # Columns (X)
59
+
60
+ sum_weights = np.sum(weights)
61
+
62
+ centroid_x = np.sum(x_coords * weights) / sum_weights
63
+ centroid_y = np.sum(y_coords * weights) / sum_weights
64
+
65
+ # Normalize to [0, 1] based on map size
66
+ h, w = cam_array.shape
67
+ norm_x = centroid_x / w
68
+ norm_y = centroid_y / h
69
+
70
+ max_activation = np.max(weights)
71
+
72
+ return (norm_x, norm_y, max_activation)
73
+
74
+ # --- Explainability Agent Core (No changes needed, logic is dynamic) ---
75
+
76
+ class ExplainabilityAgent:
77
+ def __init__(self, class_names: list, body_part: str = "bone"):
78
+ self.class_names = class_names
79
+ self.body_part = body_part
80
+
81
+ def generate_explanation(self, diagnosis_result: Dict[str, Any], cam_array: np.ndarray) -> str:
82
+ """
83
+ Converts the Grad-CAM heatmap and prediction result into a textual explanation.
84
+ """
85
+ predicted_class = diagnosis_result.get("predicted_class", "Unknown")
86
+ confidence = diagnosis_result.get("confidence_score", 0.0)
87
+
88
+ # 1. Analyze Heatmap
89
+ norm_x, norm_y, strength = calculate_heatmap_centroid(cam_array, threshold=0.4)
90
+
91
+ # Determine general location (Simplified)
92
+ x_loc = "right side" if norm_x > 0.65 else ("left side" if norm_x < 0.35 else "center")
93
+ y_loc = "distal end" if norm_y > 0.65 else ("proximal end" if norm_y < 0.35 else "middle region")
94
+
95
+ # 2. Build Textual Explanation based on Prediction
96
+
97
+ if predicted_class == "Healthy":
98
+ if confidence > 0.90:
99
+ return f"The {self.body_part} appears **healthy** with high confidence ({confidence:.2f}). No fracture pattern was detected."
100
+ else:
101
+ return f"The {self.body_part} is likely **healthy** ({confidence:.2f}), though there is some low activation in the {y_loc} of the {x_loc} that warrants a closer look."
102
+
103
+ if not diagnosis_result.get("fracture_detected", True): # Default to True if key missing
104
+ return f"Diagnosis is **inconclusive** or data is missing."
105
+
106
+ # 3. Explanation for Detected Fracture
107
+
108
+ intro = f"A fracture pattern consistent with a **{predicted_class}** type is detected"
109
+
110
+ # Strength adjective
111
+ if strength > 0.7:
112
+ strength_adj = "strong"
113
+ elif strength > 0.5:
114
+ strength_adj = "clear"
115
+ else:
116
+ strength_adj = "mild"
117
+
118
+ # Confidence statement
119
+ confidence_stmt = f"(Confidence: {confidence:.2f})"
120
+
121
+ # Location statement
122
+ location_stmt = f"near the **{y_loc}** of the {self.body_part} in the {x_loc}."
123
+
124
+ # Final Assembly
125
+ explanation = f"{intro} {confidence_stmt}. The model's focus is {strength_adj} {location_stmt}"
126
+
127
+ # Add a note on the type based on visual focus
128
+ if predicted_class in ["Transverse", "Oblique"]:
129
+ explanation += " This is based on a distinct linear focus."
130
+
131
+ return explanation
132
+
133
+
134
+ # --- 4. EXAMPLE USAGE ---
135
+
136
+ if __name__ == '__main__':
137
+
138
+ # --- SIMULATED INPUT ---
139
+ SIMULATED_RESULT = {
140
+ "image_path": "test_image.jpg",
141
+ "fracture_detected": True,
142
+ "predicted_class": "Spiral",
143
+ "severity_type": "Spiral",
144
+ "confidence_score": 0.95,
145
+ "uncertainty_score": 0.05,
146
+ }
147
+
148
+ CLASS_NAMES = ["Comminuted", "Greenstick", "Healthy", "Oblique", "Oblique Displaced", "Spiral", "Transverse", "Transverse Displaced"]
149
+ explainer = ExplainabilityAgent(class_names=CLASS_NAMES, body_part="humerus")
150
+
151
+ # Run 3 times to demonstrate dynamic output
152
+ print("\n--- Testing Dynamic Output (Run 1: Random Heatmap) ---")
153
+
154
+ # Use the new dynamic heatmap function!
155
+ dynamic_cam_1 = generate_random_heatmap()
156
+ explanation_text_1 = explainer.generate_explanation(SIMULATED_RESULT, dynamic_cam_1)
157
+ print(f"Explanation 1: {explanation_text_1}")
158
+
159
+ print("\n--- Testing Dynamic Output (Run 2: Another Random Heatmap) ---")
160
+ dynamic_cam_2 = generate_random_heatmap()
161
+ explanation_text_2 = explainer.generate_explanation(SIMULATED_RESULT, dynamic_cam_2)
162
+ print(f"Explanation 2: {explanation_text_2}")
163
+
164
+ print("--------------------------------------------------\n")
src/agents/knowledge_agent.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from typing import Dict, Any, List
3
+
4
+ # --- Pre-compiled Medical Knowledge Base (Simulated) ---
5
+ # In a real application, this would be a large database (e.g., SQL, MongoDB, or specialized API)
6
+ MEDICAL_KNOWLEDGE_BASE = {
7
+ "Comminuted": {
8
+ "definition": "A fracture where the bone is broken into three or more pieces.",
9
+ "icd_code": "S52.5",
10
+ "severity": "High",
11
+ "treatment_guidelines": ["Usually requires surgical intervention (ORIF - Open Reduction Internal Fixation).", "Long immobilization time (8-12 weeks).", "Requires physical therapy."],
12
+ "prognosis_notes": "Risk of non-union is higher. Full recovery may take 6+ months."
13
+ },
14
+ "Greenstick": {
15
+ "definition": "A partial fracture where the bone is cracked but not completely broken through. Common in children.",
16
+ "icd_code": "S52.3",
17
+ "severity": "Low-Moderate",
18
+ "treatment_guidelines": ["Immobilization with cast or splint.", "Careful monitoring for progression.", "Minimal surgical intervention usually needed."],
19
+ "prognosis_notes": "Generally good prognosis. Recovery typically within 4-6 weeks."
20
+ },
21
+ "Healthy": {
22
+ "definition": "No evidence of fracture. Bone appears normal.",
23
+ "icd_code": "Z00.0",
24
+ "severity": "None",
25
+ "treatment_guidelines": ["No treatment required.", "Continue normal activities as tolerated.", "Regular follow-up if there is persistent pain."],
26
+ "prognosis_notes": "Normal bone health. No intervention needed."
27
+ },
28
+ "Oblique": {
29
+ "definition": "A diagonal break across the bone at approximately 45 degrees.",
30
+ "icd_code": "S52.2",
31
+ "severity": "Moderate",
32
+ "treatment_guidelines": ["Immobilization with cast or splint.", "Regular X-rays to monitor healing.", "Physical therapy after immobilization period."],
33
+ "prognosis_notes": "Good prognosis with proper immobilization. Recovery typically 6-8 weeks."
34
+ },
35
+ "Oblique Displaced": {
36
+ "definition": "A diagonal break where the bone fragments are not aligned and have shifted out of place.",
37
+ "icd_code": "S52.9",
38
+ "severity": "Medium-High",
39
+ "treatment_guidelines": ["Requires reduction (closed or open).", "Often requires casting or sometimes surgery to stabilize.", "Regular X-rays to ensure proper alignment."],
40
+ "prognosis_notes": "Good prognosis if successfully reduced and stabilized. Recovery 8-12 weeks."
41
+ },
42
+ "Spiral": {
43
+ "definition": "A twisting break that spirals around the bone, typically caused by rotational forces.",
44
+ "icd_code": "S52.4",
45
+ "severity": "Serious",
46
+ "treatment_guidelines": ["Usually requires immobilization in a cast or brace.", "May require surgery if fragments are unstable.", "Requires extensive physical therapy."],
47
+ "prognosis_notes": "Variable recovery time. May take 8-16 weeks depending on severity."
48
+ },
49
+ "Transverse": {
50
+ "definition": "A clean break straight across the bone, perpendicular to the bone's long axis.",
51
+ "icd_code": "S52.1",
52
+ "severity": "Moderate",
53
+ "treatment_guidelines": ["Immobilization with cast or splint.", "Regular X-rays to monitor alignment.", "Physical therapy after healing begins."],
54
+ "prognosis_notes": "Good prognosis. Clean breaks typically heal well. Recovery 6-10 weeks."
55
+ },
56
+ "Transverse Displaced": {
57
+ "definition": "A straight break across the bone with fragments shifted out of place.",
58
+ "icd_code": "S52.8",
59
+ "severity": "Serious",
60
+ "treatment_guidelines": ["Requires reduction (closed or open).", "Often requires surgery to realign fragments.", "Long-term immobilization and rehabilitation."],
61
+ "prognosis_notes": "Good prognosis with treatment. Recovery 10-14 weeks."
62
+ }
63
+ }
64
+
65
+ class KnowledgeAgent:
66
+ def __init__(self, knowledge_base: Dict[str, Any]):
67
+ self.knowledge_base = knowledge_base
68
+
69
+ def get_medical_summary(self, diagnosis: str, confidence: float) -> Dict[str, Any]:
70
+ """
71
+ Retrieves and formats external medical knowledge based on the final diagnosis.
72
+ """
73
+ diagnosis = diagnosis.strip()
74
+
75
+ if diagnosis not in self.knowledge_base:
76
+ return {"error": "Diagnosis not found in the knowledge base."}
77
+
78
+ # 1. Retrieve Raw Data
79
+ raw_data = self.knowledge_base[diagnosis]
80
+
81
+ # 2. Format Summary for Professional Use (Example output)
82
+ summary = {
83
+ "Diagnosis": diagnosis,
84
+ "Ensemble_Confidence": f"{confidence:.2f}",
85
+ "Type": raw_data.get("definition"),
86
+ "ICD_Code": raw_data.get("icd_code", "N/A"),
87
+ "Severity": raw_data.get("severity"),
88
+ "Guidelines": raw_data.get("treatment_guidelines")
89
+ }
90
+
91
+ return summary
92
+
93
+ # --- Example Usage (Integration with Cross-Validation Agent Output) ---
94
+ if __name__ == '__main__':
95
+ # Assume this is the output from your cross_validation_agent:
96
+ cross_validation_result = {
97
+ "ensemble_prediction": "Oblique Displaced",
98
+ "ensemble_confidence": 0.85
99
+ }
100
+
101
+ agent = KnowledgeAgent(MEDICAL_KNOWLEDGE_BASE)
102
+
103
+ medical_report = agent.get_medical_summary(
104
+ diagnosis=cross_validation_result["ensemble_prediction"],
105
+ confidence=cross_validation_result["ensemble_confidence"]
106
+ )
107
+
108
+ print("\n--- 🧠 KNOWLEDGE AGENT REPORT ---")
109
+ print(json.dumps(medical_report, indent=4))
src/analysis/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Analysis and visualization modules for model evaluation.
3
+ """
4
+
5
+ # Analysis scripts can be imported individually as needed
6
+ # from .analyze import main as analyze_results
7
+ # from .analyze_2 import main as analyze_results_2
8
+ # from .visualize_gradcam import main as visualize_gradcam
9
+
10
+ __all__ = []
src/analysis/analyze.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # analyze_results.py
2
+ import os, sys, csv, argparse, numpy as np, matplotlib.pyplot as plt
3
+ from PIL import Image
4
+ import torch, torch.nn as nn, torchvision.transforms as T
5
+ import timm, torchvision.models as tvmodels
6
+ from sklearn.metrics import precision_recall_fscore_support, confusion_matrix
7
+ import cv2
8
+
9
+ # Add parent directory to path for imports
10
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
11
+
12
+ from src.utils import get_device, get_model, get_transforms
13
+
14
+ def load_csv(path):
15
+ with open(path) as f:
16
+ reader = csv.DictReader(f)
17
+ return [r for r in reader]
18
+
19
+ def save_confusion(cm, labels, out_path):
20
+ fig, ax = plt.subplots(figsize=(8,8))
21
+ im = ax.imshow(cm, interpolation='nearest', cmap='Blues')
22
+ ax.set_xticks(range(len(labels))); ax.set_yticks(range(len(labels)))
23
+ ax.set_xticklabels(labels, rotation=45, ha='right'); ax.set_yticklabels(labels)
24
+ for i in range(len(labels)):
25
+ for j in range(len(labels)):
26
+ ax.text(j,i, str(cm[i,j]), ha='center', va='center', color='black')
27
+ plt.colorbar(im)
28
+ plt.tight_layout(); plt.savefig(out_path); plt.close(fig)
29
+
30
+ def main():
31
+ parser = argparse.ArgumentParser()
32
+ parser.add_argument('--checkpoint')
33
+ parser.add_argument('--test-csv')
34
+ parser.add_argument('--img-root', default='.')
35
+ parser.add_argument('--model', default='swin')
36
+ parser.add_argument('--img-size', default=224)
37
+ parser.add_argument('--class-names')
38
+ parser.add_argument('--out-dir', default='outputs/analysis')
39
+ args = parser.parse_args()
40
+ os.makedirs(args.out_dir, exist_ok=True)
41
+ class_names = [s.strip() for s in args.class_names.split(',')]
42
+ num_classes = len(class_names)
43
+ device = get_device()
44
+
45
+ model = get_model(args.model, num_classes, pretrained=False)
46
+ ck = torch.load(args.checkpoint, map_location='cpu')
47
+ model.load_state_dict(ck['model_state_dict'])
48
+ model.to(device); model.eval()
49
+
50
+ rows = load_csv(args.test_csv)
51
+ tf = get_transforms('val', args.img_size)
52
+ preds, trues, paths, probs = [], [], [], []
53
+ os.makedirs(os.path.join(args.out_dir,'examples'), exist_ok=True)
54
+
55
+ for r in rows:
56
+ img_path = r['image_path'] if os.path.isabs(r['image_path']) else os.path.join(args.img_root, r['image_path'])
57
+ img = Image.open(img_path).convert('RGB')
58
+ t = tf(img).unsqueeze(0).to(device)
59
+ with torch.no_grad():
60
+ out = model(t)
61
+ p = torch.softmax(out, dim=1).cpu().numpy()[0]
62
+ pred = int(p.argmax())
63
+ preds.append(pred); trues.append(int(r['label'])); paths.append(img_path); probs.append(p)
64
+
65
+ cm = confusion_matrix(trues, preds)
66
+ p, r, f1, _ = precision_recall_fscore_support(trues, preds, average=None, labels=list(range(num_classes)), zero_division=0)
67
+
68
+ # print per-class metrics
69
+ for i,name in enumerate(class_names):
70
+ print(f'{i} {name}: support={(cm[i].sum())}, prec={p[i]:.3f}, rec={r[i]:.3f}, f1={f1[i]:.3f}')
71
+ print('macro-f1:', np.mean(f1))
72
+
73
+ # save confusion matrix image
74
+ save_confusion(cm, class_names, os.path.join(args.out_dir,'confusion_matrix.png'))
75
+
76
+ # write misclassified csv
77
+ miscsv = os.path.join(args.out_dir,'misclassified.csv')
78
+ with open(miscsv,'w') as f:
79
+ writer = csv.writer(f); writer.writerow(['image_path','true','pred','top1','top2'])
80
+ for path, t, pr, prob in zip(paths,trues,preds,probs):
81
+ if t!=pr:
82
+ top2 = np.argsort(prob)[-2:][::-1].tolist()
83
+ writer.writerow([path, t, pr, int(np.argmax(prob)), int(top2[0])])
84
+
85
+ # Save example images for top confused pairs
86
+ # find the biggest off-diagonal cells
87
+ cm_off = cm.copy(); np.fill_diagonal(cm_off, 0)
88
+ flat = [(cm_off[i,j],i,j) for i in range(num_classes) for j in range(num_classes)]
89
+ flat = sorted(flat, reverse=True)
90
+ for count,i,j in flat[:6]: # top 6 confusion pairs
91
+ if count==0: continue
92
+ pair_dir = os.path.join(args.out_dir, 'examples', f'{i}_to_{j}')
93
+ os.makedirs(pair_dir, exist_ok=True)
94
+ saved=0
95
+ for path,t,pred,prob in zip(paths,trues,preds,probs):
96
+ if t==i and pred==j and saved<10:
97
+ img = Image.open(path).convert('RGB')
98
+ img.save(os.path.join(pair_dir, os.path.basename(path)))
99
+ saved+=1
100
+
101
+ print('Saved misclassified list and example images in', args.out_dir)
102
+
103
+ if __name__=='__main__':
104
+ main()
src/analysis/analyze_2.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import argparse
4
+ import time
5
+ from pathlib import Path
6
+ from typing import List, Dict
7
+
8
+ import numpy as np
9
+ from PIL import Image
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.utils.data import Dataset, DataLoader
14
+ import torchvision.transforms as T
15
+ import torchvision.models as tvmodels
16
+ import timm
17
+
18
+ from sklearn.metrics import precision_recall_fscore_support, confusion_matrix
19
+ import cv2
20
+ import csv
21
+ import matplotlib.pyplot as plt
22
+
23
+ # Import necessary modules for Grad-CAM
24
+ from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
25
+ from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
26
+ from pytorch_grad_cam.utils.image import show_cam_on_image
27
+
28
+ # Add parent directory to path for imports
29
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
30
+
31
+ from src.utils import get_device, get_model, get_transforms
32
+
33
+ DEVICE = get_device()
34
+ print(f"Using device: {DEVICE}")
35
+
36
+ # ----------------------------- Dataset (Reusing logic from pipeline.py) -----------------------------
37
+
38
+ class FractureDataset(Dataset):
39
+ def __init__(self, df, img_root: str = '.', transform=None):
40
+ self.entries = df
41
+ self.img_root = img_root
42
+ self.transform = transform
43
+ # CRITICAL PATH FIX: Define the redundant prefix
44
+ self.redundant_prefix = 'balanced_augmented_dataset/'
45
+ self.redundant_prefix_len = len(self.redundant_prefix)
46
+
47
+ def __len__(self):
48
+ return len(len(self.entries))
49
+
50
+ def __getitem__(self, idx):
51
+ row = self.entries[idx]
52
+ img_path = row['image_path']
53
+
54
+ # PATH CLEANING FIX: Strip the redundant prefix
55
+ if img_path.startswith(self.redundant_prefix):
56
+ img_path = img_path[self.redundant_prefix_len:]
57
+
58
+ if not os.path.isabs(img_path):
59
+ img_path = os.path.join(self.img_root, img_path)
60
+
61
+ img = Image.open(img_path).convert('RGB')
62
+
63
+ # NOTE: We return the raw image here for visualization purposes
64
+ raw_img = np.array(img).astype(np.float32) / 255.0
65
+
66
+ label = int(row['label'])
67
+ if self.transform:
68
+ img = self.transform(img)
69
+
70
+ return img, label, img_path, raw_img
71
+
72
+
73
+ # ----------------------------- Model selection with Grad-CAM target layers -----------------------------
74
+
75
+ def get_model_with_target_layer(name: str, num_classes: int, pretrained: bool=True):
76
+ """Get model and its target layer for Grad-CAM visualization."""
77
+ model = get_model(name, num_classes, pretrained=pretrained)
78
+ name = name.lower()
79
+
80
+ if name.startswith('swin'):
81
+ # Target layer for Swin: the last layer of the last stage (blocks[-1][-1])
82
+ target_layer = model.layers[-1].blocks[-1].norm2
83
+ return model, target_layer
84
+
85
+ if name.startswith('convnext'):
86
+ # Target layer for ConvNext: the last block of the feature extractor
87
+ target_layer = model.stages[-1]
88
+ return model, target_layer
89
+
90
+ if name.startswith('densenet'):
91
+ # Target layer for DenseNet: features.norm5
92
+ target_layer = model.features.norm5
93
+ return model, target_layer
94
+
95
+ raise ValueError(f'Unknown target layer for model: {name}')
96
+
97
+
98
+ # ----------------------------- Helpers: CSV loader -----------------------------
99
+
100
+ def load_csv_like(path: str) -> List[Dict]:
101
+ rows = []
102
+ with open(path, 'r', encoding='utf8') as f:
103
+ reader = csv.DictReader(f)
104
+ for r in reader:
105
+ rows.append(r)
106
+ return rows
107
+
108
+ # ----------------------------- Grad-CAM Analysis -----------------------------
109
+
110
+ def analyze(args):
111
+ device = DEVICE
112
+
113
+ # Load CSVs
114
+ test_rows = load_csv_like(args.test_csv)
115
+
116
+ # Get model and the target layer for Grad-CAM
117
+ model, target_layer = get_model_with_target_layer(args.model, args.num_classes, pretrained=False)
118
+ model.to(device)
119
+
120
+ # Load checkpoint weights
121
+ ck = torch.load(args.checkpoint, map_location=device)
122
+ model.load_state_dict(ck['model_state_dict'])
123
+ model.eval()
124
+ print(f'Loaded model from {args.checkpoint} onto {device}.')
125
+
126
+ # Data setup
127
+ test_tf = get_transforms('val', args.img_size)
128
+ test_ds = FractureDataset(test_rows, img_root=args.img_root, transform=test_tf)
129
+ test_loader = DataLoader(test_ds, batch_size=1, shuffle=False) # Use batch size 1 for accurate CAM per image
130
+
131
+ # Initialize Grad-CAM
132
+ cam = GradCAM(model=model, target_layers=[target_layer], use_cuda=(device.type == 'cuda'))
133
+
134
+ # Setup output directory
135
+ os.makedirs(args.out_dir, exist_ok=True)
136
+
137
+ class_names = args.class_names.split(',')
138
+
139
+ print(f"Starting Grad-CAM analysis on {len(test_ds)} images...")
140
+
141
+ for i, (imgs, labels, img_paths, raw_imgs) in enumerate(test_loader):
142
+ imgs = imgs.to(device)
143
+ true_label = labels.item()
144
+
145
+ # 1. Prediction and Target Setup
146
+ with torch.no_grad():
147
+ outputs = model(imgs)
148
+ predicted_label = outputs.softmax(dim=1).argmax(dim=1).item()
149
+
150
+ # Set the target to the PREDICTED class for visualization
151
+ targets = [ClassifierOutputTarget(predicted_label)]
152
+
153
+ # 2. Generate CAM
154
+ grayscale_cam = cam(input_tensor=imgs, targets=targets)
155
+ grayscale_cam = grayscale_cam[0, :]
156
+
157
+ # 3. Visualization
158
+ # raw_img is the unnormalized image [0, 1]
159
+ raw_img_for_viz = raw_imgs.squeeze(0).numpy()
160
+ visualization = show_cam_on_image(raw_img_for_viz, grayscale_cam, use_rgb=True)
161
+
162
+ # Convert to PIL Image for saving
163
+ visualization_pil = Image.fromarray(cv2.cvtColor((visualization * 255).astype(np.uint8), cv2.COLOR_RGB2BGR))
164
+
165
+ # 4. Save
166
+ path_obj = Path(img_paths[0])
167
+ class_name = class_names[true_label]
168
+
169
+ # Define saving path
170
+ save_dir = os.path.join(args.out_dir, class_name)
171
+ os.makedirs(save_dir, exist_ok=True)
172
+
173
+ # Determine the name with prediction/truth info
174
+ pred_class_name = class_names[predicted_label]
175
+ file_name = f'CAM_T{class_name}_P{pred_class_name}_{path_obj.name}'
176
+ save_path = os.path.join(save_dir, file_name)
177
+
178
+ visualization_pil.save(save_path)
179
+
180
+ if i % 10 == 0:
181
+ print(f"Processed {i+1}/{len(test_ds)}. Saved to: {save_path}")
182
+
183
+ print("Grad-CAM analysis complete. Results saved to:", args.out_dir)
184
+
185
+
186
+ # ----------------------------- Main -----------------------------
187
+
188
+ if __name__ == '__main__':
189
+ parser = argparse.ArgumentParser(description='Run Grad-CAM analysis on test data.')
190
+ parser.add_argument('--checkpoint', type=str, required=True, help='Path to the model checkpoint (e.g., outputs/swin_mps/best.pth)')
191
+ parser.add_argument('--test-csv', type=str, required=True, help='Path to the test CSV file.')
192
+ parser.add_argument('--img-root', type=str, default='.', help='Root directory for images.')
193
+ parser.add_argument('--model', type=str, default='swin', choices=['swin','convnext'])
194
+ parser.add_argument('--num-classes', type=int, default=8)
195
+ parser.add_argument('--img-size', type=int, default=224)
196
+ parser.add_argument('--out-dir', type=str, default='outputs/analysis', help='Directory to save CAM visualizations.')
197
+ parser.add_argument('--class-names', type=str, required=True,
198
+ help='Comma-separated list of class names (e.g., "A,B,C")')
199
+
200
+ args = parser.parse_args()
201
+
202
+ # Check for required library dependencies
203
+ try:
204
+ import pytorch_grad_cam
205
+ except ImportError:
206
+ print("ERROR: pytorch-grad-cam library not found. Please install it:")
207
+ print("pip install pytorch-grad-cam")
208
+ exit(1)
209
+
210
+ analyze(args)
src/analysis/visualize_gradcam.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ visualize_gradcam.py
3
+
4
+ Generates Grad-CAM overlays for misclassified examples listed in a CSV (format produced earlier):
5
+ image_path,true,pred,top1,top2
6
+
7
+ For each row this script saves a PNG with:
8
+ - original image
9
+ - Grad-CAM overlay for the **true** class
10
+ - Grad-CAM overlay for the **predicted** class
11
+ - difference overlay (pred - true)
12
+
13
+ Usage:
14
+ python src/analysis/visualize_gradcam.py \
15
+ --checkpoint outputs/swin_mps/best.pth \
16
+ --misclassified outputs/analysis/misclassified.csv \
17
+ --img-root . \
18
+ --model swin --img-size 224 --out-dir outputs/analysis/gradcam_overlays \
19
+ --class-names "Comminuted,Greenstick,Healthy,Oblique,Oblique Displaced,Spiral,Transverse,Transverse Displaced"
20
+
21
+ Notes:
22
+ - Script prefers MPS (Apple Silicon) if available; if Grad-CAM backward on MPS fails it will automatically fall back to CPU for CAM computation.
23
+ - Requires: torch, timm, torchvision, pillow, numpy, opencv-python
24
+
25
+ """
26
+
27
+ import os
28
+ import sys
29
+ import csv
30
+ import argparse
31
+ from pathlib import Path
32
+ from typing import Optional, List
33
+
34
+ import numpy as np
35
+ from PIL import Image
36
+ import cv2
37
+
38
+ import torch
39
+ import torch.nn as nn
40
+ import torchvision.transforms as T
41
+ import timm
42
+ import torchvision.models as tvmodels
43
+
44
+ # Add parent directory to path for imports
45
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
46
+
47
+ from src.utils import get_device, get_model, get_transforms
48
+
49
+ DEVICE = get_device()
50
+ print(f"Using device: {DEVICE}")
51
+
52
+ # ----------------------------- Grad-CAM Implementation -----------------------------
53
+
54
+ class GradCAM:
55
+ """Hook-based Grad-CAM. Call with a model (in eval mode) and a target conv layer name (optional).
56
+ If target_layer_name is None, the last nn.Conv2d module is chosen heuristically.
57
+ """
58
+ def __init__(self, model: nn.Module, target_layer_name: Optional[str] = None):
59
+ self.model = model
60
+ self.model.eval()
61
+ self.activations = None
62
+ self.gradients = None
63
+ self.handles = []
64
+
65
+ # pick target layer
66
+ if target_layer_name is None:
67
+ target_layer = None
68
+ for n, m in reversed(list(self.model.named_modules())):
69
+ if isinstance(m, nn.Conv2d):
70
+ target_layer_name = n
71
+ target_layer = m
72
+ break
73
+ if target_layer is None:
74
+ raise RuntimeError('No Conv2d layer found for Grad-CAM')
75
+ else:
76
+ target_layer = dict(self.model.named_modules()).get(target_layer_name, None)
77
+ if target_layer is None:
78
+ raise RuntimeError(f'layer name {target_layer_name} not found')
79
+
80
+ # register hooks
81
+ self.handles.append(target_layer.register_forward_hook(self._forward_hook))
82
+ # backward hook
83
+ try:
84
+ self.handles.append(target_layer.register_backward_hook(self._backward_hook))
85
+ except Exception:
86
+ # fallback for newer pytorch versions: use register_full_backward_hook if available
87
+ try:
88
+ self.handles.append(target_layer.register_full_backward_hook(self._backward_hook))
89
+ except Exception:
90
+ # some builds won't allow backward hooks; we'll compute gradients by retaining graph and reading .grad from activations
91
+ pass
92
+
93
+ def _forward_hook(self, module, inp, out):
94
+ # out: tensor shape (B,C,H,W)
95
+ self.activations = out.detach()
96
+
97
+ def _backward_hook(self, module, grad_in, grad_out):
98
+ # grad_out[0] shape (B,C,H,W)
99
+ self.gradients = grad_out[0].detach()
100
+
101
+ def clear(self):
102
+ for h in self.handles:
103
+ try:
104
+ h.remove()
105
+ except Exception:
106
+ pass
107
+ self.handles = []
108
+
109
+ def __call__(self, input_tensor: torch.Tensor, class_idx: Optional[int] = None, device: torch.device = torch.device('cpu')):
110
+ """Compute CAM for a single input tensor (1,C,H,W). Returns cam resized to input HxW in numpy [0,1]."""
111
+ self.model.zero_grad()
112
+ input_tensor = input_tensor.to(device)
113
+ input_tensor.requires_grad = True
114
+ outputs = self.model(input_tensor)
115
+ if class_idx is None:
116
+ class_idx = int(outputs.argmax(dim=1).item())
117
+ loss = outputs[0, class_idx]
118
+ loss.backward(retain_graph=True)
119
+
120
+ if self.gradients is None or self.activations is None:
121
+ raise RuntimeError('GradCAM failed to collect gradients/activations (hooks missing)')
122
+
123
+ grads = self.gradients[0] # C,H,W
124
+ acts = self.activations[0] # C,H,W
125
+ weights = grads.mean(dim=(1,2)) # C
126
+ cam = (weights[:, None, None] * acts).sum(dim=0).cpu().numpy()
127
+ cam = np.maximum(cam, 0)
128
+ cam = cam - cam.min()
129
+ if cam.max() > 0:
130
+ cam = cam / (cam.max() + 1e-8)
131
+ else:
132
+ cam = np.zeros_like(cam)
133
+ # resize to original input spatial size (assume square input)
134
+ H = input_tensor.shape[-2]; W = input_tensor.shape[-1]
135
+ cam = cv2.resize(cam, (W, H))
136
+ return cam
137
+
138
+
139
+ def apply_colormap_on_image(org_img: np.ndarray, activation: np.ndarray, colormap=cv2.COLORMAP_JET, alpha=0.5):
140
+ """Overlay heatmap on image (org_img: HxW x 3 uint8, activation: HxW float in [0,1])"""
141
+ if activation is None:
142
+ raise ValueError('activation is None')
143
+ # ensure activation is 2D and in [0,1]
144
+ activation = np.asarray(activation)
145
+ if activation.ndim == 3:
146
+ # if somehow a channel dim exists, reduce to single channel
147
+ activation = activation[..., 0]
148
+ activation = np.clip(activation, 0.0, 1.0)
149
+
150
+ # Convert activation -> heatmap (BGR) and resize heatmap to match original image
151
+ heatmap = np.uint8(255 * activation)
152
+ heatmap = cv2.applyColorMap(heatmap, colormap)
153
+
154
+ # Resize heatmap to original image spatial size before blending
155
+ h, w = org_img.shape[:2]
156
+ heatmap = cv2.resize(heatmap, (w, h), interpolation=cv2.INTER_LINEAR)
157
+
158
+ # convert heatmap to RGB to match org_img (which is RGB)
159
+ heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
160
+
161
+ # ensure types match for addWeighted
162
+ org_uint8 = org_img.astype('uint8')
163
+ heat_uint8 = heatmap.astype('uint8')
164
+ overlaid = cv2.addWeighted(org_uint8, 1.0 - alpha, heat_uint8, alpha, 0)
165
+ return overlaid
166
+
167
+
168
+ def pil_to_numpy(img: Image.Image):
169
+ arr = np.array(img.convert('RGB'))
170
+ return arr
171
+
172
+
173
+ def get_transform(img_size=224):
174
+ return T.Compose([
175
+ T.Resize((img_size, img_size)),
176
+ T.ToTensor(),
177
+ T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
178
+ ])
179
+
180
+
181
+ def main():
182
+ parser = argparse.ArgumentParser()
183
+ parser.add_argument('--checkpoint', required=True)
184
+ parser.add_argument('--misclassified', required=True)
185
+ parser.add_argument('--img-root', default='.')
186
+ parser.add_argument('--model', default='swin')
187
+ parser.add_argument('--img-size', type=int, default=224)
188
+ parser.add_argument('--out-dir', default='outputs/analysis/gradcam_overlays')
189
+ parser.add_argument('--class-names', required=True)
190
+ parser.add_argument('--target-layer', default=None)
191
+ parser.add_argument('--max-samples', type=int, default=200, help='max misclassified rows to process')
192
+ args = parser.parse_args()
193
+
194
+ class_names = [c.strip() for c in args.class_names.split(',')]
195
+ num_classes = len(class_names)
196
+
197
+ device_pref = detect_device()
198
+ print('preferred device:', device_pref)
199
+
200
+ model = get_model(args.model, num_classes, pretrained=False)
201
+ ck = torch.load(args.checkpoint, map_location='cpu')
202
+ model.load_state_dict(ck['model_state_dict'])
203
+
204
+ # We'll run forward on preferred device, but if backward (for CAM) fails on MPS we'll move to CPU for CAM computation
205
+ model.to(device_pref)
206
+ model.eval()
207
+
208
+ transform = get_transform(args.img_size)
209
+
210
+ os.makedirs(args.out_dir, exist_ok=True)
211
+
212
+ rows = []
213
+ with open(args.misclassified, 'r') as f:
214
+ reader = csv.DictReader(f)
215
+ for r in reader:
216
+ rows.append(r)
217
+ rows = rows[:args.max_samples]
218
+
219
+ # initialize GradCAM on device_pref; if backward fails, we will retry on CPU
220
+ gradcam = None
221
+ try:
222
+ gradcam = GradCAM(model, target_layer_name=args.target_layer)
223
+ cam_device = device_pref
224
+ except Exception as e:
225
+ print('GradCAM init failed on preferred device; will try CPU. Error:', e)
226
+ cam_device = torch.device('cpu')
227
+ model_cpu = get_model(args.model, num_classes, pretrained=False)
228
+ model_cpu.load_state_dict(ck['model_state_dict'])
229
+ model_cpu.to(cam_device)
230
+ model_cpu.eval()
231
+ gradcam = GradCAM(model_cpu, target_layer_name=args.target_layer)
232
+
233
+ for i, r in enumerate(rows):
234
+ img_path = r['image_path'] if os.path.isabs(r['image_path']) else os.path.join(args.img_root, r['image_path'])
235
+ true_lbl = int(r['true'])
236
+ pred_lbl = int(r['pred'])
237
+ try:
238
+ pil = Image.open(img_path).convert('RGB')
239
+ except Exception as e:
240
+ print('failed to open', img_path, e); continue
241
+
242
+ org_np = pil_to_numpy(pil)
243
+ inp = transform(pil).unsqueeze(0)
244
+
245
+ # forward on preferred device to get outputs and predicted class
246
+ try:
247
+ inp_pref = inp.to(device_pref)
248
+ with torch.no_grad():
249
+ out_pref = model(inp_pref)
250
+ probs = torch.softmax(out_pref, dim=1).cpu().numpy()[0]
251
+ except Exception as e:
252
+ print('forward failed on preferred device:', e)
253
+ # fallback to CPU forward
254
+ model.cpu(); inp_cpu = inp; model.eval()
255
+ with torch.no_grad():
256
+ out_cpu = model(inp_cpu)
257
+ probs = torch.softmax(out_cpu, dim=1).numpy()[0]
258
+
259
+ # compute CAMs on gradcam.device (cam_device)
260
+ cam_true = None; cam_pred = None
261
+ try:
262
+ # ensure model used for gradcam is on cam_device
263
+ cam_model = gradcam.model
264
+ # move input to cam_device
265
+ inp_cam = inp.to(cam_device)
266
+ cam_true = gradcam(inp_cam, class_idx=true_lbl, device=cam_device)
267
+ cam_pred = gradcam(inp_cam, class_idx=pred_lbl, device=cam_device)
268
+ except Exception as e:
269
+ print('Grad-CAM on preferred device failed for', img_path, 'error:', e)
270
+ # try CPU
271
+ try:
272
+ # rebuild cpu model if needed
273
+ cpu_dev = torch.device('cpu')
274
+ model_cpu = get_model(args.model, num_classes, pretrained=False)
275
+ model_cpu.load_state_dict(ck['model_state_dict'])
276
+ model_cpu.to(cpu_dev); model_cpu.eval()
277
+ gradcam_cpu = GradCAM(model_cpu, target_layer_name=args.target_layer)
278
+ cam_true = gradcam_cpu(inp.to(cpu_dev), class_idx=true_lbl, device=cpu_dev)
279
+ cam_pred = gradcam_cpu(inp.to(cpu_dev), class_idx=pred_lbl, device=cpu_dev)
280
+ gradcam_cpu.clear()
281
+ except Exception as e2:
282
+ print('Grad-CAM CPU retry failed for', img_path, e2)
283
+ continue
284
+
285
+ # overlay
286
+ try:
287
+ over_true = apply_colormap_on_image(org_np, cam_true, alpha=0.5)
288
+ over_pred = apply_colormap_on_image(org_np, cam_pred, alpha=0.5)
289
+ diff = cam_pred - cam_true
290
+ diff = (diff - diff.min()) / (diff.max() - diff.min() + 1e-8)
291
+ over_diff = apply_colormap_on_image(org_np, diff, alpha=0.6)
292
+
293
+ # concat: original | true | pred | diff
294
+ h, w, _ = org_np.shape
295
+ # resize overlays to original size if needed
296
+ over_true = cv2.resize(over_true, (w, h))
297
+ over_pred = cv2.resize(over_pred, (w, h))
298
+ over_diff = cv2.resize(over_diff, (w, h))
299
+ orig_bgr = cv2.cvtColor(org_np, cv2.COLOR_RGB2BGR)
300
+ grid = np.vstack([np.hstack([orig_bgr, cv2.cvtColor(over_true, cv2.COLOR_RGB2BGR)]),
301
+ np.hstack([cv2.cvtColor(over_pred, cv2.COLOR_RGB2BGR), cv2.cvtColor(over_diff, cv2.COLOR_RGB2BGR)])])
302
+
303
+ out_name = f"{i:04d}_true{true_lbl}_pred{pred_lbl}_{os.path.basename(img_path)}.png"
304
+ out_path = os.path.join(args.out_dir, out_name)
305
+ cv2.imwrite(out_path, grid)
306
+ except Exception as e:
307
+ print('failed to create overlay for', img_path, e)
308
+ continue
309
+
310
+ gradcam.clear()
311
+ print('Saved overlays to', args.out_dir)
312
+
313
+ if __name__ == '__main__':
314
+ main()
src/config/cloud_deployment.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Cloud deployment configuration for model storage and management.
3
+ Supports AWS S3, Google Cloud Storage, and other cloud providers.
4
+ """
5
+
6
+ import os
7
+ import json
8
+ from typing import Optional
9
+
10
+ # ============================================================================
11
+ # AWS S3 Configuration (if using S3 for model storage)
12
+ # ============================================================================
13
+
14
+ AWS_S3_CONFIG = {
15
+ "bucket": os.getenv("AWS_S3_BUCKET", "your-bucket-name"),
16
+ "region": os.getenv("AWS_REGION", "us-east-1"),
17
+ "access_key": os.getenv("AWS_ACCESS_KEY_ID", ""),
18
+ "secret_key": os.getenv("AWS_SECRET_ACCESS_KEY", ""),
19
+ }
20
+
21
+ # ============================================================================
22
+ # Google Cloud Storage Configuration
23
+ # ============================================================================
24
+
25
+ GCS_CONFIG = {
26
+ "project_id": os.getenv("GCP_PROJECT_ID", ""),
27
+ "bucket": os.getenv("GCP_BUCKET", ""),
28
+ "credentials_json": os.getenv("GOOGLE_APPLICATION_CREDENTIALS", ""),
29
+ }
30
+
31
+ # ============================================================================
32
+ # Model Download URLs
33
+ # ============================================================================
34
+
35
+ # These should be set as environment variables for security
36
+ # Example for AWS S3 pre-signed URLs:
37
+ # export SWIN_MODEL_URL="https://your-bucket.s3.amazonaws.com/best_swin.pth?..."
38
+
39
+ MODEL_DOWNLOAD_URLS = {
40
+ "best_swin.pth": os.getenv("SWIN_MODEL_URL", ""),
41
+ "best_mobilenetv2.pth": os.getenv("MOBILENETV2_MODEL_URL", ""),
42
+ "best_densenet169.pth": os.getenv("DENSENET_MODEL_URL", ""),
43
+ "best_efficientnetv2.pth": os.getenv("EFFICIENTNET_MODEL_URL", ""),
44
+ "best_maxvit.pth": os.getenv("MAXVIT_MODEL_URL", ""),
45
+ }
46
+
47
+ # ============================================================================
48
+ # Ollama Configuration for Cloud Deployment
49
+ # ============================================================================
50
+
51
+ OLLAMA_CONFIG = {
52
+ # For local deployment
53
+ "host": os.getenv("OLLAMA_HOST", "http://localhost:11434"),
54
+ "model": os.getenv("OLLAMA_MODEL", "llama3"),
55
+
56
+ # Alternative: Use cloud-hosted LLM API instead
57
+ "use_cloud_api": os.getenv("USE_CLOUD_LLM", "False").lower() == "true",
58
+ "cloud_api_provider": os.getenv("CLOUD_LLM_PROVIDER", "openai"), # openai, anthropic, etc
59
+ "cloud_api_key": os.getenv("CLOUD_LLM_API_KEY", ""),
60
+ }
61
+
62
+ # ============================================================================
63
+ # Streamlit Cloud Configuration
64
+ # ============================================================================
65
+
66
+ STREAMLIT_CLOUD_CONFIG = {
67
+ "deployment_mode": os.getenv("STREAMLIT_DEPLOYMENT", "False").lower() == "true",
68
+ "enable_model_download": os.getenv("ENABLE_MODEL_DOWNLOAD", "True").lower() == "true",
69
+ "model_cache_size_mb": int(os.getenv("MODEL_CACHE_SIZE_MB", "1000")),
70
+ }
71
+
72
+ # ============================================================================
73
+ # Helper Functions
74
+ # ============================================================================
75
+
76
+ def get_s3_client():
77
+ """Create AWS S3 client."""
78
+ try:
79
+ import boto3
80
+ return boto3.client(
81
+ 's3',
82
+ region_name=AWS_S3_CONFIG["region"],
83
+ aws_access_key_id=AWS_S3_CONFIG["access_key"],
84
+ aws_secret_access_key=AWS_S3_CONFIG["secret_key"],
85
+ )
86
+ except ImportError:
87
+ raise ImportError("boto3 not installed. Run: pip install boto3")
88
+
89
+
90
+ def get_gcs_client():
91
+ """Create Google Cloud Storage client."""
92
+ try:
93
+ from google.cloud import storage
94
+ return storage.Client(project=GCS_CONFIG["project_id"])
95
+ except ImportError:
96
+ raise ImportError("google-cloud-storage not installed. Run: pip install google-cloud-storage")
97
+
98
+
99
+ def upload_models_to_s3(local_model_dir: str = "./outputs") -> dict:
100
+ """
101
+ Upload local models to AWS S3.
102
+
103
+ Args:
104
+ local_model_dir: Directory containing model files
105
+
106
+ Returns:
107
+ Dictionary with upload results
108
+ """
109
+ from pathlib import Path
110
+
111
+ client = get_s3_client()
112
+ results = {}
113
+
114
+ for model_file in Path(local_model_dir).glob("best_*.pth"):
115
+ try:
116
+ key = f"models/{model_file.name}"
117
+ print(f"Uploading {model_file.name} to S3...")
118
+ client.upload_file(
119
+ str(model_file),
120
+ AWS_S3_CONFIG["bucket"],
121
+ key,
122
+ Callback=None
123
+ )
124
+ results[model_file.name] = {"status": "success", "s3_key": key}
125
+ print(f"✅ Uploaded {model_file.name}")
126
+ except Exception as e:
127
+ results[model_file.name] = {"status": "failed", "error": str(e)}
128
+ print(f"❌ Failed to upload {model_file.name}: {e}")
129
+
130
+ return results
131
+
132
+
133
+ def upload_models_to_gcs(local_model_dir: str = "./outputs") -> dict:
134
+ """
135
+ Upload local models to Google Cloud Storage.
136
+
137
+ Args:
138
+ local_model_dir: Directory containing model files
139
+
140
+ Returns:
141
+ Dictionary with upload results
142
+ """
143
+ from pathlib import Path
144
+
145
+ client = get_gcs_client()
146
+ bucket = client.bucket(GCS_CONFIG["bucket"])
147
+ results = {}
148
+
149
+ for model_file in Path(local_model_dir).glob("best_*.pth"):
150
+ try:
151
+ blob = bucket.blob(f"models/{model_file.name}")
152
+ print(f"Uploading {model_file.name} to GCS...")
153
+ blob.upload_from_filename(str(model_file))
154
+ results[model_file.name] = {"status": "success", "gs_url": blob.public_url}
155
+ print(f"✅ Uploaded {model_file.name}")
156
+ except Exception as e:
157
+ results[model_file.name] = {"status": "failed", "error": str(e)}
158
+ print(f"❌ Failed to upload {model_file.name}: {e}")
159
+
160
+ return results
161
+
162
+
163
+ def generate_s3_presigned_urls() -> dict:
164
+ """Generate S3 pre-signed URLs for models."""
165
+ client = get_s3_client()
166
+ urls = {}
167
+
168
+ for model_name in MODEL_DOWNLOAD_URLS.keys():
169
+ key = f"models/{model_name}"
170
+ try:
171
+ url = client.generate_presigned_url(
172
+ 'get_object',
173
+ Params={'Bucket': AWS_S3_CONFIG["bucket"], 'Key': key},
174
+ ExpiresIn=3600 * 24 * 7 # 7 days
175
+ )
176
+ urls[model_name] = url
177
+ except Exception as e:
178
+ print(f"Error generating URL for {model_name}: {e}")
179
+
180
+ return urls
181
+
182
+
183
+ def print_deployment_checklist():
184
+ """Print deployment checklist."""
185
+ print("""
186
+ ╔══════════════════════════════════════════════════════════════════════════════╗
187
+ ║ STREAMLIT CLOUD DEPLOYMENT CHECKLIST ║
188
+ ╚══════════════════════════════════════════════════════════════════════════════╝
189
+
190
+ 1. GITHUB SETUP
191
+ ☐ Repository pushed to GitHub
192
+ ☐ .gitignore excludes *.pth files
193
+ ☐ README.md describes the project
194
+ ☐ requirements-prod.txt is in root
195
+
196
+ 2. MODEL STORAGE (Choose one)
197
+ ☐ AWS S3 Setup:
198
+ - Created S3 bucket
199
+ - Uploaded models
200
+ - Generated pre-signed URLs
201
+ - Set environment variables (SWIN_MODEL_URL, etc.)
202
+
203
+ OR
204
+
205
+ ☐ Google Cloud Storage Setup:
206
+ - Created GCS bucket
207
+ - Uploaded models
208
+ - Set environment variables
209
+
210
+ OR
211
+
212
+ ☐ Manual Upload:
213
+ - Will upload models manually to Streamlit Cloud
214
+
215
+ 3. ENVIRONMENT VARIABLES (in Streamlit Cloud Secrets)
216
+ ☐ OLLAMA_HOST (if using external Ollama server)
217
+ ☐ OLLAMA_MODEL (default: llama3)
218
+ ☐ Model download URLs or credentials
219
+ ☐ Cloud provider credentials (if applicable)
220
+
221
+ 4. STREAMLIT CLOUD DEPLOYMENT
222
+ ☐ Created account at share.streamlit.io
223
+ ☐ Connected GitHub repository
224
+ ☐ Configured Secrets
225
+ ☐ Deployed app
226
+
227
+ 5. TESTING
228
+ ☐ App loads successfully
229
+ ☐ Models are available
230
+ ☐ Chat feature works (if Ollama is configured)
231
+ ☐ Workflow can run end-to-end
232
+
233
+ ═══════════════════════════════════════════════════════════════════════════════
234
+
235
+ IMPORTANT NOTES:
236
+ - Each model is ~200MB, total ~1GB
237
+ - Streamlit Cloud max storage is ~1GB
238
+ - Models must be downloaded/cached on startup
239
+ - Ollama requires external server (not available in Streamlit Cloud)
240
+ - For chat feature, consider using cloud APIs (OpenAI, Anthropic)
241
+
242
+ ═══════════════════════════════════════════════════════════════════════════════
243
+ """)
244
+
245
+
246
+ if __name__ == "__main__":
247
+ print("Cloud Deployment Configuration")
248
+ print_deployment_checklist()
249
+
250
+ print("\n📋 Current Configuration:")
251
+ print(f" Deployment Mode: {STREAMLIT_CLOUD_CONFIG['deployment_mode']}")
252
+ print(f" Ollama Host: {OLLAMA_CONFIG['host']}")
253
+ print(f" Use Cloud API: {OLLAMA_CONFIG['use_cloud_api']}")
src/training/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Training pipeline modules for model training and fine-tuning.
3
+ """
4
+
5
+ # Training pipelines can be imported individually as needed
6
+ # from .pipeline import main as train_pipeline
7
+ # from .pipeline_2 import main as train_pipeline_2
8
+
9
+ __all__ = []
src/training/pipeline.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Fracture classification pipeline — Mac MPS only with Weights & Biases logging
3
+
4
+ Features:
5
+ - Enforces MPS device on Apple Silicon (exits if not available).
6
+ - Supports three backbones: swin, convnext, densenet (via timm / torchvision).
7
+ - Local checkpointing (best.pth) and automatic upload of checkpoints to Weights & Biases using `wandb.save`.
8
+ - WandB logging of train/val metrics, lr, and confusion matrix artifact.
9
+ - Stage-2 Grad-CAM cropping and retrain supported.
10
+
11
+ Usage (example):
12
+ python src/training/pipeline.py \
13
+ --train-csv data/balanced_augmented_dataset/train.csv \
14
+ --val-csv data/balanced_augmented_dataset/val.csv \
15
+ --test-csv data/balanced_augmented_dataset/test.csv \
16
+ --model swin --num-classes 8 --epochs 20 --batch-size 6 --img-size 224 \
17
+ --out-dir outputs/swin_mps --wandb-project fracture-mps --wandb-entity your_entity
18
+
19
+ Notes:
20
+ - This script *requires* MPS (Apple Silicon). It will exit if MPS is unavailable.
21
+ - Use small batch sizes (4-8) depending on your GPU/VRAM. The Mac M4 Pro Max 36GB UM should handle moderate sizes but training is slower than CUDA GPUs.
22
+ - For WandB: run `wandb login` beforehand or set `WANDB_API_KEY` env var.
23
+
24
+ """
25
+
26
+ import os
27
+ import sys
28
+ import argparse
29
+ import time
30
+ import copy
31
+ from pathlib import Path
32
+ from typing import Optional, Tuple, List, Dict
33
+
34
+ import numpy as np
35
+ from PIL import Image
36
+
37
+ import torch
38
+ import torch.nn as nn
39
+ from torch.utils.data import Dataset, DataLoader
40
+ import torchvision.transforms as T
41
+ import torchvision.models as tvmodels
42
+ import timm
43
+
44
+ import wandb
45
+ from sklearn.metrics import precision_recall_fscore_support, confusion_matrix
46
+ import cv2
47
+
48
+ # Add parent directory to path for imports
49
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
50
+
51
+ from src.utils import require_mps, get_model, get_transforms, FractureDataset
52
+
53
+ # ----------------------------- Device (MPS only) -----------------------------
54
+
55
+ DEVICE = require_mps()
56
+ print(f"Using device: {DEVICE}")
57
+
58
+ # ----------------------------- Training & Evaluation -----------------------------
59
+
60
+ def save_checkpoint(state, is_best, out_dir, name='checkpoint.pth', upload_to_wandb: bool=False):
61
+ os.makedirs(out_dir, exist_ok=True)
62
+ path = os.path.join(out_dir, name)
63
+ torch.save(state, path)
64
+ if is_best:
65
+ best_path = os.path.join(out_dir, 'best.pth')
66
+ torch.save(state, best_path)
67
+ if upload_to_wandb:
68
+ try:
69
+ wandb.save(best_path)
70
+ print('Uploaded best checkpoint to WandB:', best_path)
71
+ except Exception as e:
72
+ print('WandB save failed:', e)
73
+
74
+
75
+ def train_one_epoch(model, loader, optimizer, criterion, device):
76
+ model.train()
77
+ running_loss = 0.0
78
+ all_preds = []
79
+ all_targets = []
80
+ for imgs, labels, _ in loader:
81
+ imgs = imgs.to(device)
82
+ labels = labels.to(device)
83
+ optimizer.zero_grad()
84
+ outputs = model(imgs)
85
+ loss = criterion(outputs, labels)
86
+ loss.backward()
87
+ optimizer.step()
88
+ running_loss += loss.item() * imgs.size(0)
89
+ preds = outputs.softmax(dim=1).argmax(dim=1)
90
+ all_preds.extend(preds.detach().cpu().numpy().tolist())
91
+ all_targets.extend(labels.detach().cpu().numpy().tolist())
92
+ epoch_loss = running_loss / len(loader.dataset)
93
+ p, r, f1, _ = precision_recall_fscore_support(all_targets, all_preds, average='macro', zero_division=0)
94
+ return epoch_loss, p, r, f1
95
+
96
+
97
+ def validate(model, loader, criterion, device):
98
+ model.eval()
99
+ running_loss = 0.0
100
+ all_preds = []
101
+ all_targets = []
102
+ with torch.no_grad():
103
+ for imgs, labels, _ in loader:
104
+ imgs = imgs.to(device)
105
+ labels = labels.to(device)
106
+ outputs = model(imgs)
107
+ loss = criterion(outputs, labels)
108
+ running_loss += loss.item() * imgs.size(0)
109
+ preds = outputs.softmax(dim=1).argmax(dim=1)
110
+ all_preds.extend(preds.detach().cpu().numpy().tolist())
111
+ all_targets.extend(labels.detach().cpu().numpy().tolist())
112
+ epoch_loss = running_loss / len(loader.dataset)
113
+ p, r, f1, _ = precision_recall_fscore_support(all_targets, all_preds, average='macro', zero_division=0)
114
+ cm = confusion_matrix(all_targets, all_preds)
115
+ return epoch_loss, p, r, f1, cm
116
+
117
+ # ----------------------------- Grad-CAM utilities -----------------------------
118
+ class GradCAM:
119
+ def __init__(self, model: nn.Module, target_layer_name: str = None):
120
+ self.model = model
121
+ self.model.eval()
122
+ self.gradients = None
123
+ self.activations = None
124
+ self.hook_handles = []
125
+ if target_layer_name is None:
126
+ for n, m in reversed(list(self.model.named_modules())):
127
+ if isinstance(m, (nn.Conv2d,)):
128
+ target_layer_name = n
129
+ break
130
+ self.target_layer_name = target_layer_name
131
+ if target_layer_name is None:
132
+ raise ValueError('Cannot find a convolutional layer for Grad-CAM')
133
+ target_module = dict(self.model.named_modules())[self.target_layer_name]
134
+ self.hook_handles.append(target_module.register_forward_hook(self._forward_hook))
135
+ # Note: register_full_backward_hook not supported in all versions; use backward hook where available
136
+ try:
137
+ self.hook_handles.append(target_module.register_backward_hook(self._backward_hook))
138
+ except Exception:
139
+ pass
140
+
141
+ def _forward_hook(self, module, input, output):
142
+ self.activations = output.detach()
143
+
144
+ def _backward_hook(self, module, grad_in, grad_out):
145
+ self.gradients = grad_out[0].detach()
146
+
147
+ def __call__(self, input_tensor: torch.Tensor, class_idx: Optional[int] = None, device: torch.device = DEVICE):
148
+ self.model.zero_grad()
149
+ input_tensor = input_tensor.to(device)
150
+ input_tensor.requires_grad = True
151
+ outputs = self.model(input_tensor)
152
+ if class_idx is None:
153
+ class_idx = outputs.argmax(dim=1).item()
154
+ loss = outputs[0, class_idx]
155
+ loss.backward(retain_graph=True)
156
+ if self.gradients is None or self.activations is None:
157
+ raise RuntimeError('GradCAM failed to collect gradients/activations — try a different target layer name')
158
+ grads = self.gradients[0]
159
+ acts = self.activations[0]
160
+ weights = grads.mean(dim=(1,2))
161
+ cam = (weights[:, None, None] * acts).sum(dim=0)
162
+ cam = np.maximum(cam.cpu().numpy(), 0)
163
+ cam = cv2.resize(cam, (input_tensor.shape[-1], input_tensor.shape[-2]))
164
+ cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
165
+ return cam
166
+
167
+ def close(self):
168
+ for h in self.hook_handles:
169
+ try:
170
+ h.remove()
171
+ except Exception:
172
+ pass
173
+
174
+ # ----------------------------- Heatmap -> bbox -----------------------------
175
+
176
+ def heatmap_to_bbox(cam: np.ndarray, thr: float = 0.5, min_area: int = 100):
177
+ H, W = cam.shape
178
+ thr_val = cam.max() * thr
179
+ mask = (cam >= thr_val).astype('uint8') * 255
180
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
181
+ if not contours:
182
+ return None
183
+ contours = sorted(contours, key=cv2.contourArea, reverse=True)
184
+ for cnt in contours:
185
+ area = cv2.contourArea(cnt)
186
+ if area < min_area:
187
+ continue
188
+ x,y,w,h = cv2.boundingRect(cnt)
189
+ return (x, y, x+w, y+h)
190
+ return None
191
+
192
+ # ----------------------------- Generate crops from Grad-CAM (stage 2 prep) -----------------------------
193
+
194
+ def generate_crops_from_gradcam(model, entries: List[Dict], out_dir: str, transform_for_cam, device: torch.device, cam_layer: str=None, thr: float=0.5, padding: float=0.15):
195
+ os.makedirs(out_dir, exist_ok=True)
196
+ gradcam = GradCAM(model, target_layer_name=cam_layer)
197
+ new_entries = []
198
+ for i, row in enumerate(entries):
199
+ path = row['image_path']
200
+ img = Image.open(path).convert('RGB')
201
+ tensor = transform_for_cam(img).unsqueeze(0).to(device)
202
+ try:
203
+ cam = gradcam(tensor, class_idx=None, device=device)
204
+ except Exception as e:
205
+ print('GradCAM failed for', path, e)
206
+ continue
207
+ bbox = heatmap_to_bbox(cam, thr=thr)
208
+ if bbox is None:
209
+ w, h = img.size
210
+ cx, cy = w//2, h//2
211
+ side = int(min(w,h)*0.6)
212
+ xmin = max(0, cx-side//2); ymin = max(0, cy-side//2); xmax = min(w, cx+side//2); ymax = min(h, cy+side//2)
213
+ else:
214
+ xmin, ymin, xmax, ymax = bbox
215
+ w = xmax - xmin; h = ymax - ymin
216
+ px = int(w * padding); py = int(h * padding)
217
+ xmin = max(0, xmin - px); ymin = max(0, ymin - py); xmax = min(img.size[0], xmax + px); ymax = min(img.size[1], ymax + py)
218
+ crop = img.crop((xmin, ymin, xmax, ymax)).resize((224,224))
219
+ fname = f"crop_{i}_{os.path.basename(path)}"
220
+ out_path = os.path.join(out_dir, fname)
221
+ crop.save(out_path)
222
+ new_entries.append({'image_path': out_path, 'label': row['label']})
223
+ gradcam.close()
224
+ return new_entries
225
+
226
+ # ----------------------------- Inference with simple TTA -----------------------------
227
+
228
+ def tta_predict(model, pil_img: Image.Image, device, img_size=224):
229
+ base = T.Compose([T.Resize((img_size,img_size)), T.ToTensor(), T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])])
230
+ img1 = base(pil_img).unsqueeze(0).to(device)
231
+ img2 = base(pil_img.transpose(Image.FLIP_LEFT_RIGHT)).unsqueeze(0).to(device)
232
+ model.eval()
233
+ with torch.no_grad():
234
+ out1 = model(img1).softmax(dim=1)
235
+ out2 = model(img2).softmax(dim=1)
236
+ probs = (out1 + out2) / 2.0
237
+ return probs.squeeze(0).cpu().numpy()
238
+
239
+ # ----------------------------- Helpers: CSV loader -----------------------------
240
+
241
+ def load_csv_like(path: str) -> List[Dict]:
242
+ import csv
243
+ rows = []
244
+ with open(path, 'r') as f:
245
+ reader = csv.DictReader(f)
246
+ for r in reader:
247
+ rows.append(r)
248
+ return rows
249
+
250
+ # ----------------------------- Main -----------------------------
251
+
252
+ def main(argv=None):
253
+ parser = argparse.ArgumentParser()
254
+ parser.add_argument('--train-csv', type=str, help='train csv', required=True)
255
+ parser.add_argument('--val-csv', type=str, help='val csv', required=True)
256
+ parser.add_argument('--test-csv', type=str, help='test csv', required=True)
257
+ parser.add_argument('--img-root', type=str, default='.', help='root for images')
258
+ parser.add_argument('--model', type=str, default='swin', choices=['swin','convnext','densenet'])
259
+ parser.add_argument('--num-classes', type=int, default=8)
260
+ parser.add_argument('--img-size', type=int, default=224)
261
+ parser.add_argument('--epochs', type=int, default=20)
262
+ parser.add_argument('--batch-size', type=int, default=6)
263
+ parser.add_argument('--lr', type=float, default=1e-4)
264
+ parser.add_argument('--weight-decay', type=float, default=1e-2)
265
+ parser.add_argument('--out-dir', type=str, default='outputs')
266
+ parser.add_argument('--checkpoint', type=str, default=None)
267
+ parser.add_argument('--stage2', action='store_true', help='run stage 2: generate crops from gradcam and retrain')
268
+ parser.add_argument('--stage2-crop-dir', type=str, default='crops')
269
+ parser.add_argument('--cam-layer', type=str, default=None, help='module name for Grad-CAM hook (optional)')
270
+
271
+ # wandb args
272
+ parser.add_argument('--wandb-project', type=str, default='fracture-mps')
273
+ parser.add_argument('--wandb-entity', type=str, default=None)
274
+ parser.add_argument('--wandb-run-name', type=str, default=None)
275
+ parser.add_argument('--wandb-mode', type=str, default='online', choices=['online','offline','disabled'])
276
+
277
+ args = parser.parse_args(argv)
278
+
279
+ # initialize wandb
280
+ if args.wandb_mode != 'disabled':
281
+ wandb.init(project=args.wandb_project, entity=args.wandb_entity, name=args.wandb_run_name, mode=args.wandb_mode)
282
+ wandb.config.update(vars(args))
283
+ else:
284
+ wandb.init(mode='disabled')
285
+
286
+ device = DEVICE
287
+
288
+ # load CSVs
289
+ train_rows = load_csv_like(args.train_csv)
290
+ val_rows = load_csv_like(args.val_csv)
291
+ test_rows = load_csv_like(args.test_csv)
292
+
293
+ train_tf = get_transforms('train', img_size=args.img_size)
294
+ val_tf = get_transforms('val', img_size=args.img_size)
295
+
296
+ model = get_model(args.model, args.num_classes, pretrained=True).to(device)
297
+
298
+ if args.checkpoint:
299
+ ck = torch.load(args.checkpoint, map_location=device)
300
+ model.load_state_dict(ck['model_state_dict'])
301
+ print('Loaded checkpoint', args.checkpoint)
302
+
303
+ train_ds = FractureDataset(train_rows, img_root=args.img_root, transform=train_tf)
304
+ val_ds = FractureDataset(val_rows, img_root=args.img_root, transform=val_tf)
305
+ test_ds = FractureDataset(test_rows, img_root=args.img_root, transform=val_tf)
306
+
307
+ train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=False)
308
+ val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=False)
309
+ test_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=False)
310
+
311
+ criterion = nn.CrossEntropyLoss()
312
+ optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
313
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max(1,args.epochs))
314
+
315
+ best_f1 = 0.0
316
+ out_dir = args.out_dir
317
+ os.makedirs(out_dir, exist_ok=True)
318
+
319
+ for epoch in range(args.epochs):
320
+ start = time.time()
321
+ train_loss, train_p, train_r, train_f1 = train_one_epoch(model, train_loader, optimizer, criterion, device)
322
+ val_loss, val_p, val_r, val_f1, cm = validate(model, val_loader, criterion, device)
323
+ scheduler.step()
324
+ is_best = val_f1 > best_f1
325
+ if is_best:
326
+ best_f1 = val_f1
327
+ ck_name = f'epoch_{epoch}.pth'
328
+ save_checkpoint({'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'val_f1': val_f1}, is_best, out_dir, name=ck_name, upload_to_wandb=(args.wandb_mode!='disabled'))
329
+
330
+ # wandb logging
331
+ metrics = {'epoch': epoch, 'train_loss': train_loss, 'train_macro_f1': train_f1, 'val_loss': val_loss, 'val_macro_f1': val_f1, 'lr': scheduler.get_last_lr()[0]}
332
+ print(f"Epoch {epoch}/{args.epochs} time={time.time()-start:.1f}s")
333
+ print(metrics)
334
+ if args.wandb_mode != 'disabled':
335
+ wandb.log(metrics, step=epoch)
336
+ # log confusion matrix as an image
337
+ try:
338
+ import matplotlib.pyplot as plt
339
+ fig, ax = plt.subplots(figsize=(6,6))
340
+ ax.imshow(cm, interpolation='nearest')
341
+ ax.set_title('Confusion matrix')
342
+ wandb.log({"confusion_matrix": wandb.Image(fig)}, step=epoch)
343
+ plt.close(fig)
344
+ except Exception as e:
345
+ print('Failed to log confusion matrix plot to wandb:', e)
346
+
347
+ # load best and final test evaluation
348
+ best_ck = os.path.join(out_dir, 'best.pth')
349
+ if os.path.exists(best_ck):
350
+ ck = torch.load(best_ck, map_location=device)
351
+ model.load_state_dict(ck['model_state_dict'])
352
+ print('Loaded best checkpoint for final evaluation')
353
+
354
+ test_loss, test_p, test_r, test_f1, test_cm = validate(model, test_loader, criterion, device)
355
+ print('Test results:', test_loss, test_p, test_r, test_f1)
356
+ np.savetxt(os.path.join(out_dir, 'confusion_matrix.txt'), test_cm, fmt='%d')
357
+
358
+ if args.wandb_mode != 'disabled':
359
+ # save confusion matrix as artifact
360
+ try:
361
+ wandb.log({'test_macro_f1': test_f1})
362
+ wandb.save(os.path.join(out_dir, 'confusion_matrix.txt'))
363
+ except Exception as e:
364
+ print('WandB final save failed:', e)
365
+
366
+ # Stage 2: Grad-CAM cropping and retrain
367
+ if args.stage2:
368
+ print('Starting Stage-2: generating crops via Grad-CAM and retraining on cropped ROIs')
369
+ cam_transform = T.Compose([T.Resize((224,224)), T.ToTensor(), T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])])
370
+ crops_out = args.stage2_crop_dir
371
+ new_train = generate_crops_from_gradcam(model, train_rows, out_dir=crops_out, transform_for_cam=cam_transform, device=device, cam_layer=args.cam_layer or None, thr=0.5)
372
+ train_ds2 = FractureDataset(new_train, transform=get_transforms('train', img_size=args.img_size))
373
+ train_loader2 = DataLoader(train_ds2, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=False)
374
+ optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
375
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max(1,args.epochs//2))
376
+ best_f1_stage2 = 0.0
377
+ for epoch in range(max(5, args.epochs//2)):
378
+ train_loss, train_p, train_r, train_f1 = train_one_epoch(model, train_loader2, optimizer, criterion, device)
379
+ val_loss, val_p, val_r, val_f1, cm = validate(model, val_loader, criterion, device)
380
+ is_best = val_f1 > best_f1_stage2
381
+ if is_best:
382
+ best_f1_stage2 = val_f1
383
+ save_checkpoint({'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'val_f1': val_f1}, is_best, out_dir, name=f'stage2_epoch_{epoch}.pth', upload_to_wandb=(args.wandb_mode!='disabled'))
384
+ scheduler.step()
385
+ if args.wandb_mode != 'disabled':
386
+ wandb.log({'stage2_epoch': epoch, 'stage2_val_macro_f1': val_f1, 'stage2_train_macro_f1': train_f1}, step=epoch)
387
+ print('Stage-2 finished. Best val macro-F1:', best_f1_stage2)
388
+
389
+ print('Finished.')
390
+ if args.wandb_mode != 'disabled':
391
+ wandb.finish()
392
+
393
+ if __name__ == '__main__':
394
+ main()
src/training/pipeline_2.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import argparse
4
+ import time
5
+ import copy
6
+ from pathlib import Path
7
+ from typing import Optional, Tuple, List, Dict
8
+
9
+ import numpy as np
10
+ from PIL import Image
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from torch.utils.data import Dataset, DataLoader
15
+ import torchvision.transforms as T
16
+ import torchvision.models as tvmodels
17
+ import timm
18
+
19
+ import wandb
20
+ from sklearn.metrics import precision_recall_fscore_support, confusion_matrix
21
+ import cv2
22
+ import csv
23
+
24
+ # Add parent directory to path for imports
25
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
26
+
27
+ from src.utils import get_device, get_model, get_transforms, FractureDataset
28
+
29
+ # ----------------------------- Device Selection -----------------------------
30
+
31
+ DEVICE = get_device()
32
+ print(f"Using device: {DEVICE}")
33
+
34
+ # ----------------------------- Training & Evaluation -----------------------------
35
+ # (Omitted for brevity, but stays the same as before)
36
+ def save_checkpoint(state, is_best, out_dir, name='checkpoint.pth', upload_to_wandb: bool=False):
37
+ os.makedirs(out_dir, exist_ok=True)
38
+ path = os.path.join(out_dir, name)
39
+ torch.save(state, path)
40
+ if is_best:
41
+ best_path = os.path.join(out_dir, 'best.pth')
42
+ torch.save(state, best_path)
43
+ if upload_to_wandb:
44
+ try:
45
+ wandb.save(best_path)
46
+ print('Uploaded best checkpoint to WandB:', best_path)
47
+ except Exception as e:
48
+ print('WandB save failed:', e)
49
+
50
+ def train_one_epoch(model, loader, optimizer, criterion, device):
51
+ model.train()
52
+ running_loss = 0.0
53
+ all_preds = []
54
+ all_targets = []
55
+ for imgs, labels, _ in loader:
56
+ imgs = imgs.to(device)
57
+ labels = labels.to(device)
58
+ optimizer.zero_grad()
59
+ outputs = model(imgs)
60
+ loss = criterion(outputs, labels)
61
+ loss.backward()
62
+ optimizer.step()
63
+ running_loss += loss.item() * imgs.size(0)
64
+ preds = outputs.softmax(dim=1).argmax(dim=1)
65
+ all_preds.extend(preds.detach().cpu().numpy().tolist())
66
+ all_targets.extend(labels.detach().cpu().numpy().tolist())
67
+ epoch_loss = running_loss / len(loader.dataset)
68
+ p, r, f1, _ = precision_recall_fscore_support(all_targets, all_preds, average='macro', zero_division=0)
69
+ return epoch_loss, p, r, f1
70
+
71
+ def validate(model, loader, criterion, device):
72
+ model.eval()
73
+ running_loss = 0.0
74
+ all_preds = []
75
+ all_targets = []
76
+ with torch.no_grad():
77
+ for imgs, labels, _ in loader:
78
+ imgs = imgs.to(device)
79
+ labels = labels.to(device)
80
+ outputs = model(imgs)
81
+ loss = criterion(outputs, labels)
82
+ running_loss += loss.item() * imgs.size(0)
83
+ preds = outputs.softmax(dim=1).argmax(dim=1)
84
+ all_preds.extend(preds.detach().cpu().numpy().tolist())
85
+ all_targets.extend(labels.detach().cpu().numpy().tolist())
86
+ epoch_loss = running_loss / len(loader.dataset)
87
+ p, r, f1, _ = precision_recall_fscore_support(all_targets, all_preds, average='macro', labels=list(range(outputs.shape[1])), zero_division=0)
88
+ cm = confusion_matrix(all_targets, all_preds, labels=list(range(outputs.shape[1])))
89
+ return epoch_loss, p, r, f1, cm
90
+
91
+ # ----------------------------- Helpers: CSV loader -----------------------------
92
+ # (Omitted for brevity, but stays the same as before)
93
+ def load_csv_like(path: str) -> List[Dict]:
94
+ rows = []
95
+ with open(path, 'r', encoding='utf8') as f:
96
+ reader = csv.DictReader(f)
97
+ for r in reader:
98
+ rows.append(r)
99
+ return rows
100
+
101
+ # ----------------------------- Main -----------------------------
102
+
103
+ def main(argv=None):
104
+ parser = argparse.ArgumentParser()
105
+ parser.add_argument('--train-csv', type=str, help='train csv', required=True)
106
+ parser.add_argument('--val-csv', type=str, help='val csv', required=True)
107
+ parser.add_argument('--test-csv', type=str, help='test csv', required=True)
108
+ parser.add_argument('--img-root', type=str, default='.', help='root for images')
109
+ parser.add_argument('--model', type=str, default='swin', choices=['swin','convnext','densenet'])
110
+ parser.add_argument('--num-classes', type=int, default=8)
111
+ parser.add_argument('--img-size', type=int, default=224)
112
+ parser.add_argument('--epochs', type=int, default=20)
113
+ parser.add_argument('--batch-size', type=int, default=6)
114
+ parser.add_argument('--lr', type=float, default=1e-4)
115
+ parser.add_argument('--weight-decay', type=float, default=1e-2)
116
+ parser.add_argument('--out-dir', type=str, default='outputs')
117
+ parser.add_argument('--checkpoint', type=str, default=None)
118
+ parser.add_argument('--stage2', action='store_true', help='run stage 2: generate crops from gradcam and retrain')
119
+ parser.add_argument('--stage2-crop-dir', type=str, default='crops')
120
+ parser.add_argument('--cam-layer', type=str, default=None, help='module name for Grad-CAM hook (optional)')
121
+
122
+ # wandb args
123
+ parser.add_argument('--wandb-project', type=str, default='fracture-mps')
124
+ parser.add_argument('--wandb-entity', type=str, default=None)
125
+ parser.add_argument('--wandb-run-name', type=str, default=None)
126
+ parser.add_argument('--wandb-mode', type=str, default='online', choices=['online','offline','disabled'])
127
+
128
+ args = parser.parse_args(argv)
129
+
130
+ if args.wandb_mode != 'disabled':
131
+ wandb.init(project=args.wandb_project, entity=args.wandb_entity, name=args.wandb_run_name, mode=args.wandb_mode)
132
+ wandb.config.update(vars(args))
133
+ else:
134
+ wandb.init(mode='disabled')
135
+
136
+ device = DEVICE
137
+
138
+ train_rows = load_csv_like(args.train_csv)
139
+ val_rows = load_csv_like(args.val_csv)
140
+ test_rows = load_csv_like(args.test_csv)
141
+
142
+ train_tf = get_transforms('train', img_size=args.img_size)
143
+ val_tf = get_transforms('val', img_size=args.img_size)
144
+
145
+ model = get_model(args.model, args.num_classes, pretrained=True).to(device)
146
+
147
+ if args.checkpoint:
148
+ ck = torch.load(args.checkpoint, map_location='cpu')
149
+ state_dict = ck.get('model_state_dict', ck)
150
+ model.load_state_dict(state_dict)
151
+ print('Loaded checkpoint', args.checkpoint)
152
+
153
+ pin_memory = device.type == 'cuda'
154
+ num_workers = 0 if device.type == 'cuda' else 4
155
+
156
+ train_ds = FractureDataset(train_rows, img_root=args.img_root, transform=train_tf)
157
+ val_ds = FractureDataset(val_rows, img_root=args.img_root, transform=val_tf)
158
+ test_ds = FractureDataset(test_rows, img_root=args.img_root, transform=val_tf)
159
+
160
+ train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)
161
+ val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory)
162
+ # FIX: Corrected typo from args.batch-size to args.batch_size
163
+ test_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory)
164
+
165
+ criterion = nn.CrossEntropyLoss()
166
+ optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
167
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max(1,args.epochs))
168
+
169
+ best_f1 = 0.0
170
+ out_dir = args.out_dir
171
+ os.makedirs(out_dir, exist_ok=True)
172
+
173
+ for epoch in range(args.epochs):
174
+ start = time.time()
175
+ train_loss, train_p, train_r, train_f1 = train_one_epoch(model, train_loader, optimizer, criterion, device)
176
+ val_loss, val_p, val_r, val_f1, cm = validate(model, val_loader, criterion, device)
177
+ scheduler.step()
178
+ is_best = val_f1 > best_f1
179
+ if is_best:
180
+ best_f1 = val_f1
181
+ ck_name = f'epoch_{epoch}.pth'
182
+
183
+ save_checkpoint({'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'val_f1': val_f1}, is_best, out_dir, name=ck_name, upload_to_wandb=(args.wandb_mode!='disabled'))
184
+
185
+ # wandb logging
186
+ metrics = {'epoch': epoch, 'train_loss': train_loss, 'train_macro_f1': train_f1, 'val_loss': val_loss, 'val_macro_f1': val_f1, 'lr': scheduler.get_last_lr()[0]}
187
+ print(f"Epoch {epoch}/{args.epochs} time={time.time()-start:.1f}s")
188
+ print(metrics)
189
+ if args.wandb_mode != 'disabled':
190
+ wandb.log(metrics, step=epoch)
191
+ # log confusion matrix as an image
192
+ try:
193
+ import matplotlib.pyplot as plt
194
+ fig, ax = plt.subplots(figsize=(6,6))
195
+ ax.imshow(cm, interpolation='nearest')
196
+ ax.set_title('Confusion matrix')
197
+ wandb.log({"confusion_matrix": wandb.Image(fig)}, step=epoch)
198
+ plt.close(fig)
199
+ except Exception as e:
200
+ print('Failed to log confusion matrix plot to wandb:', e)
201
+
202
+ # load best and final test evaluation
203
+ best_ck = os.path.join(out_dir, 'best.pth')
204
+ if os.path.exists(best_ck):
205
+ ck = torch.load(best_ck, map_location=device)
206
+ model.load_state_dict(ck['model_state_dict'])
207
+ print('Loaded best checkpoint for final evaluation')
208
+
209
+ test_loss, test_p, test_r, test_f1, test_cm = validate(model, test_loader, criterion, device)
210
+ print('Test results:', test_loss, test_p, test_r, test_f1)
211
+ np.savetxt(os.path.join(out_dir, 'confusion_matrix.txt'), test_cm, fmt='%d')
212
+
213
+ if args.wandb_mode != 'disabled':
214
+ try:
215
+ wandb.log({'test_macro_f1': test_f1})
216
+ wandb.save(os.path.join(out_dir, 'confusion_matrix.txt'))
217
+ except Exception as e:
218
+ print('WandB final save failed:', e)
219
+
220
+ print('Finished.')
221
+ if args.wandb_mode != 'disabled':
222
+ wandb.finish()
223
+
224
+ if __name__ == '__main__':
225
+ main()
src/utils/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .device_utils import get_device, require_mps, DEVICE
2
+ from .model_utils import get_model
3
+ from .data_utils import get_transforms, FractureDataset
4
+
5
+ __all__ = [
6
+ 'get_device',
7
+ 'require_mps',
8
+ 'DEVICE',
9
+ 'get_model',
10
+ 'get_transforms',
11
+ 'FractureDataset'
12
+ ]
src/utils/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (508 Bytes). View file
 
src/utils/__pycache__/data_utils.cpython-311.pyc ADDED
Binary file (4.1 kB). View file
 
src/utils/__pycache__/device_utils.cpython-311.pyc ADDED
Binary file (1.41 kB). View file
 
src/utils/__pycache__/model_utils.cpython-311.pyc ADDED
Binary file (3.09 kB). View file
 
src/utils/data_utils.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image
3
+ import torch
4
+ from torch.utils.data import Dataset
5
+ import torchvision.transforms as T
6
+
7
+ def get_transforms(split: str, img_size: int = 224):
8
+ """Returns train or val/test transforms."""
9
+ if split == 'train':
10
+ return T.Compose([
11
+ T.Resize((int(img_size*1.1), int(img_size*1.1))),
12
+ T.RandomResizedCrop(img_size, scale=(0.8, 1.0)),
13
+ T.RandomRotation(15),
14
+ T.RandomHorizontalFlip(),
15
+ T.ToTensor(),
16
+ T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
17
+ ])
18
+ else:
19
+ return T.Compose([
20
+ T.Resize((img_size, img_size)),
21
+ T.CenterCrop(img_size),
22
+ T.ToTensor(),
23
+ T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
24
+ ])
25
+
26
+ class FractureDataset(Dataset):
27
+ """Dataset for fracture images with optional bounding box cropping."""
28
+
29
+ def __init__(self, df, img_root: str = '.', transform=None, use_bbox: bool = False):
30
+ self.entries = df
31
+ self.img_root = img_root
32
+ self.transform = transform
33
+ self.use_bbox = use_bbox
34
+
35
+ def __len__(self):
36
+ return len(self.entries)
37
+
38
+ def __getitem__(self, idx):
39
+ row = self.entries[idx]
40
+ img_path = row['image_path']
41
+
42
+ if not os.path.isabs(img_path):
43
+ img_path = os.path.join(self.img_root, img_path)
44
+
45
+ img = Image.open(img_path).convert('RGB')
46
+
47
+ if self.use_bbox and all(k in row for k in ('bbox_xmin','bbox_ymin','bbox_xmax','bbox_ymax')):
48
+ xmin = int(row['bbox_xmin'])
49
+ ymin = int(row['bbox_ymin'])
50
+ xmax = int(row['bbox_xmax'])
51
+ ymax = int(row['bbox_ymax'])
52
+ img = img.crop((xmin, ymin, xmax, ymax))
53
+
54
+ label = int(row['label'])
55
+
56
+ if self.transform:
57
+ img = self.transform(img)
58
+
59
+ return img, label, img_path
src/utils/device_utils.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def get_device():
4
+ """Dynamically selects CUDA, MPS, or falls back to CPU."""
5
+ if torch.cuda.is_available():
6
+ return torch.device('cuda')
7
+ elif getattr(torch.backends, 'mps', None) is not None and torch.backends.mps.is_available():
8
+ return torch.device('mps')
9
+ else:
10
+ return torch.device('cpu')
11
+
12
+ def require_mps():
13
+ """Enforces MPS device (for Mac-only scripts)."""
14
+ if getattr(torch.backends, 'mps', None) is None or not torch.backends.mps.is_available():
15
+ raise RuntimeError('MPS (Apple Silicon) is required but not available.')
16
+ return torch.device('mps')
17
+
18
+ DEVICE = get_device()
src/utils/model_manager.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model management utility for cloud deployments.
3
+ Handles downloading and caching models from cloud storage.
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ import json
9
+ import hashlib
10
+ from pathlib import Path
11
+ from typing import Dict, Optional
12
+ import requests
13
+
14
+ # Add parent directory to path
15
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
16
+
17
+ # Model registry - Update these URLs with your cloud storage
18
+ MODEL_REGISTRY = {
19
+ "best_swin.pth": {
20
+ "size_mb": 200,
21
+ # Replace with your actual cloud storage URL
22
+ "url": os.getenv("SWIN_MODEL_URL", ""),
23
+ "hash": "", # Optional: SHA256 hash for verification
24
+ },
25
+ "best_mobilenetv2.pth": {
26
+ "size_mb": 100,
27
+ "url": os.getenv("MOBILENETV2_MODEL_URL", ""),
28
+ "hash": "",
29
+ },
30
+ "best_densenet169.pth": {
31
+ "size_mb": 200,
32
+ "url": os.getenv("DENSENET_MODEL_URL", ""),
33
+ "hash": "",
34
+ },
35
+ "best_efficientnetv2.pth": {
36
+ "size_mb": 180,
37
+ "url": os.getenv("EFFICIENTNET_MODEL_URL", ""),
38
+ "hash": "",
39
+ },
40
+ "best_maxvit.pth": {
41
+ "size_mb": 220,
42
+ "url": os.getenv("MAXVIT_MODEL_URL", ""),
43
+ "hash": "",
44
+ },
45
+ }
46
+
47
+ MODELS_DIR = Path("./outputs")
48
+ MODELS_DIR.mkdir(exist_ok=True)
49
+
50
+
51
+ def check_model_exists(model_name: str) -> bool:
52
+ """Check if a model file exists locally."""
53
+ model_path = MODELS_DIR / model_name
54
+ return model_path.exists()
55
+
56
+
57
+ def get_all_models_status() -> Dict[str, Dict]:
58
+ """Get status of all models."""
59
+ status = {}
60
+ for model_name, config in MODEL_REGISTRY.items():
61
+ exists = check_model_exists(model_name)
62
+ status[model_name] = {
63
+ "exists": exists,
64
+ "size_mb": config["size_mb"],
65
+ "url": config["url"],
66
+ }
67
+ return status
68
+
69
+
70
+ def download_model(model_name: str, force: bool = False) -> bool:
71
+ """
72
+ Download a model from cloud storage.
73
+
74
+ Args:
75
+ model_name: Name of the model file
76
+ force: Force download even if file exists
77
+
78
+ Returns:
79
+ True if successful, False otherwise
80
+ """
81
+ if not force and check_model_exists(model_name):
82
+ print(f"✅ {model_name} already exists locally")
83
+ return True
84
+
85
+ if model_name not in MODEL_REGISTRY:
86
+ print(f"❌ {model_name} not found in registry")
87
+ return False
88
+
89
+ config = MODEL_REGISTRY[model_name]
90
+ url = config.get("url")
91
+
92
+ if not url:
93
+ print(f"⚠️ No download URL configured for {model_name}")
94
+ print(f" Set environment variable or update MODEL_REGISTRY")
95
+ return False
96
+
97
+ try:
98
+ print(f"📥 Downloading {model_name} from cloud storage...")
99
+ response = requests.get(url, timeout=300, stream=True)
100
+ response.raise_for_status()
101
+
102
+ model_path = MODELS_DIR / model_name
103
+ total_size = int(response.headers.get('content-length', 0))
104
+
105
+ with open(model_path, 'wb') as f:
106
+ downloaded = 0
107
+ for chunk in response.iter_content(chunk_size=8192):
108
+ if chunk:
109
+ f.write(chunk)
110
+ downloaded += len(chunk)
111
+ if total_size:
112
+ percent = (downloaded / total_size) * 100
113
+ print(f" Progress: {percent:.1f}%", end='\r')
114
+
115
+ print(f"\n✅ Successfully downloaded {model_name}")
116
+ return True
117
+
118
+ except Exception as e:
119
+ print(f"❌ Failed to download {model_name}: {e}")
120
+ return False
121
+
122
+
123
+ def download_all_models() -> Dict[str, bool]:
124
+ """Download all models that have URLs configured."""
125
+ results = {}
126
+ for model_name in MODEL_REGISTRY:
127
+ results[model_name] = download_model(model_name)
128
+ return results
129
+
130
+
131
+ def initialize_models_for_deployment() -> bool:
132
+ """
133
+ Initialize models for deployment.
134
+ Checks if models exist, attempts download if needed.
135
+
136
+ Returns:
137
+ True if all models are available, False otherwise
138
+ """
139
+ print("\n🔍 Checking model availability...")
140
+ status = get_all_models_status()
141
+
142
+ all_available = True
143
+ for model_name, info in status.items():
144
+ if info["exists"]:
145
+ print(f" ✅ {model_name}")
146
+ else:
147
+ print(f" ❌ {model_name} - NOT FOUND")
148
+ if info["url"]:
149
+ print(f" URL configured: {info['url'][:50]}...")
150
+ else:
151
+ print(f" No download URL - configure via environment variables")
152
+ all_available = False
153
+
154
+ if not all_available:
155
+ print("\n⚠️ Some models are missing!")
156
+ print(" Option 1: Configure cloud storage URLs and run: python -c 'from src.utils.model_manager import download_all_models; download_all_models()'")
157
+ print(" Option 2: Upload models manually to ./outputs/")
158
+ return False
159
+
160
+ print("\n✅ All models are available!")
161
+ return True
162
+
163
+
164
+ if __name__ == "__main__":
165
+ print("Model Manager - Cloud Deployment Utility")
166
+ print("=" * 50)
167
+
168
+ if len(sys.argv) > 1:
169
+ command = sys.argv[1]
170
+
171
+ if command == "status":
172
+ status = get_all_models_status()
173
+ print(json.dumps(status, indent=2))
174
+
175
+ elif command == "download-all":
176
+ results = download_all_models()
177
+ print("\nDownload Results:")
178
+ print(json.dumps(results, indent=2))
179
+
180
+ elif command == "check":
181
+ success = initialize_models_for_deployment()
182
+ sys.exit(0 if success else 1)
183
+
184
+ else:
185
+ print(f"Unknown command: {command}")
186
+ print("Available commands: status, download-all, check")
187
+
188
+ else:
189
+ # Default: check status
190
+ initialize_models_for_deployment()
src/utils/model_utils.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import timm
3
+ import torchvision.models as tvmodels
4
+
5
+ def get_model(name: str, num_classes: int, pretrained: bool = True):
6
+ """Loads and adapts model architecture."""
7
+ name = name.lower()
8
+
9
+ if name.startswith('swin'):
10
+ model = timm.create_model('swin_small_patch4_window7_224', pretrained=pretrained)
11
+ if hasattr(model, 'reset_classifier'):
12
+ model.reset_classifier(num_classes=num_classes)
13
+ else:
14
+ model.head = nn.Linear(model.head.in_features, num_classes)
15
+ return model
16
+
17
+ if name.startswith('convnext'):
18
+ model = timm.create_model('convnext_tiny', pretrained=pretrained)
19
+ if hasattr(model, 'reset_classifier'):
20
+ model.reset_classifier(num_classes=num_classes)
21
+ else:
22
+ model.head.fc = nn.Linear(model.head.fc.in_features, num_classes)
23
+ return model
24
+
25
+ if name.startswith('densenet'):
26
+ model = tvmodels.densenet169(pretrained=pretrained)
27
+ model.classifier = nn.Linear(model.classifier.in_features, num_classes)
28
+ return model
29
+
30
+ if name.startswith('mobilenet'):
31
+ model = timm.create_model('mobilenetv2_100', pretrained=pretrained)
32
+ if hasattr(model, 'reset_classifier'):
33
+ model.reset_classifier(num_classes=num_classes)
34
+ else:
35
+ model.classifier = nn.Linear(model.classifier.in_features, num_classes)
36
+ return model
37
+
38
+ if name.startswith('efficientnet'):
39
+ model = timm.create_model('efficientnet_b0', pretrained=pretrained)
40
+ if hasattr(model, 'reset_classifier'):
41
+ model.reset_classifier(num_classes=num_classes)
42
+ else:
43
+ model.classifier = nn.Linear(model.classifier.in_features, num_classes)
44
+ return model
45
+
46
+ if name.startswith('maxvit'):
47
+ model = timm.create_model('maxvit_tiny_tf_224', pretrained=pretrained)
48
+ if hasattr(model, 'reset_classifier'):
49
+ model.reset_classifier(num_classes=num_classes)
50
+ else:
51
+ model.head.fc = nn.Linear(model.head.fc.in_features, num_classes)
52
+ return model
53
+
54
+ raise ValueError(f'Unknown model: {name}')
streamlit_app.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Main entry point for Streamlit Cloud deployment.
3
+ Streamlit Cloud looks for streamlit_app.py or app.py in the root directory.
4
+
5
+ Uses the cloud-optimized version with Hugging Face Inference API.
6
+ For local development with Ollama, use: streamlit run apps/patient_chat_app_local.py
7
+ """
8
+
9
+ import os
10
+ import sys
11
+ import streamlit as st
12
+
13
+ # Add src directory to Python path
14
+ sys.path.insert(0, os.path.abspath(os.path.dirname(__file__)))
15
+
16
+ # Pre-initialize models check (runs once at app startup)
17
+ @st.cache_resource
18
+ def initialize_deployment():
19
+ """Initialize deployment environment and models."""
20
+ from src.utils.model_manager import initialize_models_for_deployment
21
+
22
+ try:
23
+ models_ready = initialize_models_for_deployment()
24
+ return models_ready
25
+ except Exception as e:
26
+ st.error(f"Error checking models: {e}")
27
+ return False
28
+
29
+ if __name__ == "__main__":
30
+ # Check model availability
31
+ # models_ready = initialize_deployment()
32
+
33
+ # Import and run the cloud version with Hugging Face
34
+ from apps.patient_chat_app_cloud import main
35
+ main()
36
+
37
+ # import os
38
+ # import sys
39
+ # import streamlit as st
40
+
41
+ # # Add src directory to Python path
42
+ # sys.path.insert(0, os.path.abspath(os.path.dirname(__file__)))
43
+
44
+ # # Check if we're in deployment mode
45
+ # IS_STREAMLIT_CLOUD = os.getenv("STREAMLIT_DEPLOYMENT", "False").lower() == "true"
46
+
47
+ # # Pre-initialize models check (runs once at app startup)
48
+ # @st.cache_resource
49
+ # def initialize_deployment():
50
+ # """Initialize deployment environment and models."""
51
+ # from src.utils.model_manager import initialize_models_for_deployment
52
+
53
+ # try:
54
+ # models_ready = initialize_models_for_deployment()
55
+ # return models_ready
56
+ # except Exception as e:
57
+ # st.error(f"Error checking models: {e}")
58
+ # return False
59
+
60
+ # if __name__ == "__main__":
61
+ # # Check model availability
62
+ # # models_ready = initialize_deployment()
63
+
64
+ # # Import and run the main app
65
+ # from apps.patient_chat_app_local import main
66
+ # main()