edeler commited on
Commit
7056797
·
verified ·
1 Parent(s): 44e8cba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +500 -500
app.py CHANGED
@@ -1,500 +1,500 @@
1
- import os
2
- import json
3
- import gc
4
- import time
5
- import traceback
6
- from typing import Dict, List, Optional, Tuple, Callable, Any
7
-
8
- import torch
9
- import gradio as gr
10
- import supervision as sv
11
- from PIL import Image
12
-
13
- # Try to import optional dependencies
14
- try:
15
- from transformers import (
16
- AutoModelForCausalLM,
17
- AutoTokenizer,
18
- AutoModelForImageTextToText,
19
- AutoProcessor,
20
- BitsAndBytesConfig,
21
- )
22
- except Exception:
23
- AutoModelForCausalLM = None
24
- AutoTokenizer = None
25
- AutoModelForImageTextToText = None
26
- AutoProcessor = None
27
- BitsAndBytesConfig = None
28
-
29
- # Import RF-DETR (assumes it's in the same directory or installed)
30
- try:
31
- from rfdetr import RFDETRMedium
32
- except ImportError:
33
- print("Warning: RF-DETR not found. Please ensure it's properly installed.")
34
- RFDETRMedium = None
35
-
36
- # ============================================================================
37
- # Configuration for Hugging Face Spaces
38
- # ============================================================================
39
-
40
- class SpacesConfig:
41
- """Configuration optimized for Hugging Face Spaces."""
42
-
43
- def __init__(self):
44
- self.settings = {
45
- 'results_dir': '/tmp/results',
46
- 'checkpoint': None,
47
- 'resolution': 576,
48
- 'threshold': 0.7,
49
- 'use_llm': True,
50
- 'llm_model_id': 'google/medgemma-4b-it',
51
- 'llm_max_new_tokens': 200,
52
- 'llm_temperature': 0.2,
53
- 'llm_4bit': True,
54
- 'enable_caching': True,
55
- 'max_cache_size': 100,
56
- }
57
-
58
- def get(self, key: str, default: Any = None) -> Any:
59
- return self.settings.get(key, default)
60
-
61
- # ============================================================================
62
- # Memory Management (simplified for Spaces)
63
- # ============================================================================
64
-
65
- class MemoryManager:
66
- """Simplified memory management for Spaces."""
67
-
68
- def __init__(self):
69
- self.memory_thresholds = {
70
- 'gpu_warning': 0.8,
71
- 'system_warning': 0.85,
72
- }
73
-
74
- def cleanup_memory(self, force: bool = False) -> None:
75
- """Perform memory cleanup."""
76
- try:
77
- gc.collect()
78
- if torch and torch.cuda.is_available():
79
- torch.cuda.empty_cache()
80
- torch.cuda.synchronize()
81
- except Exception as e:
82
- print(f"Memory cleanup error: {e}")
83
-
84
- # Global memory manager
85
- memory_manager = MemoryManager()
86
-
87
- # ============================================================================
88
- # Model Loading
89
- # ============================================================================
90
-
91
- def find_checkpoint() -> Optional[str]:
92
- """Find RF-DETR checkpoint in various locations."""
93
- candidates = [
94
- "rf-detr-medium.pth", # Current directory
95
- "/tmp/results/checkpoint_best_total.pth",
96
- "/tmp/results/checkpoint_best_ema.pth",
97
- "/tmp/results/checkpoint_best_regular.pth",
98
- "/tmp/results/checkpoint.pth",
99
- ]
100
-
101
- for path in candidates:
102
- if os.path.isfile(path):
103
- return path
104
- return None
105
-
106
- def load_model(checkpoint_path: str, resolution: int):
107
- """Load RF-DETR model."""
108
- if RFDETRMedium is None:
109
- raise RuntimeError("RF-DETR not available. Please install it properly.")
110
-
111
- model = RFDETRMedium(pretrain_weights=checkpoint_path, resolution=resolution)
112
- try:
113
- model.optimize_for_inference()
114
- except Exception:
115
- pass
116
- return model
117
-
118
- # ============================================================================
119
- # LLM Integration
120
- # ============================================================================
121
-
122
- class TextGenerator:
123
- """Simplified text generator for Spaces."""
124
-
125
- def __init__(self, model_id: str, max_tokens: int = 200, temperature: float = 0.2):
126
- self.model_id = model_id
127
- self.max_tokens = max_tokens
128
- self.temperature = temperature
129
- self.model = None
130
- self.tokenizer = None
131
- self.processor = None
132
- self.is_multimodal = False
133
-
134
- def load_model(self):
135
- """Load the LLM model."""
136
- if self.model is not None:
137
- return
138
-
139
- if (AutoModelForCausalLM is None and AutoModelForImageTextToText is None):
140
- raise RuntimeError("Transformers not available")
141
-
142
- # Clear memory before loading
143
- memory_manager.cleanup_memory()
144
-
145
- print(f"Loading model: {self.model_id}")
146
-
147
- model_kwargs = {
148
- "device_map": "auto",
149
- "low_cpu_mem_usage": True,
150
- }
151
-
152
- if torch and torch.cuda.is_available():
153
- model_kwargs["torch_dtype"] = torch.bfloat16
154
-
155
- # Use 4-bit quantization if available
156
- if BitsAndBytesConfig is not None:
157
- try:
158
- compute_dtype = torch.bfloat16 if torch and torch.cuda.is_available() else torch.float16
159
- model_kwargs["quantization_config"] = BitsAndBytesConfig(
160
- load_in_4bit=True,
161
- bnb_4bit_compute_dtype=compute_dtype,
162
- bnb_4bit_use_double_quant=True,
163
- bnb_4bit_quant_type="nf4"
164
- )
165
- model_kwargs["torch_dtype"] = compute_dtype
166
- except Exception:
167
- pass
168
-
169
- # Check if it's a multimodal model
170
- is_multimodal = "medgemma" in self.model_id.lower()
171
-
172
- if is_multimodal and AutoModelForImageTextToText is not None and AutoProcessor is not None:
173
- self.processor = AutoProcessor.from_pretrained(self.model_id)
174
- self.model = AutoModelForImageTextToText.from_pretrained(self.model_id, **model_kwargs)
175
- self.is_multimodal = True
176
- elif AutoModelForCausalLM is not None and AutoTokenizer is not None:
177
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
178
- self.model = AutoModelForCausalLM.from_pretrained(self.model_id, **model_kwargs)
179
- self.is_multimodal = False
180
- else:
181
- raise RuntimeError("Required model classes not available")
182
-
183
- print("✓ Model loaded successfully")
184
-
185
- def generate(self, text: str, image: Optional[Image.Image] = None) -> str:
186
- """Generate text using the loaded model."""
187
- self.load_model()
188
-
189
- if self.model is None:
190
- return f"[Model not loaded: {text}]"
191
-
192
- try:
193
- # Create messages
194
- system_text = "You are a concise medical assistant. Provide a brief, clear summary of detection results. Avoid repetition and be direct. Do not give medical advice."
195
- user_text = f"Summarize these detection results in 3 clear sentences:\n\n{text}"
196
-
197
- if self.is_multimodal:
198
- # Multimodal model
199
- user_content = [{"type": "text", "text": user_text}]
200
- if image is not None:
201
- user_content.append({"type": "image", "image": image})
202
-
203
- messages = [
204
- {"role": "system", "content": [{"type": "text", "text": system_text}]},
205
- {"role": "user", "content": user_content},
206
- ]
207
-
208
- inputs = self.processor.apply_chat_template(
209
- messages,
210
- add_generation_prompt=True,
211
- tokenize=True,
212
- return_dict=True,
213
- return_tensors="pt",
214
- )
215
-
216
- if torch:
217
- inputs = inputs.to(self.model.device, dtype=torch.bfloat16)
218
-
219
- with torch.inference_mode():
220
- generation = self.model.generate(
221
- **inputs,
222
- max_new_tokens=self.max_tokens,
223
- do_sample=self.temperature > 0,
224
- temperature=max(0.01, self.temperature) if self.temperature > 0 else None,
225
- use_cache=False,
226
- )
227
-
228
- input_len = inputs["input_ids"].shape[-1]
229
- generation = generation[0][input_len:]
230
- decoded = self.processor.decode(generation, skip_special_tokens=True)
231
- return decoded.strip()
232
-
233
- else:
234
- # Text-only model
235
- messages = [
236
- {"role": "system", "content": system_text},
237
- {"role": "user", "content": user_text},
238
- ]
239
-
240
- inputs = self.tokenizer.apply_chat_template(
241
- messages,
242
- add_generation_prompt=True,
243
- tokenize=True,
244
- return_dict=True,
245
- return_tensors="pt",
246
- )
247
-
248
- inputs = inputs.to(self.model.device)
249
-
250
- with torch.inference_mode():
251
- generation = self.model.generate(
252
- **inputs,
253
- max_new_tokens=self.max_tokens,
254
- do_sample=self.temperature > 0,
255
- temperature=max(0.01, self.temperature) if self.temperature > 0 else None,
256
- use_cache=False,
257
- )
258
-
259
- input_len = inputs["input_ids"].shape[-1]
260
- generation = generation[0][input_len:]
261
- decoded = self.tokenizer.decode(generation, skip_special_tokens=True)
262
- return decoded.strip()
263
-
264
- except Exception as e:
265
- error_msg = f"[Generation error: {e}]"
266
- print(f"Generation error: {traceback.format_exc()}")
267
- return f"{error_msg}\n\n{text}"
268
-
269
- # ============================================================================
270
- # Application State
271
- # ============================================================================
272
-
273
- class AppState:
274
- """Application state for Spaces."""
275
-
276
- def __init__(self):
277
- self.config = SpacesConfig()
278
- self.model = None
279
- self.class_names = None
280
- self.text_generator = None
281
-
282
- def load_model(self):
283
- """Load the detection model."""
284
- if self.model is not None:
285
- return
286
-
287
- checkpoint = find_checkpoint()
288
- if not checkpoint:
289
- raise FileNotFoundError(
290
- "No RF-DETR checkpoint found. Please upload rf-detr-medium.pth to your Space."
291
- )
292
-
293
- print(f"Loading RF-DETR from: {checkpoint}")
294
- self.model = load_model(checkpoint, self.config.get('resolution'))
295
-
296
- # Try to load class names
297
- try:
298
- results_json = "/tmp/results/results.json"
299
- if os.path.isfile(results_json):
300
- with open(results_json, 'r') as f:
301
- data = json.load(f)
302
- classes = []
303
- for split in ("valid", "test", "train"):
304
- if "class_map" in data and split in data["class_map"]:
305
- for item in data["class_map"][split]:
306
- name = item.get("class")
307
- if name and name != "all" and name not in classes:
308
- classes.append(name)
309
- self.class_names = classes if classes else None
310
- except Exception:
311
- pass
312
-
313
- print("✓ RF-DETR model loaded")
314
-
315
- def get_text_generator(self, model_size: str = "4B") -> TextGenerator:
316
- """Get or create text generator."""
317
- # Determine model ID based on size selection
318
- model_id = 'google/medgemma-27b-it' if model_size == "27B" else 'google/medgemma-4b-it'
319
-
320
- # Check if we need to create a new generator for different model size
321
- if (self.text_generator is None or
322
- hasattr(self.text_generator, 'model_id') and
323
- self.text_generator.model_id != model_id):
324
-
325
- max_tokens = self.config.get('llm_max_new_tokens')
326
- temperature = self.config.get('llm_temperature')
327
-
328
- self.text_generator = TextGenerator(model_id, max_tokens, temperature)
329
- return self.text_generator
330
-
331
- # ============================================================================
332
- # UI and Inference
333
- # ============================================================================
334
-
335
- def create_detection_interface():
336
- """Create the Gradio interface."""
337
-
338
- # Color palette for annotations
339
- COLOR_PALETTE = sv.ColorPalette.from_hex([
340
- "#ffff00", "#ff9b00", "#ff66ff", "#3399ff", "#ff66b2",
341
- "#ff8080", "#b266ff", "#9999ff", "#66ffff", "#33ff99",
342
- "#66ff66", "#99ff00",
343
- ])
344
-
345
- def annotate_image(image: Image.Image, threshold: float, model_size: str = "4B") -> Tuple[Image.Image, str]:
346
- """Process an image and return annotated version with description."""
347
-
348
- if image is None:
349
- return None, "Please upload an image."
350
-
351
- try:
352
- # Load model if needed
353
- app_state.load_model()
354
-
355
- # Run detection
356
- detections = app_state.model.predict(image, threshold=threshold)
357
-
358
- # Annotate image
359
- bbox_annotator = sv.BoxAnnotator(color=COLOR_PALETTE, thickness=2)
360
- label_annotator = sv.LabelAnnotator(text_scale=0.5, text_color=sv.Color.BLACK)
361
-
362
- labels = []
363
- for i in range(len(detections)):
364
- class_id = int(detections.class_id[i]) if detections.class_id is not None else None
365
- conf = float(detections.confidence[i]) if detections.confidence is not None else 0.0
366
-
367
- if app_state.class_names and class_id is not None:
368
- if 0 <= class_id < len(app_state.class_names):
369
- label_name = app_state.class_names[class_id]
370
- else:
371
- label_name = str(class_id)
372
- else:
373
- label_name = str(class_id) if class_id is not None else "object"
374
-
375
- labels.append(f"{label_name} {conf:.2f}")
376
-
377
- annotated = image.copy()
378
- annotated = bbox_annotator.annotate(annotated, detections)
379
- annotated = label_annotator.annotate(annotated, detections, labels)
380
-
381
- # Generate description
382
- description = f"Found {len(detections)} detections above threshold {threshold}:\n\n"
383
-
384
- if len(detections) > 0:
385
- counts = {}
386
- for i in range(len(detections)):
387
- class_id = int(detections.class_id[i]) if detections.class_id is not None else None
388
- if app_state.class_names and class_id is not None:
389
- if 0 <= class_id < len(app_state.class_names):
390
- name = app_state.class_names[class_id]
391
- else:
392
- name = str(class_id)
393
- else:
394
- name = str(class_id) if class_id is not None else "object"
395
- counts[name] = counts.get(name, 0) + 1
396
-
397
- for name, count in counts.items():
398
- description += f"- {count}× {name}\n"
399
-
400
- # Use LLM for description if enabled
401
- if app_state.config.get('use_llm'):
402
- try:
403
- generator = app_state.get_text_generator(model_size)
404
- llm_description = generator.generate(description, image=annotated)
405
- description = llm_description
406
- except Exception as e:
407
- description = f"[LLM error: {e}]\n\n{description}"
408
- else:
409
- description += "No objects detected above the confidence threshold."
410
-
411
- return annotated, description
412
-
413
- except Exception as e:
414
- error_msg = f"Error processing image: {str(e)}"
415
- print(f"Processing error: {traceback.format_exc()}")
416
- return None, error_msg
417
-
418
- # Create the interface
419
- with gr.Blocks(title="Medical Image Analysis", theme=gr.themes.Soft()) as demo:
420
- gr.Markdown("# 🏥 Medical Image Analysis")
421
- gr.Markdown("Upload a medical image to detect and analyze findings using AI.")
422
-
423
- with gr.Row():
424
- with gr.Column():
425
- input_image = gr.Image(type="pil", label="Upload Image", height=400)
426
- threshold_slider = gr.Slider(
427
- minimum=0.1,
428
- maximum=1.0,
429
- value=0.7,
430
- step=0.05,
431
- label="Confidence Threshold",
432
- info="Higher values = fewer but more confident detections"
433
- )
434
-
435
- model_size_radio = gr.Radio(
436
- choices=["4B", "27B"],
437
- value="4B",
438
- label="MedGemma Model Size",
439
- info="4B: Faster, less memory | 27B: More accurate, more memory"
440
- )
441
-
442
- analyze_btn = gr.Button("🔍 Analyze Image", variant="primary")
443
-
444
- with gr.Column():
445
- output_image = gr.Image(type="pil", label="Results", height=400)
446
- output_text = gr.Textbox(
447
- label="Analysis Results",
448
- lines=8,
449
- max_lines=15,
450
- show_copy_button=True
451
- )
452
-
453
- # Wire up the interface
454
- analyze_btn.click(
455
- fn=annotate_image,
456
- inputs=[input_image, threshold_slider, model_size_radio],
457
- outputs=[output_image, output_text]
458
- )
459
-
460
- # Also run when image is uploaded
461
- input_image.change(
462
- fn=annotate_image,
463
- inputs=[input_image, threshold_slider, model_size_radio],
464
- outputs=[output_image, output_text]
465
- )
466
-
467
- # Footer
468
- gr.Markdown("---")
469
- gr.Markdown("*Powered by RF-DETR and MedGemma • Built for Hugging Face Spaces*")
470
-
471
- return demo
472
-
473
- # ============================================================================
474
- # Main Application
475
- # ============================================================================
476
-
477
- # Global app state
478
- app_state = AppState()
479
-
480
- def main():
481
- """Main entry point for the Spaces app."""
482
- print("🚀 Starting Medical Image Analysis App")
483
-
484
- # Ensure results directory exists
485
- os.makedirs(app_state.config.get('results_dir'), exist_ok=True)
486
-
487
- # Create and launch the interface
488
- demo = create_detection_interface()
489
-
490
- # Launch with Spaces-optimized settings
491
- demo.launch(
492
- server_name="0.0.0.0",
493
- server_port=7860,
494
- share=False, # Spaces handles this
495
- show_error=True,
496
- show_api=False,
497
- )
498
-
499
- if __name__ == "__main__":
500
- main()
 
1
+ import os
2
+ import json
3
+ import gc
4
+ import time
5
+ import traceback
6
+ from typing import Dict, List, Optional, Tuple, Callable, Any
7
+
8
+ import torch
9
+ import gradio as gr
10
+ import supervision as sv
11
+ from PIL import Image
12
+
13
+ # Try to import optional dependencies
14
+ try:
15
+ from transformers import (
16
+ AutoModelForCausalLM,
17
+ AutoTokenizer,
18
+ AutoModelForImageTextToText,
19
+ AutoProcessor,
20
+ BitsAndBytesConfig,
21
+ )
22
+ except Exception:
23
+ AutoModelForCausalLM = None
24
+ AutoTokenizer = None
25
+ AutoModelForImageTextToText = None
26
+ AutoProcessor = None
27
+ BitsAndBytesConfig = None
28
+
29
+ # Import RF-DETR (assumes it's in the same directory or installed)
30
+ try:
31
+ from rfdetr import RFDETRMedium
32
+ except ImportError:
33
+ print("Warning: RF-DETR not found. Please ensure it's properly installed.")
34
+ RFDETRMedium = None
35
+
36
+ # ============================================================================
37
+ # Configuration for Hugging Face Spaces
38
+ # ============================================================================
39
+
40
+ class SpacesConfig:
41
+ """Configuration optimized for Hugging Face Spaces."""
42
+
43
+ def __init__(self):
44
+ self.settings = {
45
+ 'results_dir': '/tmp/results',
46
+ 'checkpoint': None,
47
+ 'resolution': 576,
48
+ 'threshold': 0.7,
49
+ 'use_llm': True,
50
+ 'llm_model_id': 'google/medgemma-4b-it',
51
+ 'llm_max_new_tokens': 200,
52
+ 'llm_temperature': 0.2,
53
+ 'llm_4bit': True,
54
+ 'enable_caching': True,
55
+ 'max_cache_size': 100,
56
+ }
57
+
58
+ def get(self, key: str, default: Any = None) -> Any:
59
+ return self.settings.get(key, default)
60
+
61
+ # ============================================================================
62
+ # Memory Management (simplified for Spaces)
63
+ # ============================================================================
64
+
65
+ class MemoryManager:
66
+ """Simplified memory management for Spaces."""
67
+
68
+ def __init__(self):
69
+ self.memory_thresholds = {
70
+ 'gpu_warning': 0.8,
71
+ 'system_warning': 0.85,
72
+ }
73
+
74
+ def cleanup_memory(self, force: bool = False) -> None:
75
+ """Perform memory cleanup."""
76
+ try:
77
+ gc.collect()
78
+ if torch and torch.cuda.is_available():
79
+ torch.cuda.empty_cache()
80
+ torch.cuda.synchronize()
81
+ except Exception as e:
82
+ print(f"Memory cleanup error: {e}")
83
+
84
+ # Global memory manager
85
+ memory_manager = MemoryManager()
86
+
87
+ # ============================================================================
88
+ # Model Loading
89
+ # ============================================================================
90
+
91
+ def find_checkpoint() -> Optional[str]:
92
+ """Find RF-DETR checkpoint in various locations."""
93
+ candidates = [
94
+ "edeler/rf-detr-lorai/rf-detr-medium.pth", # Current directory
95
+ "/tmp/results/checkpoint_best_total.pth",
96
+ "/tmp/results/checkpoint_best_ema.pth",
97
+ "/tmp/results/checkpoint_best_regular.pth",
98
+ "/tmp/results/checkpoint.pth",
99
+ ]
100
+
101
+ for path in candidates:
102
+ if os.path.isfile(path):
103
+ return path
104
+ return None
105
+
106
+ def load_model(checkpoint_path: str, resolution: int):
107
+ """Load RF-DETR model."""
108
+ if RFDETRMedium is None:
109
+ raise RuntimeError("RF-DETR not available. Please install it properly.")
110
+
111
+ model = RFDETRMedium(pretrain_weights=checkpoint_path, resolution=resolution)
112
+ try:
113
+ model.optimize_for_inference()
114
+ except Exception:
115
+ pass
116
+ return model
117
+
118
+ # ============================================================================
119
+ # LLM Integration
120
+ # ============================================================================
121
+
122
+ class TextGenerator:
123
+ """Simplified text generator for Spaces."""
124
+
125
+ def __init__(self, model_id: str, max_tokens: int = 200, temperature: float = 0.2):
126
+ self.model_id = model_id
127
+ self.max_tokens = max_tokens
128
+ self.temperature = temperature
129
+ self.model = None
130
+ self.tokenizer = None
131
+ self.processor = None
132
+ self.is_multimodal = False
133
+
134
+ def load_model(self):
135
+ """Load the LLM model."""
136
+ if self.model is not None:
137
+ return
138
+
139
+ if (AutoModelForCausalLM is None and AutoModelForImageTextToText is None):
140
+ raise RuntimeError("Transformers not available")
141
+
142
+ # Clear memory before loading
143
+ memory_manager.cleanup_memory()
144
+
145
+ print(f"Loading model: {self.model_id}")
146
+
147
+ model_kwargs = {
148
+ "device_map": "auto",
149
+ "low_cpu_mem_usage": True,
150
+ }
151
+
152
+ if torch and torch.cuda.is_available():
153
+ model_kwargs["torch_dtype"] = torch.bfloat16
154
+
155
+ # Use 4-bit quantization if available
156
+ if BitsAndBytesConfig is not None:
157
+ try:
158
+ compute_dtype = torch.bfloat16 if torch and torch.cuda.is_available() else torch.float16
159
+ model_kwargs["quantization_config"] = BitsAndBytesConfig(
160
+ load_in_4bit=True,
161
+ bnb_4bit_compute_dtype=compute_dtype,
162
+ bnb_4bit_use_double_quant=True,
163
+ bnb_4bit_quant_type="nf4"
164
+ )
165
+ model_kwargs["torch_dtype"] = compute_dtype
166
+ except Exception:
167
+ pass
168
+
169
+ # Check if it's a multimodal model
170
+ is_multimodal = "medgemma" in self.model_id.lower()
171
+
172
+ if is_multimodal and AutoModelForImageTextToText is not None and AutoProcessor is not None:
173
+ self.processor = AutoProcessor.from_pretrained(self.model_id)
174
+ self.model = AutoModelForImageTextToText.from_pretrained(self.model_id, **model_kwargs)
175
+ self.is_multimodal = True
176
+ elif AutoModelForCausalLM is not None and AutoTokenizer is not None:
177
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
178
+ self.model = AutoModelForCausalLM.from_pretrained(self.model_id, **model_kwargs)
179
+ self.is_multimodal = False
180
+ else:
181
+ raise RuntimeError("Required model classes not available")
182
+
183
+ print("✓ Model loaded successfully")
184
+
185
+ def generate(self, text: str, image: Optional[Image.Image] = None) -> str:
186
+ """Generate text using the loaded model."""
187
+ self.load_model()
188
+
189
+ if self.model is None:
190
+ return f"[Model not loaded: {text}]"
191
+
192
+ try:
193
+ # Create messages
194
+ system_text = "You are a concise medical assistant. Provide a brief, clear summary of detection results. Avoid repetition and be direct. Do not give medical advice."
195
+ user_text = f"Summarize these detection results in 3 clear sentences:\n\n{text}"
196
+
197
+ if self.is_multimodal:
198
+ # Multimodal model
199
+ user_content = [{"type": "text", "text": user_text}]
200
+ if image is not None:
201
+ user_content.append({"type": "image", "image": image})
202
+
203
+ messages = [
204
+ {"role": "system", "content": [{"type": "text", "text": system_text}]},
205
+ {"role": "user", "content": user_content},
206
+ ]
207
+
208
+ inputs = self.processor.apply_chat_template(
209
+ messages,
210
+ add_generation_prompt=True,
211
+ tokenize=True,
212
+ return_dict=True,
213
+ return_tensors="pt",
214
+ )
215
+
216
+ if torch:
217
+ inputs = inputs.to(self.model.device, dtype=torch.bfloat16)
218
+
219
+ with torch.inference_mode():
220
+ generation = self.model.generate(
221
+ **inputs,
222
+ max_new_tokens=self.max_tokens,
223
+ do_sample=self.temperature > 0,
224
+ temperature=max(0.01, self.temperature) if self.temperature > 0 else None,
225
+ use_cache=False,
226
+ )
227
+
228
+ input_len = inputs["input_ids"].shape[-1]
229
+ generation = generation[0][input_len:]
230
+ decoded = self.processor.decode(generation, skip_special_tokens=True)
231
+ return decoded.strip()
232
+
233
+ else:
234
+ # Text-only model
235
+ messages = [
236
+ {"role": "system", "content": system_text},
237
+ {"role": "user", "content": user_text},
238
+ ]
239
+
240
+ inputs = self.tokenizer.apply_chat_template(
241
+ messages,
242
+ add_generation_prompt=True,
243
+ tokenize=True,
244
+ return_dict=True,
245
+ return_tensors="pt",
246
+ )
247
+
248
+ inputs = inputs.to(self.model.device)
249
+
250
+ with torch.inference_mode():
251
+ generation = self.model.generate(
252
+ **inputs,
253
+ max_new_tokens=self.max_tokens,
254
+ do_sample=self.temperature > 0,
255
+ temperature=max(0.01, self.temperature) if self.temperature > 0 else None,
256
+ use_cache=False,
257
+ )
258
+
259
+ input_len = inputs["input_ids"].shape[-1]
260
+ generation = generation[0][input_len:]
261
+ decoded = self.tokenizer.decode(generation, skip_special_tokens=True)
262
+ return decoded.strip()
263
+
264
+ except Exception as e:
265
+ error_msg = f"[Generation error: {e}]"
266
+ print(f"Generation error: {traceback.format_exc()}")
267
+ return f"{error_msg}\n\n{text}"
268
+
269
+ # ============================================================================
270
+ # Application State
271
+ # ============================================================================
272
+
273
+ class AppState:
274
+ """Application state for Spaces."""
275
+
276
+ def __init__(self):
277
+ self.config = SpacesConfig()
278
+ self.model = None
279
+ self.class_names = None
280
+ self.text_generator = None
281
+
282
+ def load_model(self):
283
+ """Load the detection model."""
284
+ if self.model is not None:
285
+ return
286
+
287
+ checkpoint = find_checkpoint()
288
+ if not checkpoint:
289
+ raise FileNotFoundError(
290
+ "No RF-DETR checkpoint found. Please upload rf-detr-medium.pth to your Space."
291
+ )
292
+
293
+ print(f"Loading RF-DETR from: {checkpoint}")
294
+ self.model = load_model(checkpoint, self.config.get('resolution'))
295
+
296
+ # Try to load class names
297
+ try:
298
+ results_json = "/tmp/results/results.json"
299
+ if os.path.isfile(results_json):
300
+ with open(results_json, 'r') as f:
301
+ data = json.load(f)
302
+ classes = []
303
+ for split in ("valid", "test", "train"):
304
+ if "class_map" in data and split in data["class_map"]:
305
+ for item in data["class_map"][split]:
306
+ name = item.get("class")
307
+ if name and name != "all" and name not in classes:
308
+ classes.append(name)
309
+ self.class_names = classes if classes else None
310
+ except Exception:
311
+ pass
312
+
313
+ print("✓ RF-DETR model loaded")
314
+
315
+ def get_text_generator(self, model_size: str = "4B") -> TextGenerator:
316
+ """Get or create text generator."""
317
+ # Determine model ID based on size selection
318
+ model_id = 'google/medgemma-27b-it' if model_size == "27B" else 'google/medgemma-4b-it'
319
+
320
+ # Check if we need to create a new generator for different model size
321
+ if (self.text_generator is None or
322
+ hasattr(self.text_generator, 'model_id') and
323
+ self.text_generator.model_id != model_id):
324
+
325
+ max_tokens = self.config.get('llm_max_new_tokens')
326
+ temperature = self.config.get('llm_temperature')
327
+
328
+ self.text_generator = TextGenerator(model_id, max_tokens, temperature)
329
+ return self.text_generator
330
+
331
+ # ============================================================================
332
+ # UI and Inference
333
+ # ============================================================================
334
+
335
+ def create_detection_interface():
336
+ """Create the Gradio interface."""
337
+
338
+ # Color palette for annotations
339
+ COLOR_PALETTE = sv.ColorPalette.from_hex([
340
+ "#ffff00", "#ff9b00", "#ff66ff", "#3399ff", "#ff66b2",
341
+ "#ff8080", "#b266ff", "#9999ff", "#66ffff", "#33ff99",
342
+ "#66ff66", "#99ff00",
343
+ ])
344
+
345
+ def annotate_image(image: Image.Image, threshold: float, model_size: str = "4B") -> Tuple[Image.Image, str]:
346
+ """Process an image and return annotated version with description."""
347
+
348
+ if image is None:
349
+ return None, "Please upload an image."
350
+
351
+ try:
352
+ # Load model if needed
353
+ app_state.load_model()
354
+
355
+ # Run detection
356
+ detections = app_state.model.predict(image, threshold=threshold)
357
+
358
+ # Annotate image
359
+ bbox_annotator = sv.BoxAnnotator(color=COLOR_PALETTE, thickness=2)
360
+ label_annotator = sv.LabelAnnotator(text_scale=0.5, text_color=sv.Color.BLACK)
361
+
362
+ labels = []
363
+ for i in range(len(detections)):
364
+ class_id = int(detections.class_id[i]) if detections.class_id is not None else None
365
+ conf = float(detections.confidence[i]) if detections.confidence is not None else 0.0
366
+
367
+ if app_state.class_names and class_id is not None:
368
+ if 0 <= class_id < len(app_state.class_names):
369
+ label_name = app_state.class_names[class_id]
370
+ else:
371
+ label_name = str(class_id)
372
+ else:
373
+ label_name = str(class_id) if class_id is not None else "object"
374
+
375
+ labels.append(f"{label_name} {conf:.2f}")
376
+
377
+ annotated = image.copy()
378
+ annotated = bbox_annotator.annotate(annotated, detections)
379
+ annotated = label_annotator.annotate(annotated, detections, labels)
380
+
381
+ # Generate description
382
+ description = f"Found {len(detections)} detections above threshold {threshold}:\n\n"
383
+
384
+ if len(detections) > 0:
385
+ counts = {}
386
+ for i in range(len(detections)):
387
+ class_id = int(detections.class_id[i]) if detections.class_id is not None else None
388
+ if app_state.class_names and class_id is not None:
389
+ if 0 <= class_id < len(app_state.class_names):
390
+ name = app_state.class_names[class_id]
391
+ else:
392
+ name = str(class_id)
393
+ else:
394
+ name = str(class_id) if class_id is not None else "object"
395
+ counts[name] = counts.get(name, 0) + 1
396
+
397
+ for name, count in counts.items():
398
+ description += f"- {count}× {name}\n"
399
+
400
+ # Use LLM for description if enabled
401
+ if app_state.config.get('use_llm'):
402
+ try:
403
+ generator = app_state.get_text_generator(model_size)
404
+ llm_description = generator.generate(description, image=annotated)
405
+ description = llm_description
406
+ except Exception as e:
407
+ description = f"[LLM error: {e}]\n\n{description}"
408
+ else:
409
+ description += "No objects detected above the confidence threshold."
410
+
411
+ return annotated, description
412
+
413
+ except Exception as e:
414
+ error_msg = f"Error processing image: {str(e)}"
415
+ print(f"Processing error: {traceback.format_exc()}")
416
+ return None, error_msg
417
+
418
+ # Create the interface
419
+ with gr.Blocks(title="Medical Image Analysis", theme=gr.themes.Soft()) as demo:
420
+ gr.Markdown("# ��� Medical Image Analysis")
421
+ gr.Markdown("Upload a medical image to detect and analyze findings using AI.")
422
+
423
+ with gr.Row():
424
+ with gr.Column():
425
+ input_image = gr.Image(type="pil", label="Upload Image", height=400)
426
+ threshold_slider = gr.Slider(
427
+ minimum=0.1,
428
+ maximum=1.0,
429
+ value=0.7,
430
+ step=0.05,
431
+ label="Confidence Threshold",
432
+ info="Higher values = fewer but more confident detections"
433
+ )
434
+
435
+ model_size_radio = gr.Radio(
436
+ choices=["4B", "27B"],
437
+ value="4B",
438
+ label="MedGemma Model Size",
439
+ info="4B: Faster, less memory | 27B: More accurate, more memory"
440
+ )
441
+
442
+ analyze_btn = gr.Button("🔍 Analyze Image", variant="primary")
443
+
444
+ with gr.Column():
445
+ output_image = gr.Image(type="pil", label="Results", height=400)
446
+ output_text = gr.Textbox(
447
+ label="Analysis Results",
448
+ lines=8,
449
+ max_lines=15,
450
+ show_copy_button=True
451
+ )
452
+
453
+ # Wire up the interface
454
+ analyze_btn.click(
455
+ fn=annotate_image,
456
+ inputs=[input_image, threshold_slider, model_size_radio],
457
+ outputs=[output_image, output_text]
458
+ )
459
+
460
+ # Also run when image is uploaded
461
+ input_image.change(
462
+ fn=annotate_image,
463
+ inputs=[input_image, threshold_slider, model_size_radio],
464
+ outputs=[output_image, output_text]
465
+ )
466
+
467
+ # Footer
468
+ gr.Markdown("---")
469
+ gr.Markdown("*Powered by RF-DETR and MedGemma • Built for Hugging Face Spaces*")
470
+
471
+ return demo
472
+
473
+ # ============================================================================
474
+ # Main Application
475
+ # ============================================================================
476
+
477
+ # Global app state
478
+ app_state = AppState()
479
+
480
+ def main():
481
+ """Main entry point for the Spaces app."""
482
+ print("🚀 Starting Medical Image Analysis App")
483
+
484
+ # Ensure results directory exists
485
+ os.makedirs(app_state.config.get('results_dir'), exist_ok=True)
486
+
487
+ # Create and launch the interface
488
+ demo = create_detection_interface()
489
+
490
+ # Launch with Spaces-optimized settings
491
+ demo.launch(
492
+ server_name="0.0.0.0",
493
+ server_port=7860,
494
+ share=False, # Spaces handles this
495
+ show_error=True,
496
+ show_api=False,
497
+ )
498
+
499
+ if __name__ == "__main__":
500
+ main()