edeler commited on
Commit
53dfff0
·
verified ·
1 Parent(s): 7056797

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +545 -500
app.py CHANGED
@@ -1,500 +1,545 @@
1
- import os
2
- import json
3
- import gc
4
- import time
5
- import traceback
6
- from typing import Dict, List, Optional, Tuple, Callable, Any
7
-
8
- import torch
9
- import gradio as gr
10
- import supervision as sv
11
- from PIL import Image
12
-
13
- # Try to import optional dependencies
14
- try:
15
- from transformers import (
16
- AutoModelForCausalLM,
17
- AutoTokenizer,
18
- AutoModelForImageTextToText,
19
- AutoProcessor,
20
- BitsAndBytesConfig,
21
- )
22
- except Exception:
23
- AutoModelForCausalLM = None
24
- AutoTokenizer = None
25
- AutoModelForImageTextToText = None
26
- AutoProcessor = None
27
- BitsAndBytesConfig = None
28
-
29
- # Import RF-DETR (assumes it's in the same directory or installed)
30
- try:
31
- from rfdetr import RFDETRMedium
32
- except ImportError:
33
- print("Warning: RF-DETR not found. Please ensure it's properly installed.")
34
- RFDETRMedium = None
35
-
36
- # ============================================================================
37
- # Configuration for Hugging Face Spaces
38
- # ============================================================================
39
-
40
- class SpacesConfig:
41
- """Configuration optimized for Hugging Face Spaces."""
42
-
43
- def __init__(self):
44
- self.settings = {
45
- 'results_dir': '/tmp/results',
46
- 'checkpoint': None,
47
- 'resolution': 576,
48
- 'threshold': 0.7,
49
- 'use_llm': True,
50
- 'llm_model_id': 'google/medgemma-4b-it',
51
- 'llm_max_new_tokens': 200,
52
- 'llm_temperature': 0.2,
53
- 'llm_4bit': True,
54
- 'enable_caching': True,
55
- 'max_cache_size': 100,
56
- }
57
-
58
- def get(self, key: str, default: Any = None) -> Any:
59
- return self.settings.get(key, default)
60
-
61
- # ============================================================================
62
- # Memory Management (simplified for Spaces)
63
- # ============================================================================
64
-
65
- class MemoryManager:
66
- """Simplified memory management for Spaces."""
67
-
68
- def __init__(self):
69
- self.memory_thresholds = {
70
- 'gpu_warning': 0.8,
71
- 'system_warning': 0.85,
72
- }
73
-
74
- def cleanup_memory(self, force: bool = False) -> None:
75
- """Perform memory cleanup."""
76
- try:
77
- gc.collect()
78
- if torch and torch.cuda.is_available():
79
- torch.cuda.empty_cache()
80
- torch.cuda.synchronize()
81
- except Exception as e:
82
- print(f"Memory cleanup error: {e}")
83
-
84
- # Global memory manager
85
- memory_manager = MemoryManager()
86
-
87
- # ============================================================================
88
- # Model Loading
89
- # ============================================================================
90
-
91
- def find_checkpoint() -> Optional[str]:
92
- """Find RF-DETR checkpoint in various locations."""
93
- candidates = [
94
- "edeler/rf-detr-lorai/rf-detr-medium.pth", # Current directory
95
- "/tmp/results/checkpoint_best_total.pth",
96
- "/tmp/results/checkpoint_best_ema.pth",
97
- "/tmp/results/checkpoint_best_regular.pth",
98
- "/tmp/results/checkpoint.pth",
99
- ]
100
-
101
- for path in candidates:
102
- if os.path.isfile(path):
103
- return path
104
- return None
105
-
106
- def load_model(checkpoint_path: str, resolution: int):
107
- """Load RF-DETR model."""
108
- if RFDETRMedium is None:
109
- raise RuntimeError("RF-DETR not available. Please install it properly.")
110
-
111
- model = RFDETRMedium(pretrain_weights=checkpoint_path, resolution=resolution)
112
- try:
113
- model.optimize_for_inference()
114
- except Exception:
115
- pass
116
- return model
117
-
118
- # ============================================================================
119
- # LLM Integration
120
- # ============================================================================
121
-
122
- class TextGenerator:
123
- """Simplified text generator for Spaces."""
124
-
125
- def __init__(self, model_id: str, max_tokens: int = 200, temperature: float = 0.2):
126
- self.model_id = model_id
127
- self.max_tokens = max_tokens
128
- self.temperature = temperature
129
- self.model = None
130
- self.tokenizer = None
131
- self.processor = None
132
- self.is_multimodal = False
133
-
134
- def load_model(self):
135
- """Load the LLM model."""
136
- if self.model is not None:
137
- return
138
-
139
- if (AutoModelForCausalLM is None and AutoModelForImageTextToText is None):
140
- raise RuntimeError("Transformers not available")
141
-
142
- # Clear memory before loading
143
- memory_manager.cleanup_memory()
144
-
145
- print(f"Loading model: {self.model_id}")
146
-
147
- model_kwargs = {
148
- "device_map": "auto",
149
- "low_cpu_mem_usage": True,
150
- }
151
-
152
- if torch and torch.cuda.is_available():
153
- model_kwargs["torch_dtype"] = torch.bfloat16
154
-
155
- # Use 4-bit quantization if available
156
- if BitsAndBytesConfig is not None:
157
- try:
158
- compute_dtype = torch.bfloat16 if torch and torch.cuda.is_available() else torch.float16
159
- model_kwargs["quantization_config"] = BitsAndBytesConfig(
160
- load_in_4bit=True,
161
- bnb_4bit_compute_dtype=compute_dtype,
162
- bnb_4bit_use_double_quant=True,
163
- bnb_4bit_quant_type="nf4"
164
- )
165
- model_kwargs["torch_dtype"] = compute_dtype
166
- except Exception:
167
- pass
168
-
169
- # Check if it's a multimodal model
170
- is_multimodal = "medgemma" in self.model_id.lower()
171
-
172
- if is_multimodal and AutoModelForImageTextToText is not None and AutoProcessor is not None:
173
- self.processor = AutoProcessor.from_pretrained(self.model_id)
174
- self.model = AutoModelForImageTextToText.from_pretrained(self.model_id, **model_kwargs)
175
- self.is_multimodal = True
176
- elif AutoModelForCausalLM is not None and AutoTokenizer is not None:
177
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
178
- self.model = AutoModelForCausalLM.from_pretrained(self.model_id, **model_kwargs)
179
- self.is_multimodal = False
180
- else:
181
- raise RuntimeError("Required model classes not available")
182
-
183
- print(" Model loaded successfully")
184
-
185
- def generate(self, text: str, image: Optional[Image.Image] = None) -> str:
186
- """Generate text using the loaded model."""
187
- self.load_model()
188
-
189
- if self.model is None:
190
- return f"[Model not loaded: {text}]"
191
-
192
- try:
193
- # Create messages
194
- system_text = "You are a concise medical assistant. Provide a brief, clear summary of detection results. Avoid repetition and be direct. Do not give medical advice."
195
- user_text = f"Summarize these detection results in 3 clear sentences:\n\n{text}"
196
-
197
- if self.is_multimodal:
198
- # Multimodal model
199
- user_content = [{"type": "text", "text": user_text}]
200
- if image is not None:
201
- user_content.append({"type": "image", "image": image})
202
-
203
- messages = [
204
- {"role": "system", "content": [{"type": "text", "text": system_text}]},
205
- {"role": "user", "content": user_content},
206
- ]
207
-
208
- inputs = self.processor.apply_chat_template(
209
- messages,
210
- add_generation_prompt=True,
211
- tokenize=True,
212
- return_dict=True,
213
- return_tensors="pt",
214
- )
215
-
216
- if torch:
217
- inputs = inputs.to(self.model.device, dtype=torch.bfloat16)
218
-
219
- with torch.inference_mode():
220
- generation = self.model.generate(
221
- **inputs,
222
- max_new_tokens=self.max_tokens,
223
- do_sample=self.temperature > 0,
224
- temperature=max(0.01, self.temperature) if self.temperature > 0 else None,
225
- use_cache=False,
226
- )
227
-
228
- input_len = inputs["input_ids"].shape[-1]
229
- generation = generation[0][input_len:]
230
- decoded = self.processor.decode(generation, skip_special_tokens=True)
231
- return decoded.strip()
232
-
233
- else:
234
- # Text-only model
235
- messages = [
236
- {"role": "system", "content": system_text},
237
- {"role": "user", "content": user_text},
238
- ]
239
-
240
- inputs = self.tokenizer.apply_chat_template(
241
- messages,
242
- add_generation_prompt=True,
243
- tokenize=True,
244
- return_dict=True,
245
- return_tensors="pt",
246
- )
247
-
248
- inputs = inputs.to(self.model.device)
249
-
250
- with torch.inference_mode():
251
- generation = self.model.generate(
252
- **inputs,
253
- max_new_tokens=self.max_tokens,
254
- do_sample=self.temperature > 0,
255
- temperature=max(0.01, self.temperature) if self.temperature > 0 else None,
256
- use_cache=False,
257
- )
258
-
259
- input_len = inputs["input_ids"].shape[-1]
260
- generation = generation[0][input_len:]
261
- decoded = self.tokenizer.decode(generation, skip_special_tokens=True)
262
- return decoded.strip()
263
-
264
- except Exception as e:
265
- error_msg = f"[Generation error: {e}]"
266
- print(f"Generation error: {traceback.format_exc()}")
267
- return f"{error_msg}\n\n{text}"
268
-
269
- # ============================================================================
270
- # Application State
271
- # ============================================================================
272
-
273
- class AppState:
274
- """Application state for Spaces."""
275
-
276
- def __init__(self):
277
- self.config = SpacesConfig()
278
- self.model = None
279
- self.class_names = None
280
- self.text_generator = None
281
-
282
- def load_model(self):
283
- """Load the detection model."""
284
- if self.model is not None:
285
- return
286
-
287
- checkpoint = find_checkpoint()
288
- if not checkpoint:
289
- raise FileNotFoundError(
290
- "No RF-DETR checkpoint found. Please upload rf-detr-medium.pth to your Space."
291
- )
292
-
293
- print(f"Loading RF-DETR from: {checkpoint}")
294
- self.model = load_model(checkpoint, self.config.get('resolution'))
295
-
296
- # Try to load class names
297
- try:
298
- results_json = "/tmp/results/results.json"
299
- if os.path.isfile(results_json):
300
- with open(results_json, 'r') as f:
301
- data = json.load(f)
302
- classes = []
303
- for split in ("valid", "test", "train"):
304
- if "class_map" in data and split in data["class_map"]:
305
- for item in data["class_map"][split]:
306
- name = item.get("class")
307
- if name and name != "all" and name not in classes:
308
- classes.append(name)
309
- self.class_names = classes if classes else None
310
- except Exception:
311
- pass
312
-
313
- print("✓ RF-DETR model loaded")
314
-
315
- def get_text_generator(self, model_size: str = "4B") -> TextGenerator:
316
- """Get or create text generator."""
317
- # Determine model ID based on size selection
318
- model_id = 'google/medgemma-27b-it' if model_size == "27B" else 'google/medgemma-4b-it'
319
-
320
- # Check if we need to create a new generator for different model size
321
- if (self.text_generator is None or
322
- hasattr(self.text_generator, 'model_id') and
323
- self.text_generator.model_id != model_id):
324
-
325
- max_tokens = self.config.get('llm_max_new_tokens')
326
- temperature = self.config.get('llm_temperature')
327
-
328
- self.text_generator = TextGenerator(model_id, max_tokens, temperature)
329
- return self.text_generator
330
-
331
- # ============================================================================
332
- # UI and Inference
333
- # ============================================================================
334
-
335
- def create_detection_interface():
336
- """Create the Gradio interface."""
337
-
338
- # Color palette for annotations
339
- COLOR_PALETTE = sv.ColorPalette.from_hex([
340
- "#ffff00", "#ff9b00", "#ff66ff", "#3399ff", "#ff66b2",
341
- "#ff8080", "#b266ff", "#9999ff", "#66ffff", "#33ff99",
342
- "#66ff66", "#99ff00",
343
- ])
344
-
345
- def annotate_image(image: Image.Image, threshold: float, model_size: str = "4B") -> Tuple[Image.Image, str]:
346
- """Process an image and return annotated version with description."""
347
-
348
- if image is None:
349
- return None, "Please upload an image."
350
-
351
- try:
352
- # Load model if needed
353
- app_state.load_model()
354
-
355
- # Run detection
356
- detections = app_state.model.predict(image, threshold=threshold)
357
-
358
- # Annotate image
359
- bbox_annotator = sv.BoxAnnotator(color=COLOR_PALETTE, thickness=2)
360
- label_annotator = sv.LabelAnnotator(text_scale=0.5, text_color=sv.Color.BLACK)
361
-
362
- labels = []
363
- for i in range(len(detections)):
364
- class_id = int(detections.class_id[i]) if detections.class_id is not None else None
365
- conf = float(detections.confidence[i]) if detections.confidence is not None else 0.0
366
-
367
- if app_state.class_names and class_id is not None:
368
- if 0 <= class_id < len(app_state.class_names):
369
- label_name = app_state.class_names[class_id]
370
- else:
371
- label_name = str(class_id)
372
- else:
373
- label_name = str(class_id) if class_id is not None else "object"
374
-
375
- labels.append(f"{label_name} {conf:.2f}")
376
-
377
- annotated = image.copy()
378
- annotated = bbox_annotator.annotate(annotated, detections)
379
- annotated = label_annotator.annotate(annotated, detections, labels)
380
-
381
- # Generate description
382
- description = f"Found {len(detections)} detections above threshold {threshold}:\n\n"
383
-
384
- if len(detections) > 0:
385
- counts = {}
386
- for i in range(len(detections)):
387
- class_id = int(detections.class_id[i]) if detections.class_id is not None else None
388
- if app_state.class_names and class_id is not None:
389
- if 0 <= class_id < len(app_state.class_names):
390
- name = app_state.class_names[class_id]
391
- else:
392
- name = str(class_id)
393
- else:
394
- name = str(class_id) if class_id is not None else "object"
395
- counts[name] = counts.get(name, 0) + 1
396
-
397
- for name, count in counts.items():
398
- description += f"- {count}× {name}\n"
399
-
400
- # Use LLM for description if enabled
401
- if app_state.config.get('use_llm'):
402
- try:
403
- generator = app_state.get_text_generator(model_size)
404
- llm_description = generator.generate(description, image=annotated)
405
- description = llm_description
406
- except Exception as e:
407
- description = f"[LLM error: {e}]\n\n{description}"
408
- else:
409
- description += "No objects detected above the confidence threshold."
410
-
411
- return annotated, description
412
-
413
- except Exception as e:
414
- error_msg = f"Error processing image: {str(e)}"
415
- print(f"Processing error: {traceback.format_exc()}")
416
- return None, error_msg
417
-
418
- # Create the interface
419
- with gr.Blocks(title="Medical Image Analysis", theme=gr.themes.Soft()) as demo:
420
- gr.Markdown("# 🏥 Medical Image Analysis")
421
- gr.Markdown("Upload a medical image to detect and analyze findings using AI.")
422
-
423
- with gr.Row():
424
- with gr.Column():
425
- input_image = gr.Image(type="pil", label="Upload Image", height=400)
426
- threshold_slider = gr.Slider(
427
- minimum=0.1,
428
- maximum=1.0,
429
- value=0.7,
430
- step=0.05,
431
- label="Confidence Threshold",
432
- info="Higher values = fewer but more confident detections"
433
- )
434
-
435
- model_size_radio = gr.Radio(
436
- choices=["4B", "27B"],
437
- value="4B",
438
- label="MedGemma Model Size",
439
- info="4B: Faster, less memory | 27B: More accurate, more memory"
440
- )
441
-
442
- analyze_btn = gr.Button("🔍 Analyze Image", variant="primary")
443
-
444
- with gr.Column():
445
- output_image = gr.Image(type="pil", label="Results", height=400)
446
- output_text = gr.Textbox(
447
- label="Analysis Results",
448
- lines=8,
449
- max_lines=15,
450
- show_copy_button=True
451
- )
452
-
453
- # Wire up the interface
454
- analyze_btn.click(
455
- fn=annotate_image,
456
- inputs=[input_image, threshold_slider, model_size_radio],
457
- outputs=[output_image, output_text]
458
- )
459
-
460
- # Also run when image is uploaded
461
- input_image.change(
462
- fn=annotate_image,
463
- inputs=[input_image, threshold_slider, model_size_radio],
464
- outputs=[output_image, output_text]
465
- )
466
-
467
- # Footer
468
- gr.Markdown("---")
469
- gr.Markdown("*Powered by RF-DETR and MedGemma • Built for Hugging Face Spaces*")
470
-
471
- return demo
472
-
473
- # ============================================================================
474
- # Main Application
475
- # ============================================================================
476
-
477
- # Global app state
478
- app_state = AppState()
479
-
480
- def main():
481
- """Main entry point for the Spaces app."""
482
- print("🚀 Starting Medical Image Analysis App")
483
-
484
- # Ensure results directory exists
485
- os.makedirs(app_state.config.get('results_dir'), exist_ok=True)
486
-
487
- # Create and launch the interface
488
- demo = create_detection_interface()
489
-
490
- # Launch with Spaces-optimized settings
491
- demo.launch(
492
- server_name="0.0.0.0",
493
- server_port=7860,
494
- share=False, # Spaces handles this
495
- show_error=True,
496
- show_api=False,
497
- )
498
-
499
- if __name__ == "__main__":
500
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import gc
4
+ import traceback
5
+ from typing import Optional, Tuple, Any
6
+
7
+ import torch
8
+ import gradio as gr
9
+ import supervision as sv
10
+ from PIL import Image
11
+
12
+ # Try to import optional dependencies
13
+ try:
14
+ from transformers import (
15
+ AutoModelForCausalLM,
16
+ AutoTokenizer,
17
+ AutoModelForImageTextToText,
18
+ AutoProcessor,
19
+ BitsAndBytesConfig,
20
+ )
21
+ except Exception:
22
+ AutoModelForCausalLM = None
23
+ AutoTokenizer = None
24
+ AutoModelForImageTextToText = None
25
+ AutoProcessor = None
26
+ BitsAndBytesConfig = None
27
+
28
+ # Try to import huggingface_hub for model downloading
29
+ try:
30
+ from huggingface_hub import hf_hub_download
31
+ except ImportError:
32
+ hf_hub_download = None
33
+
34
+ # Import RF-DETR (assumes it's in the same directory or installed)
35
+ try:
36
+ from rfdetr import RFDETRMedium
37
+ except ImportError:
38
+ print("Warning: RF-DETR not found. Please ensure it's properly installed.")
39
+ RFDETRMedium = None
40
+
41
+ # ============================================================================
42
+ # Configuration for Hugging Face Spaces
43
+ # ============================================================================
44
+
45
+ class SpacesConfig:
46
+ """Configuration optimized for Hugging Face Spaces."""
47
+
48
+ def __init__(self):
49
+ self.settings = {
50
+ 'results_dir': '/tmp/results',
51
+ 'checkpoint': None,
52
+ 'hf_model_repo': 'edeler/lorai', # Hugging Face model repository
53
+ 'hf_model_filename': 'lorai.pth',
54
+ 'resolution': 576,
55
+ 'threshold': 0.7,
56
+ 'use_llm': True,
57
+ 'llm_model_id': 'google/medgemma-4b-it',
58
+ 'llm_max_new_tokens': 200,
59
+ 'llm_temperature': 0.2,
60
+ 'llm_4bit': True,
61
+ 'enable_caching': True,
62
+ 'max_cache_size': 100,
63
+ }
64
+
65
+ def get(self, key: str, default: Any = None) -> Any:
66
+ return self.settings.get(key, default)
67
+
68
+ def set_hf_model_repo(self, repo_id: str, filename: str = 'lorai.pth'):
69
+ """Set Hugging Face model repository."""
70
+ self.settings['hf_model_repo'] = repo_id
71
+ self.settings['hf_model_filename'] = filename
72
+
73
+ # ============================================================================
74
+ # Memory Management (simplified for Spaces)
75
+ # ============================================================================
76
+
77
+ class MemoryManager:
78
+ """Simplified memory management for Spaces."""
79
+
80
+ def __init__(self):
81
+ self.memory_thresholds = {
82
+ 'gpu_warning': 0.8,
83
+ 'system_warning': 0.85,
84
+ }
85
+
86
+ def cleanup_memory(self, force: bool = False) -> None:
87
+ """Perform memory cleanup."""
88
+ try:
89
+ gc.collect()
90
+ if torch and torch.cuda.is_available():
91
+ torch.cuda.empty_cache()
92
+ torch.cuda.synchronize()
93
+ except Exception as e:
94
+ print(f"Memory cleanup error: {e}")
95
+
96
+ # Global memory manager
97
+ memory_manager = MemoryManager()
98
+
99
+ # ============================================================================
100
+ # Model Loading
101
+ # ============================================================================
102
+
103
+ def find_checkpoint(hf_repo: Optional[str] = None, hf_filename: str = 'lorai.pth') -> Optional[str]:
104
+ """Find RF-DETR checkpoint in various locations or download from Hugging Face Hub."""
105
+
106
+ # First check if we should download from Hugging Face
107
+ repo_id = hf_repo or os.environ.get('HF_MODEL_REPO')
108
+
109
+ if repo_id and hf_hub_download is not None:
110
+ try:
111
+ print(f"Downloading checkpoint from Hugging Face Hub: {repo_id}/{hf_filename}")
112
+ checkpoint_path = hf_hub_download(
113
+ repo_id=repo_id,
114
+ filename=hf_filename,
115
+ cache_dir="/tmp/hf_cache"
116
+ )
117
+ print(f"✓ Downloaded checkpoint to: {checkpoint_path}")
118
+ return checkpoint_path
119
+ except Exception as e:
120
+ print(f"Warning: Failed to download from Hugging Face Hub: {e}")
121
+ print("Falling back to local checkpoints...")
122
+
123
+ # Fall back to local file search
124
+ candidates = [
125
+ "lorai.pth", # Current directory
126
+ "rf-detr-medium.pth",
127
+ "/tmp/results/checkpoint_best_total.pth",
128
+ "/tmp/results/checkpoint_best_ema.pth",
129
+ "/tmp/results/checkpoint_best_regular.pth",
130
+ "/tmp/results/checkpoint.pth",
131
+ ]
132
+
133
+ for path in candidates:
134
+ if os.path.isfile(path):
135
+ print(f"Found local checkpoint: {path}")
136
+ return path
137
+
138
+ return None
139
+
140
+ def load_model(checkpoint_path: str, resolution: int):
141
+ """Load RF-DETR model."""
142
+ if RFDETRMedium is None:
143
+ raise RuntimeError("RF-DETR not available. Please install it properly.")
144
+
145
+ model = RFDETRMedium(pretrain_weights=checkpoint_path, resolution=resolution)
146
+ try:
147
+ model.optimize_for_inference()
148
+ except Exception:
149
+ pass
150
+ return model
151
+
152
+ # ============================================================================
153
+ # LLM Integration
154
+ # ============================================================================
155
+
156
+ class TextGenerator:
157
+ """Simplified text generator for Spaces."""
158
+
159
+ def __init__(self, model_id: str, max_tokens: int = 200, temperature: float = 0.2):
160
+ self.model_id = model_id
161
+ self.max_tokens = max_tokens
162
+ self.temperature = temperature
163
+ self.model = None
164
+ self.tokenizer = None
165
+ self.processor = None
166
+ self.is_multimodal = False
167
+
168
+ def load_model(self):
169
+ """Load the LLM model."""
170
+ if self.model is not None:
171
+ return
172
+
173
+ if (AutoModelForCausalLM is None and AutoModelForImageTextToText is None):
174
+ raise RuntimeError("Transformers not available")
175
+
176
+ # Clear memory before loading
177
+ memory_manager.cleanup_memory()
178
+
179
+ print(f"Loading model: {self.model_id}")
180
+
181
+ model_kwargs = {
182
+ "device_map": "auto",
183
+ "low_cpu_mem_usage": True,
184
+ }
185
+
186
+ if torch and torch.cuda.is_available():
187
+ model_kwargs["torch_dtype"] = torch.bfloat16
188
+
189
+ # Use 4-bit quantization if available
190
+ if BitsAndBytesConfig is not None:
191
+ try:
192
+ compute_dtype = torch.bfloat16 if torch and torch.cuda.is_available() else torch.float16
193
+ model_kwargs["quantization_config"] = BitsAndBytesConfig(
194
+ load_in_4bit=True,
195
+ bnb_4bit_compute_dtype=compute_dtype,
196
+ bnb_4bit_use_double_quant=True,
197
+ bnb_4bit_quant_type="nf4"
198
+ )
199
+ model_kwargs["torch_dtype"] = compute_dtype
200
+ except Exception:
201
+ pass
202
+
203
+ # Check if it's a multimodal model
204
+ is_multimodal = "medgemma" in self.model_id.lower()
205
+
206
+ if is_multimodal and AutoModelForImageTextToText is not None and AutoProcessor is not None:
207
+ self.processor = AutoProcessor.from_pretrained(self.model_id)
208
+ self.model = AutoModelForImageTextToText.from_pretrained(self.model_id, **model_kwargs)
209
+ self.is_multimodal = True
210
+ elif AutoModelForCausalLM is not None and AutoTokenizer is not None:
211
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
212
+ self.model = AutoModelForCausalLM.from_pretrained(self.model_id, **model_kwargs)
213
+ self.is_multimodal = False
214
+ else:
215
+ raise RuntimeError("Required model classes not available")
216
+
217
+ print("✓ Model loaded successfully")
218
+
219
+ def generate(self, text: str, image: Optional[Image.Image] = None) -> str:
220
+ """Generate text using the loaded model."""
221
+ self.load_model()
222
+
223
+ if self.model is None:
224
+ return f"[Model not loaded: {text}]"
225
+
226
+ try:
227
+ # Create messages
228
+ system_text = "You are a concise medical assistant. Provide a brief, clear summary of detection results. Avoid repetition and be direct. Do not give medical advice."
229
+ user_text = f"Summarize these detection results in 3 clear sentences:\n\n{text}"
230
+
231
+ if self.is_multimodal:
232
+ # Multimodal model
233
+ user_content = [{"type": "text", "text": user_text}]
234
+ if image is not None:
235
+ user_content.append({"type": "image", "image": image})
236
+
237
+ messages = [
238
+ {"role": "system", "content": [{"type": "text", "text": system_text}]},
239
+ {"role": "user", "content": user_content},
240
+ ]
241
+
242
+ inputs = self.processor.apply_chat_template(
243
+ messages,
244
+ add_generation_prompt=True,
245
+ tokenize=True,
246
+ return_dict=True,
247
+ return_tensors="pt",
248
+ )
249
+
250
+ if torch:
251
+ inputs = inputs.to(self.model.device, dtype=torch.bfloat16)
252
+
253
+ with torch.inference_mode():
254
+ generation = self.model.generate(
255
+ **inputs,
256
+ max_new_tokens=self.max_tokens,
257
+ do_sample=self.temperature > 0,
258
+ temperature=max(0.01, self.temperature) if self.temperature > 0 else None,
259
+ use_cache=False,
260
+ )
261
+
262
+ input_len = inputs["input_ids"].shape[-1]
263
+ generation = generation[0][input_len:]
264
+ decoded = self.processor.decode(generation, skip_special_tokens=True)
265
+ return decoded.strip()
266
+
267
+ else:
268
+ # Text-only model
269
+ messages = [
270
+ {"role": "system", "content": system_text},
271
+ {"role": "user", "content": user_text},
272
+ ]
273
+
274
+ inputs = self.tokenizer.apply_chat_template(
275
+ messages,
276
+ add_generation_prompt=True,
277
+ tokenize=True,
278
+ return_dict=True,
279
+ return_tensors="pt",
280
+ )
281
+
282
+ inputs = inputs.to(self.model.device)
283
+
284
+ with torch.inference_mode():
285
+ generation = self.model.generate(
286
+ **inputs,
287
+ max_new_tokens=self.max_tokens,
288
+ do_sample=self.temperature > 0,
289
+ temperature=max(0.01, self.temperature) if self.temperature > 0 else None,
290
+ use_cache=False,
291
+ )
292
+
293
+ input_len = inputs["input_ids"].shape[-1]
294
+ generation = generation[0][input_len:]
295
+ decoded = self.tokenizer.decode(generation, skip_special_tokens=True)
296
+ return decoded.strip()
297
+
298
+ except Exception as e:
299
+ error_msg = f"[Generation error: {e}]"
300
+ print(f"Generation error: {traceback.format_exc()}")
301
+ return f"{error_msg}\n\n{text}"
302
+
303
+ # ============================================================================
304
+ # Application State
305
+ # ============================================================================
306
+
307
+ class AppState:
308
+ """Application state for Spaces."""
309
+
310
+ def __init__(self):
311
+ self.config = SpacesConfig()
312
+ self.model = None
313
+ self.class_names = None
314
+ self.text_generator = None
315
+
316
+ def load_model(self):
317
+ """Load the detection model."""
318
+ if self.model is not None:
319
+ return
320
+
321
+ checkpoint = find_checkpoint(
322
+ hf_repo=self.config.get('hf_model_repo'),
323
+ hf_filename=self.config.get('hf_model_filename', 'lorai.pth')
324
+ )
325
+ if not checkpoint:
326
+ hf_repo = self.config.get('hf_model_repo') or os.environ.get('HF_MODEL_REPO')
327
+ if hf_repo:
328
+ raise FileNotFoundError(
329
+ f"No RF-DETR checkpoint found. Could not download from '{hf_repo}'. "
330
+ "Please check the repository ID and ensure the model file exists."
331
+ )
332
+ else:
333
+ raise FileNotFoundError(
334
+ "No RF-DETR checkpoint found. Please either:\n"
335
+ "1. Set HF_MODEL_REPO environment variable (e.g., 'edeler/lorai'), or\n"
336
+ "2. Upload lorai.pth to your Space's root directory"
337
+ )
338
+
339
+ print(f"Loading RF-DETR from: {checkpoint}")
340
+ self.model = load_model(checkpoint, self.config.get('resolution'))
341
+
342
+ # Try to load class names
343
+ try:
344
+ results_json = "/tmp/results/results.json"
345
+ if os.path.isfile(results_json):
346
+ with open(results_json, 'r') as f:
347
+ data = json.load(f)
348
+ classes = []
349
+ for split in ("valid", "test", "train"):
350
+ if "class_map" in data and split in data["class_map"]:
351
+ for item in data["class_map"][split]:
352
+ name = item.get("class")
353
+ if name and name != "all" and name not in classes:
354
+ classes.append(name)
355
+ self.class_names = classes if classes else None
356
+ except Exception:
357
+ pass
358
+
359
+ print("✓ RF-DETR model loaded")
360
+
361
+ def get_text_generator(self, model_size: str = "4B") -> TextGenerator:
362
+ """Get or create text generator."""
363
+ # Determine model ID based on size selection
364
+ model_id = 'google/medgemma-27b-it' if model_size == "27B" else 'google/medgemma-4b-it'
365
+
366
+ # Check if we need to create a new generator for different model size
367
+ if (self.text_generator is None or
368
+ hasattr(self.text_generator, 'model_id') and
369
+ self.text_generator.model_id != model_id):
370
+
371
+ max_tokens = self.config.get('llm_max_new_tokens')
372
+ temperature = self.config.get('llm_temperature')
373
+
374
+ self.text_generator = TextGenerator(model_id, max_tokens, temperature)
375
+ return self.text_generator
376
+
377
+ # ============================================================================
378
+ # UI and Inference
379
+ # ============================================================================
380
+
381
+ def create_detection_interface():
382
+ """Create the Gradio interface."""
383
+
384
+ # Color palette for annotations
385
+ COLOR_PALETTE = sv.ColorPalette.from_hex([
386
+ "#ffff00", "#ff9b00", "#ff66ff", "#3399ff", "#ff66b2",
387
+ "#ff8080", "#b266ff", "#9999ff", "#66ffff", "#33ff99",
388
+ "#66ff66", "#99ff00",
389
+ ])
390
+
391
+ def annotate_image(image: Image.Image, threshold: float, model_size: str = "4B") -> Tuple[Image.Image, str]:
392
+ """Process an image and return annotated version with description."""
393
+
394
+ if image is None:
395
+ return None, "Please upload an image."
396
+
397
+ try:
398
+ # Load model if needed
399
+ app_state.load_model()
400
+
401
+ # Run detection
402
+ detections = app_state.model.predict(image, threshold=threshold)
403
+
404
+ # Annotate image
405
+ bbox_annotator = sv.BoxAnnotator(color=COLOR_PALETTE, thickness=2)
406
+ label_annotator = sv.LabelAnnotator(text_scale=0.5, text_color=sv.Color.BLACK)
407
+
408
+ labels = []
409
+ for i in range(len(detections)):
410
+ class_id = int(detections.class_id[i]) if detections.class_id is not None else None
411
+ conf = float(detections.confidence[i]) if detections.confidence is not None else 0.0
412
+
413
+ if app_state.class_names and class_id is not None:
414
+ if 0 <= class_id < len(app_state.class_names):
415
+ label_name = app_state.class_names[class_id]
416
+ else:
417
+ label_name = str(class_id)
418
+ else:
419
+ label_name = str(class_id) if class_id is not None else "object"
420
+
421
+ labels.append(f"{label_name} {conf:.2f}")
422
+
423
+ annotated = image.copy()
424
+ annotated = bbox_annotator.annotate(annotated, detections)
425
+ annotated = label_annotator.annotate(annotated, detections, labels)
426
+
427
+ # Generate description
428
+ description = f"Found {len(detections)} detections above threshold {threshold}:\n\n"
429
+
430
+ if len(detections) > 0:
431
+ counts = {}
432
+ for i in range(len(detections)):
433
+ class_id = int(detections.class_id[i]) if detections.class_id is not None else None
434
+ if app_state.class_names and class_id is not None:
435
+ if 0 <= class_id < len(app_state.class_names):
436
+ name = app_state.class_names[class_id]
437
+ else:
438
+ name = str(class_id)
439
+ else:
440
+ name = str(class_id) if class_id is not None else "object"
441
+ counts[name] = counts.get(name, 0) + 1
442
+
443
+ for name, count in counts.items():
444
+ description += f"- {count}× {name}\n"
445
+
446
+ # Use LLM for description if enabled
447
+ if app_state.config.get('use_llm'):
448
+ try:
449
+ generator = app_state.get_text_generator(model_size)
450
+ llm_description = generator.generate(description, image=annotated)
451
+ description = llm_description
452
+ except Exception as e:
453
+ description = f"[LLM error: {e}]\n\n{description}"
454
+ else:
455
+ description += "No objects detected above the confidence threshold."
456
+
457
+ return annotated, description
458
+
459
+ except Exception as e:
460
+ error_msg = f"Error processing image: {str(e)}"
461
+ print(f"Processing error: {traceback.format_exc()}")
462
+ return None, error_msg
463
+
464
+ # Create the interface
465
+ with gr.Blocks(title="Medical Image Analysis", theme=gr.themes.Soft()) as demo:
466
+ gr.Markdown("# 🏥 Medical Image Analysis")
467
+ gr.Markdown("Upload a medical image to detect and analyze findings using AI.")
468
+
469
+ with gr.Row():
470
+ with gr.Column():
471
+ input_image = gr.Image(type="pil", label="Upload Image", height=400)
472
+ threshold_slider = gr.Slider(
473
+ minimum=0.1,
474
+ maximum=1.0,
475
+ value=0.7,
476
+ step=0.05,
477
+ label="Confidence Threshold",
478
+ info="Higher values = fewer but more confident detections"
479
+ )
480
+
481
+ model_size_radio = gr.Radio(
482
+ choices=["4B", "27B"],
483
+ value="4B",
484
+ label="MedGemma Model Size",
485
+ info="4B: Faster, less memory | 27B: More accurate, more memory"
486
+ )
487
+
488
+ analyze_btn = gr.Button("🔍 Analyze Image", variant="primary")
489
+
490
+ with gr.Column():
491
+ output_image = gr.Image(type="pil", label="Results", height=400)
492
+ output_text = gr.Textbox(
493
+ label="Analysis Results",
494
+ lines=8,
495
+ max_lines=15,
496
+ show_copy_button=True
497
+ )
498
+
499
+ # Wire up the interface
500
+ analyze_btn.click(
501
+ fn=annotate_image,
502
+ inputs=[input_image, threshold_slider, model_size_radio],
503
+ outputs=[output_image, output_text]
504
+ )
505
+
506
+ # Also run when image is uploaded
507
+ input_image.change(
508
+ fn=annotate_image,
509
+ inputs=[input_image, threshold_slider, model_size_radio],
510
+ outputs=[output_image, output_text]
511
+ )
512
+
513
+ # Footer
514
+ gr.Markdown("---")
515
+
516
+ return demo
517
+
518
+ # ============================================================================
519
+ # Main Application
520
+ # ============================================================================
521
+
522
+ # Global app state
523
+ app_state = AppState()
524
+
525
+ def main():
526
+ """Main entry point for the Spaces app."""
527
+ print("🚀 Starting Medical Image Analysis App")
528
+
529
+ # Ensure results directory exists
530
+ os.makedirs(app_state.config.get('results_dir'), exist_ok=True)
531
+
532
+ # Create and launch the interface
533
+ demo = create_detection_interface()
534
+
535
+ # Launch with Spaces-optimized settings
536
+ demo.launch(
537
+ server_name="0.0.0.0",
538
+ server_port=7860,
539
+ share=False, # Spaces handles this
540
+ show_error=True,
541
+ show_api=False,
542
+ )
543
+
544
+ if __name__ == "__main__":
545
+ main()