eduardofarina commited on
Commit
fa5d00b
·
verified ·
1 Parent(s): a27703f

Upload folder using huggingface_hub

Browse files
Files changed (6) hide show
  1. .gitignore +30 -0
  2. README.md +74 -8
  3. app.py +422 -0
  4. dicom_processor.py +255 -0
  5. model_handler.py +173 -0
  6. requirements.txt +10 -0
.gitignore ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Environment variables (contains secrets)
2
+ .env
3
+
4
+ # Downloaded models (large files)
5
+ models/
6
+ .env.local
7
+ .env.*
8
+
9
+ # Python
10
+ __pycache__/
11
+ *.py[cod]
12
+ *$py.class
13
+ *.so
14
+ .Python
15
+ *.egg-info/
16
+ dist/
17
+ build/
18
+
19
+ # IDE
20
+ .vscode/
21
+ .idea/
22
+ *.swp
23
+ *.swo
24
+
25
+ # OS
26
+ .DS_Store
27
+ Thumbs.db
28
+
29
+ # Logs
30
+ *.log
README.md CHANGED
@@ -1,15 +1,81 @@
1
  ---
2
- title: MedGemma1.5ReportGenerator
3
- emoji: 📚
4
- colorFrom: purple
5
- colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 6.4.0
8
- python_version: '3.12'
9
  app_file: app.py
10
  pinned: false
11
  license: mit
12
- short_description: Generate radiology reports with MedGemma 1.5
13
  ---
14
 
15
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: MedGemma 1.5 Report Generator
3
+ emoji: 🏥
4
+ colorFrom: blue
5
+ colorTo: green
6
  sdk: gradio
7
+ sdk_version: 5.23.3
 
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
11
  ---
12
 
13
+ # MedGemma 1.5 DICOM Report Generator
14
+
15
+ A Gradio-based web application that uses Google's MedGemma 1.5 model to automatically generate structured radiology reports from DICOM medical images.
16
+
17
+ ![Python](https://img.shields.io/badge/python-3.10+-blue.svg)
18
+ ![License](https://img.shields.io/badge/license-MIT-green.svg)
19
+
20
+ ## Features
21
+
22
+ - **DICOM Processing**: Upload ZIP files containing DICOM images from CT, MR, CR, or DX studies
23
+ - **Smart Sampling**: Configurable slice sampling per series to manage GPU memory
24
+ - **DICOM Windowing**: Auto or manual window/level controls with CT presets (Brain, Lung, Bone, etc.)
25
+ - **Image Preview**: Built-in gallery to visualize sampled slices before inference
26
+ - **VRAM Estimation**: Real-time estimation of GPU memory usage based on settings
27
+ - **Configurable Generation**: Adjustable temperature, top-p, top-k, and max tokens
28
+ - **Custom Prompts**: Editable prompts for tailored report generation
29
+
30
+ ## Requirements
31
+
32
+ - Python 3.10+
33
+ - NVIDIA GPU with CUDA support (recommended: 12GB+ VRAM)
34
+ - Hugging Face account with access to [google/medgemma-1.5-4b-it](https://huggingface.co/google/medgemma-1.5-4b-it)
35
+
36
+ ## Usage
37
+
38
+ 1. Upload a ZIP file containing DICOM images
39
+
40
+ 2. Adjust settings:
41
+ - **Max Slices Per Series**: Reduce for less VRAM usage
42
+ - **Image Size**: Smaller images use less VRAM
43
+ - **Windowing**: Use presets or manual WC/WW for CT images
44
+
45
+ 3. Click "Process & Preview" to see the sampled images and VRAM estimate
46
+
47
+ 4. Click "Generate Report" to create the radiology report
48
+
49
+ ## Window Presets
50
+
51
+ | Preset | Window Center | Window Width | Use Case |
52
+ |--------|--------------|--------------|----------|
53
+ | Brain | 40 | 80 | Brain parenchyma |
54
+ | Subdural | 75 | 215 | Subdural hematoma |
55
+ | Stroke | 32 | 8 | Acute stroke |
56
+ | Lung | -600 | 1500 | Lung parenchyma |
57
+ | Mediastinum | 50 | 350 | Mediastinal structures |
58
+ | Bone | 400 | 1800 | Bone windows |
59
+ | Abdomen | 40 | 400 | Abdominal soft tissue |
60
+ | Liver | 60 | 150 | Liver lesions |
61
+
62
+ ## Tips for Low VRAM
63
+
64
+ - Use **Max Slices Per Series = 5-10** instead of all slices
65
+ - Reduce **Image Size** to 256-384 pixels
66
+ - Process one series at a time for very large studies
67
+
68
+ ## Disclaimer
69
+
70
+ This tool is for research and educational purposes only. It is NOT intended for clinical use or medical diagnosis. Always consult qualified healthcare professionals for medical decisions.
71
+
72
+ ## License
73
+
74
+ MIT License
75
+
76
+ ## Acknowledgments
77
+
78
+ - [Google MedGemma](https://huggingface.co/google/medgemma-1.5-4b-it) for the medical vision-language model
79
+ - [Gradio](https://gradio.app/) for the web interface framework
80
+ - [PyDICOM](https://pydicom.github.io/) for DICOM file processing
81
+ - **Claude Opus** (Anthropic) for assistance in creating this demo
app.py ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Main Gradio application for MedGemma DICOM report drafting.
3
+ """
4
+ import traceback
5
+ from typing import Optional, Tuple, List
6
+ import gradio as gr
7
+ from PIL import Image
8
+
9
+ from dicom_processor import process_dicom_study
10
+ from model_handler import MedGemmaHandler
11
+
12
+ model_handler: Optional[MedGemmaHandler] = None
13
+ # Store processed data for reuse
14
+ cached_data = {
15
+ "zip_bytes": None,
16
+ "images": None,
17
+ "modality": None,
18
+ "study_info": None
19
+ }
20
+
21
+
22
+ def load_model():
23
+ """Load the MedGemma model."""
24
+ global model_handler
25
+ if model_handler is None:
26
+ model_handler = MedGemmaHandler()
27
+ model_handler.load_model()
28
+ return model_handler
29
+
30
+
31
+ def process_dicom_file(
32
+ file_path: str,
33
+ max_slices_per_series: int,
34
+ image_size: int,
35
+ window_center: float,
36
+ window_width: float,
37
+ use_auto_window: bool
38
+ ) -> Tuple[str, str, List[Image.Image]]:
39
+ """Process uploaded DICOM ZIP file and return preview images."""
40
+ global cached_data
41
+
42
+ try:
43
+ if file_path is None:
44
+ return "No file uploaded", "", []
45
+
46
+ with open(file_path, 'rb') as f:
47
+ zip_bytes = f.read()
48
+
49
+ # Use per-series sampling if max_slices_per_series > 0
50
+ slices_per_series = max_slices_per_series if max_slices_per_series > 0 else None
51
+
52
+ # Use auto window if checkbox is checked
53
+ wc = None if use_auto_window else window_center
54
+ ww = None if use_auto_window else window_width
55
+
56
+ modality, images, study_info = process_dicom_study(
57
+ zip_bytes,
58
+ max_slices_per_series=slices_per_series,
59
+ image_size=image_size,
60
+ window_center=wc,
61
+ window_width=ww
62
+ )
63
+
64
+ # Cache for later use in report generation
65
+ cached_data["zip_bytes"] = zip_bytes
66
+ cached_data["images"] = images
67
+ cached_data["modality"] = modality
68
+ cached_data["study_info"] = study_info
69
+
70
+ max_per_series = study_info.get('MaxSlicesPerSeries', None)
71
+ sampling_info = f"Max Slices Per Series: {max_per_series}" if max_per_series else "Sampling: Global (all series combined)"
72
+
73
+ # Get window info
74
+ default_wc = study_info.get('DefaultWindowCenter', 'N/A')
75
+ default_ww = study_info.get('DefaultWindowWidth', 'N/A')
76
+ window_info = f"Window: Auto (WC={default_wc}, WW={default_ww})" if use_auto_window else f"Window: Manual (WC={window_center}, WW={window_width})"
77
+
78
+ # Estimate VRAM usage based on actual image size
79
+ num_images = study_info.get('ProcessedImages', 0)
80
+ img_size = study_info.get('ImageSize', 896)
81
+ # Model base: ~8GB, per image scales with size squared
82
+ model_vram_gb = 8.0
83
+ # Base estimate for 896x896 is ~50MB, scale proportionally
84
+ base_per_image_mb = 50
85
+ size_factor = (img_size / 896) ** 2
86
+ per_image_vram_mb = base_per_image_mb * size_factor
87
+ images_vram_gb = (num_images * per_image_vram_mb) / 1024
88
+ total_vram_gb = model_vram_gb + images_vram_gb
89
+
90
+ info_text = f"""Study Information:
91
+
92
+ Modality: {study_info['Modality']}
93
+ Study Description: {study_info['StudyDescription']}
94
+ Study Date: {study_info['StudyDate']}
95
+ Patient ID: {study_info['PatientID']}
96
+
97
+ Series Count: {study_info.get('SeriesCount', 'N/A')}
98
+ Total Original Slices: {study_info.get('TotalOriginalSlices', 'N/A')}
99
+ {sampling_info}
100
+ Processed Images: {num_images}
101
+ Image Size: {img_size}x{img_size}
102
+ {window_info}
103
+
104
+ --- VRAM Estimate ---
105
+ Model: ~{model_vram_gb:.1f} GB
106
+ Images ({num_images} x {img_size}x{img_size}): ~{images_vram_gb:.1f} GB
107
+ Total Estimated: ~{total_vram_gb:.1f} GB
108
+ """
109
+
110
+ status = f"✓ Processed {len(images)} images from {study_info['Modality']} study"
111
+
112
+ return status, info_text, images
113
+
114
+ except Exception as e:
115
+ error_msg = f"Error processing DICOM: {str(e)}"
116
+ print(error_msg)
117
+ print(traceback.format_exc())
118
+ return error_msg, "", []
119
+
120
+
121
+ def generate_report(
122
+ file_path: str,
123
+ max_slices_per_series: int,
124
+ image_size: int,
125
+ window_center: float,
126
+ window_width: float,
127
+ use_auto_window: bool,
128
+ prompt: str,
129
+ max_tokens: int,
130
+ temperature: float,
131
+ top_p: float,
132
+ top_k: int,
133
+ do_sample: bool,
134
+ progress=gr.Progress(track_tqdm=True)
135
+ ) -> str:
136
+ """Generate radiology report using MedGemma."""
137
+ global cached_data
138
+
139
+ try:
140
+ if file_path is None:
141
+ return "Please upload a DICOM ZIP file first."
142
+
143
+ progress(0, desc="Loading model...")
144
+
145
+ global model_handler
146
+ if model_handler is None:
147
+ model_handler = load_model()
148
+
149
+ # Check if we can use cached images
150
+ use_cache = (
151
+ cached_data["images"] is not None and
152
+ cached_data["zip_bytes"] is not None
153
+ )
154
+
155
+ if use_cache:
156
+ progress(0.4, desc="Using cached images...")
157
+ images = cached_data["images"]
158
+ modality = cached_data["modality"]
159
+ else:
160
+ progress(0.2, desc="Reading DICOM files...")
161
+
162
+ with open(file_path, 'rb') as f:
163
+ zip_bytes = f.read()
164
+
165
+ progress(0.4, desc="Processing images...")
166
+ slices_per_series = max_slices_per_series if max_slices_per_series > 0 else None
167
+ wc = None if use_auto_window else window_center
168
+ ww = None if use_auto_window else window_width
169
+
170
+ modality, images, study_info = process_dicom_study(
171
+ zip_bytes,
172
+ max_slices_per_series=slices_per_series,
173
+ image_size=image_size,
174
+ window_center=wc,
175
+ window_width=ww
176
+ )
177
+
178
+ progress(0.6, desc=f"Generating report with MedGemma 1.5 ({len(images)} images)...")
179
+
180
+ # Use custom prompt or default
181
+ if not prompt.strip():
182
+ prompt = f"You are a radiologist, please draft the full structured report for the following {modality} exam. Include the following sections: Technique, Findings, and Impression."
183
+
184
+ report = model_handler.generate_report(
185
+ images=images,
186
+ prompt=prompt,
187
+ max_new_tokens=max_tokens,
188
+ temperature=temperature,
189
+ top_p=top_p,
190
+ top_k=top_k,
191
+ do_sample=do_sample,
192
+ )
193
+
194
+ progress(1.0, desc="Complete!")
195
+
196
+ return report
197
+
198
+ except Exception as e:
199
+ error_msg = f"Error generating report: {str(e)}\n\n{traceback.format_exc()}"
200
+ print(error_msg)
201
+ return error_msg
202
+
203
+
204
+ def create_interface():
205
+ """Create the Gradio interface."""
206
+
207
+ with gr.Blocks(title="MedGemma 1.5 DICOM Report Generator", theme=gr.themes.Soft()) as demo:
208
+ gr.Markdown("# 🏥 MedGemma 1.5 DICOM Report Generator")
209
+ gr.Markdown("Upload a ZIP file containing DICOM images to generate a structured radiology report.")
210
+
211
+ with gr.Row():
212
+ # Left column: Upload and settings
213
+ with gr.Column(scale=1):
214
+ file_input = gr.File(
215
+ label="Upload DICOM ZIP",
216
+ file_types=[".zip"],
217
+ type="filepath"
218
+ )
219
+
220
+ with gr.Accordion("Image Processing Settings", open=True):
221
+ max_slices_slider = gr.Slider(
222
+ minimum=0,
223
+ maximum=50,
224
+ value=10,
225
+ step=1,
226
+ label="Max Slices Per Series",
227
+ info="0 = use all slices. Reduce to save VRAM."
228
+ )
229
+
230
+ image_size_slider = gr.Slider(
231
+ minimum=224,
232
+ maximum=1024,
233
+ value=512,
234
+ step=32,
235
+ label="Image Size",
236
+ info="Smaller = less VRAM, lower quality"
237
+ )
238
+
239
+ gr.Markdown("**Windowing (for CT/X-ray)**")
240
+ use_auto_window = gr.Checkbox(
241
+ label="Use Auto Window (from DICOM metadata)",
242
+ value=True
243
+ )
244
+ with gr.Row():
245
+ window_center_slider = gr.Slider(
246
+ minimum=-1000,
247
+ maximum=3000,
248
+ value=40,
249
+ step=10,
250
+ label="Window Center (WC)",
251
+ info="e.g., Brain=40, Lung=-600, Bone=400"
252
+ )
253
+ window_width_slider = gr.Slider(
254
+ minimum=1,
255
+ maximum=4000,
256
+ value=400,
257
+ step=10,
258
+ label="Window Width (WW)",
259
+ info="e.g., Brain=80, Lung=1500, Bone=1800"
260
+ )
261
+
262
+ process_btn = gr.Button("Process & Preview", variant="primary", size="lg")
263
+
264
+ status_output = gr.Textbox(
265
+ label="Status",
266
+ interactive=False
267
+ )
268
+
269
+ study_info_box = gr.Textbox(
270
+ label="Study Information & VRAM Estimate",
271
+ interactive=False,
272
+ lines=14
273
+ )
274
+
275
+ # Middle column: Image preview
276
+ with gr.Column(scale=1):
277
+ gr.Markdown("### 🖼️ Image Preview")
278
+ gr.Markdown("*Preview of sampled slices that will be sent to the model*")
279
+
280
+ image_gallery = gr.Gallery(
281
+ label="Sampled Slices",
282
+ show_label=False,
283
+ columns=4,
284
+ rows=3,
285
+ height=400,
286
+ object_fit="contain",
287
+ preview=True
288
+ )
289
+
290
+ # Right column: Generation settings and output
291
+ with gr.Column(scale=1):
292
+ prompt_input = gr.Textbox(
293
+ label="Prompt",
294
+ lines=3,
295
+ value="You are a radiologist, please draft the full structured report for this exam. Include: Technique, Findings, and Impression.",
296
+ info="Customize the prompt. Leave empty for default."
297
+ )
298
+
299
+ with gr.Accordion("Model Settings", open=False):
300
+ with gr.Row():
301
+ max_tokens_slider = gr.Slider(
302
+ minimum=50,
303
+ maximum=1000,
304
+ value=350,
305
+ step=10,
306
+ label="Max Tokens"
307
+ )
308
+ temperature_slider = gr.Slider(
309
+ minimum=0.0,
310
+ maximum=2.0,
311
+ value=0.7,
312
+ step=0.1,
313
+ label="Temperature"
314
+ )
315
+ with gr.Row():
316
+ top_p_slider = gr.Slider(
317
+ minimum=0.0,
318
+ maximum=1.0,
319
+ value=0.9,
320
+ step=0.05,
321
+ label="Top P"
322
+ )
323
+ top_k_slider = gr.Slider(
324
+ minimum=1,
325
+ maximum=100,
326
+ value=50,
327
+ step=1,
328
+ label="Top K"
329
+ )
330
+ do_sample_checkbox = gr.Checkbox(
331
+ label="Enable Sampling",
332
+ value=True,
333
+ info="Uncheck for deterministic output"
334
+ )
335
+
336
+ generate_btn = gr.Button("🚀 Generate Report", variant="primary", size="lg")
337
+
338
+ report_output = gr.Textbox(
339
+ label="Generated Report",
340
+ interactive=False,
341
+ lines=18,
342
+ placeholder="Report will appear here..."
343
+ )
344
+
345
+ # Common window presets
346
+ with gr.Accordion("Window Presets (click to apply)", open=False):
347
+ gr.Markdown("**CT Presets:**")
348
+ with gr.Row():
349
+ brain_btn = gr.Button("Brain (40/80)", size="sm")
350
+ subdural_btn = gr.Button("Subdural (75/215)", size="sm")
351
+ stroke_btn = gr.Button("Stroke (32/8)", size="sm")
352
+ lung_btn = gr.Button("Lung (-600/1500)", size="sm")
353
+ mediastinum_btn = gr.Button("Mediastinum (50/350)", size="sm")
354
+ bone_btn = gr.Button("Bone (400/1800)", size="sm")
355
+ abdomen_btn = gr.Button("Abdomen (40/400)", size="sm")
356
+ liver_btn = gr.Button("Liver (60/150)", size="sm")
357
+
358
+ # Event handlers for presets
359
+ brain_btn.click(lambda: (40, 80, False), outputs=[window_center_slider, window_width_slider, use_auto_window])
360
+ subdural_btn.click(lambda: (75, 215, False), outputs=[window_center_slider, window_width_slider, use_auto_window])
361
+ stroke_btn.click(lambda: (32, 8, False), outputs=[window_center_slider, window_width_slider, use_auto_window])
362
+ lung_btn.click(lambda: (-600, 1500, False), outputs=[window_center_slider, window_width_slider, use_auto_window])
363
+ mediastinum_btn.click(lambda: (50, 350, False), outputs=[window_center_slider, window_width_slider, use_auto_window])
364
+ bone_btn.click(lambda: (400, 1800, False), outputs=[window_center_slider, window_width_slider, use_auto_window])
365
+ abdomen_btn.click(lambda: (40, 400, False), outputs=[window_center_slider, window_width_slider, use_auto_window])
366
+ liver_btn.click(lambda: (60, 150, False), outputs=[window_center_slider, window_width_slider, use_auto_window])
367
+
368
+ # Main event handlers
369
+ process_btn.click(
370
+ fn=process_dicom_file,
371
+ inputs=[
372
+ file_input,
373
+ max_slices_slider,
374
+ image_size_slider,
375
+ window_center_slider,
376
+ window_width_slider,
377
+ use_auto_window
378
+ ],
379
+ outputs=[status_output, study_info_box, image_gallery]
380
+ )
381
+
382
+ generate_btn.click(
383
+ fn=generate_report,
384
+ inputs=[
385
+ file_input,
386
+ max_slices_slider,
387
+ image_size_slider,
388
+ window_center_slider,
389
+ window_width_slider,
390
+ use_auto_window,
391
+ prompt_input,
392
+ max_tokens_slider,
393
+ temperature_slider,
394
+ top_p_slider,
395
+ top_k_slider,
396
+ do_sample_checkbox
397
+ ],
398
+ outputs=[report_output]
399
+ )
400
+
401
+ gr.Markdown("---")
402
+ gr.Markdown("**Supported Modalities:** CT, MR, CR, DX | **Tip:** Use fewer slices and smaller image size to reduce VRAM usage")
403
+
404
+ return demo
405
+
406
+
407
+ def main():
408
+ """Main entry point."""
409
+ print("Starting MedGemma 1.5 DICOM Report Generator...")
410
+ print("Note: The model will be loaded on first report generation.")
411
+
412
+ demo = create_interface()
413
+ demo.launch(
414
+ server_name="0.0.0.0",
415
+ server_port=7860,
416
+ share=False,
417
+ show_error=True
418
+ )
419
+
420
+
421
+ if __name__ == "__main__":
422
+ main()
dicom_processor.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DICOM utilities for processing medical imaging studies.
3
+ """
4
+ import io
5
+ import zipfile
6
+ from typing import List, Tuple, Dict, Optional
7
+ import numpy as np
8
+ from PIL import Image
9
+ import pydicom
10
+
11
+
12
+ def has_pixel_data(ds: pydicom.Dataset) -> bool:
13
+ """Check if DICOM dataset has pixel data."""
14
+ return (
15
+ 'PixelData' in ds or
16
+ 'FloatPixelData' in ds or
17
+ 'DoubleFloatPixelData' in ds
18
+ )
19
+
20
+
21
+ def extract_dicom_from_zip(zip_bytes: bytes) -> List[Tuple[str, pydicom.Dataset]]:
22
+ """Extract DICOM files from a ZIP archive, filtering out non-image files."""
23
+ dicom_files = []
24
+
25
+ with zipfile.ZipFile(io.BytesIO(zip_bytes), 'r') as zip_ref:
26
+ for filename in zip_ref.namelist():
27
+ if filename.lower().endswith('.dcm'):
28
+ try:
29
+ file_bytes = zip_ref.read(filename)
30
+ ds = pydicom.dcmread(io.BytesIO(file_bytes))
31
+
32
+ # Skip files without pixel data (SR, reports, dose records, etc.)
33
+ if has_pixel_data(ds):
34
+ dicom_files.append((filename, ds))
35
+ else:
36
+ print(f"Skipping {filename}: No pixel data (likely SR or report)")
37
+
38
+ except Exception as e:
39
+ print(f"Error reading {filename}: {e}")
40
+
41
+ return dicom_files
42
+
43
+
44
+ def get_modality(ds: pydicom.Dataset) -> str:
45
+ return getattr(ds, 'Modality', 'Unknown')
46
+
47
+
48
+ def get_study_info(ds: pydicom.Dataset, total_slices: int) -> Dict:
49
+ return {
50
+ 'StudyInstanceUID': getattr(ds, 'StudyInstanceUID', 'Unknown'),
51
+ 'StudyDescription': getattr(ds, 'StudyDescription', 'Unknown'),
52
+ 'Modality': get_modality(ds),
53
+ 'TotalSlices': total_slices,
54
+ 'StudyDate': getattr(ds, 'StudyDate', 'Unknown'),
55
+ 'PatientID': getattr(ds, 'PatientID', 'Unknown'),
56
+ }
57
+
58
+
59
+ def get_default_window(ds: pydicom.Dataset) -> Tuple[Optional[float], Optional[float]]:
60
+ """Get default window center and width from DICOM metadata."""
61
+ wc = getattr(ds, 'WindowCenter', None)
62
+ ww = getattr(ds, 'WindowWidth', None)
63
+
64
+ # Handle multi-valued windows (take first)
65
+ if wc is not None:
66
+ wc = float(wc[0]) if hasattr(wc, '__iter__') and not isinstance(wc, str) else float(wc)
67
+ if ww is not None:
68
+ ww = float(ww[0]) if hasattr(ww, '__iter__') and not isinstance(ww, str) else float(ww)
69
+
70
+ return wc, ww
71
+
72
+
73
+ def apply_windowing(
74
+ pixel_array: np.ndarray,
75
+ ds: pydicom.Dataset,
76
+ window_center: Optional[float] = None,
77
+ window_width: Optional[float] = None
78
+ ) -> np.ndarray:
79
+ """Apply rescale slope/intercept and windowing to pixel array."""
80
+ # Apply rescale slope and intercept (converts to HU for CT)
81
+ slope = getattr(ds, 'RescaleSlope', 1)
82
+ intercept = getattr(ds, 'RescaleIntercept', 0)
83
+ pixel_array = pixel_array.astype(np.float32) * slope + intercept
84
+
85
+ # Get window values
86
+ if window_center is None or window_width is None:
87
+ default_wc, default_ww = get_default_window(ds)
88
+ if window_center is None:
89
+ window_center = default_wc
90
+ if window_width is None:
91
+ window_width = default_ww
92
+
93
+ # Apply windowing if we have valid values
94
+ if window_center is not None and window_width is not None and window_width > 0:
95
+ min_val = window_center - window_width / 2
96
+ max_val = window_center + window_width / 2
97
+ pixel_array = np.clip(pixel_array, min_val, max_val)
98
+ normalized = ((pixel_array - min_val) / (max_val - min_val) * 255).astype(np.uint8)
99
+ else:
100
+ # Fallback: normalize to full range
101
+ pixel_min = pixel_array.min()
102
+ pixel_max = pixel_array.max()
103
+ if pixel_max > pixel_min:
104
+ normalized = ((pixel_array - pixel_min) / (pixel_max - pixel_min) * 255).astype(np.uint8)
105
+ else:
106
+ normalized = np.zeros_like(pixel_array, dtype=np.uint8)
107
+
108
+ return normalized
109
+
110
+
111
+ def dicom_to_pil(
112
+ ds: pydicom.Dataset,
113
+ size: Tuple[int, int] = (896, 896),
114
+ window_center: Optional[float] = None,
115
+ window_width: Optional[float] = None
116
+ ) -> Image.Image:
117
+ """Convert DICOM dataset to PIL Image with optional windowing and resizing."""
118
+ pixel_array = ds.pixel_array
119
+ normalized = apply_windowing(pixel_array, ds, window_center, window_width)
120
+
121
+ if len(normalized.shape) == 2:
122
+ pil_image = Image.fromarray(normalized, mode='L')
123
+ elif len(normalized.shape) == 3 and normalized.shape[2] <= 4:
124
+ if normalized.shape[2] == 1:
125
+ pil_image = Image.fromarray(normalized[:, :, 0], mode='L')
126
+ elif normalized.shape[2] == 3:
127
+ pil_image = Image.fromarray(normalized, mode='RGB')
128
+ elif normalized.shape[2] == 4:
129
+ pil_image = Image.fromarray(normalized[:, :, :3], mode='RGB')
130
+ else:
131
+ pil_image = Image.fromarray(normalized[:, :, 0], mode='L')
132
+ else:
133
+ pil_image = Image.fromarray(normalized[0], mode='L')
134
+
135
+ if pil_image.mode != 'RGB':
136
+ pil_image = pil_image.convert('RGB')
137
+
138
+ pil_image = pil_image.resize(size, Image.LANCZOS)
139
+
140
+ return pil_image
141
+
142
+
143
+ def organize_by_series(dicom_files: List[Tuple[str, pydicom.Dataset]]) -> Dict[str, List[Tuple[str, pydicom.Dataset]]]:
144
+ series_dict = {}
145
+ for filename, ds in dicom_files:
146
+ series_uid = getattr(ds, 'SeriesInstanceUID', 'Unknown')
147
+ if series_uid not in series_dict:
148
+ series_dict[series_uid] = []
149
+ series_dict[series_uid].append((filename, ds))
150
+ return series_dict
151
+
152
+
153
+ def sort_slices_by_position(series_files: List[Tuple[str, pydicom.Dataset]]) -> List[Tuple[str, pydicom.Dataset]]:
154
+ def get_sort_key(item):
155
+ filename, ds = item
156
+ instance_num = getattr(ds, 'InstanceNumber', None)
157
+ if instance_num is not None:
158
+ return (0, int(instance_num))
159
+
160
+ slice_loc = getattr(ds, 'SliceLocation', None)
161
+ if slice_loc is not None:
162
+ return (1, float(slice_loc))
163
+
164
+ return (2, filename)
165
+
166
+ return sorted(series_files, key=get_sort_key)
167
+
168
+
169
+ def sample_slices_evenly(all_slices: List[Tuple[str, pydicom.Dataset]], max_slices: int = 500) -> List[Tuple[str, pydicom.Dataset]]:
170
+ if len(all_slices) <= max_slices:
171
+ return all_slices
172
+
173
+ indices = [int(i * (len(all_slices) - 1) / (max_slices - 1)) for i in range(max_slices)]
174
+ return [all_slices[i] for i in indices]
175
+
176
+
177
+ def process_dicom_study(
178
+ zip_bytes: bytes,
179
+ max_slices: int = 500,
180
+ max_slices_per_series: Optional[int] = None,
181
+ image_size: int = 896,
182
+ window_center: Optional[float] = None,
183
+ window_width: Optional[float] = None
184
+ ) -> Tuple[str, List[Image.Image], Dict]:
185
+ """
186
+ Process a DICOM study from a ZIP file.
187
+
188
+ Args:
189
+ zip_bytes: ZIP file contents
190
+ max_slices: Maximum total slices across all series (used if max_slices_per_series is None)
191
+ max_slices_per_series: If set, sample this many slices evenly from each series
192
+ image_size: Output image size (square, e.g., 896 for 896x896)
193
+ window_center: Window center for display (None = use DICOM default or auto)
194
+ window_width: Window width for display (None = use DICOM default or auto)
195
+ """
196
+ dicom_files = extract_dicom_from_zip(zip_bytes)
197
+
198
+ if not dicom_files:
199
+ raise ValueError("No valid DICOM files found in the ZIP archive")
200
+
201
+ first_ds = dicom_files[0][1]
202
+ modality = get_modality(first_ds)
203
+
204
+ # Get default window from first image
205
+ default_wc, default_ww = get_default_window(first_ds)
206
+
207
+ series_dict = organize_by_series(dicom_files)
208
+
209
+ # Count total original slices
210
+ total_original_slices = sum(len(files) for files in series_dict.values())
211
+
212
+ # Sample slices per series or globally
213
+ sampled_slices = []
214
+ if max_slices_per_series is not None:
215
+ # Sample evenly from each series
216
+ for series_uid, series_files in series_dict.items():
217
+ sorted_slices = sort_slices_by_position(series_files)
218
+ series_sampled = sample_slices_evenly(sorted_slices, max_slices_per_series)
219
+ sampled_slices.extend(series_sampled)
220
+ else:
221
+ # Original behavior: sample globally
222
+ all_sorted_slices = []
223
+ for series_uid, series_files in series_dict.items():
224
+ sorted_slices = sort_slices_by_position(series_files)
225
+ all_sorted_slices.extend(sorted_slices)
226
+ sampled_slices = sample_slices_evenly(all_sorted_slices, max_slices)
227
+
228
+ sampled_count = len(sampled_slices)
229
+
230
+ study_info = get_study_info(first_ds, sampled_count)
231
+ study_info['SeriesCount'] = len(series_dict)
232
+ study_info['TotalOriginalSlices'] = total_original_slices
233
+ study_info['SampledSlices'] = sampled_count
234
+ study_info['ImageSize'] = image_size
235
+ study_info['DefaultWindowCenter'] = default_wc
236
+ study_info['DefaultWindowWidth'] = default_ww
237
+ if max_slices_per_series is not None:
238
+ study_info['MaxSlicesPerSeries'] = max_slices_per_series
239
+
240
+ images = []
241
+ for filename, ds in sampled_slices:
242
+ try:
243
+ pil_image = dicom_to_pil(
244
+ ds,
245
+ size=(image_size, image_size),
246
+ window_center=window_center,
247
+ window_width=window_width
248
+ )
249
+ images.append(pil_image)
250
+ except Exception as e:
251
+ print(f"Error converting {filename}: {e}")
252
+
253
+ study_info['ProcessedImages'] = len(images)
254
+
255
+ return modality, images, study_info
model_handler.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model handler for MedGemma 1.5 inference.
3
+ """
4
+ import os
5
+ import torch
6
+ from PIL import Image
7
+ from typing import List, Optional
8
+ from dotenv import load_dotenv
9
+ from transformers import AutoProcessor, AutoModelForImageTextToText
10
+
11
+ # Load environment variables from .env file
12
+ load_dotenv()
13
+
14
+
15
+ def check_gpu_availability():
16
+ """Check GPU availability and print diagnostics."""
17
+ print("=" * 60)
18
+ print("GPU Availability Check")
19
+ print("=" * 60)
20
+
21
+ cuda_available = torch.cuda.is_available()
22
+ print(f"CUDA available: {cuda_available}")
23
+
24
+ if cuda_available:
25
+ device_count = torch.cuda.device_count()
26
+ print(f"Number of GPUs: {device_count}")
27
+ for i in range(device_count):
28
+ device_name = torch.cuda.get_device_name(i)
29
+ print(f" GPU {i}: {device_name}")
30
+ print(f"Current GPU: {torch.cuda.current_device()}")
31
+ else:
32
+ print("CUDA is not available. Model will use CPU (slow).")
33
+ print("\nTo use GPU, ensure you have:")
34
+ print("1. NVIDIA GPU with CUDA support")
35
+ print("2. CUDA toolkit installed")
36
+ print("3. PyTorch with CUDA support: pip install torch --index-url https://download.pytorch.org/whl/cu118")
37
+
38
+ print("=" * 60)
39
+
40
+ return cuda_available
41
+
42
+
43
+ class MedGemmaHandler:
44
+ """Handler for MedGemma 1.5 model inference."""
45
+
46
+ def __init__(self, model_id: str = "google/medgemma-1.5-4b-it", device: Optional[str] = None):
47
+ self.model_id = model_id
48
+ self.device = device
49
+ self.processor = None
50
+ self.model = None
51
+
52
+ # Check for local model path (useful for local development)
53
+ local_model_path = os.path.join(os.path.dirname(__file__), "models", "medgemma-1.5-4b-it")
54
+ if os.path.exists(local_model_path) and os.path.isfile(os.path.join(local_model_path, "config.json")):
55
+ self.model_id = local_model_path
56
+ print(f"Using local model from: {local_model_path}")
57
+ else:
58
+ print(f"Using model from Hugging Face Hub: {self.model_id}")
59
+
60
+ def load_model(self):
61
+ """Load the MedGemma 1.5 model and processor."""
62
+ print(f"Loading MedGemma model: {self.model_id}")
63
+
64
+ # Check GPU availability
65
+ cuda_available = check_gpu_availability()
66
+
67
+ # Determine device
68
+ if self.device is None:
69
+ if cuda_available:
70
+ self.device = "cuda"
71
+ print(f"Using GPU: {torch.cuda.get_device_name(0)}")
72
+ else:
73
+ self.device = "cpu"
74
+ print("WARNING: Using CPU - this will be very slow!")
75
+ else:
76
+ print(f"Using device: {self.device}")
77
+
78
+ # Get HF token from environment
79
+ hf_token = os.getenv("HF_TOKEN")
80
+ if hf_token:
81
+ print("Using Hugging Face token from .env file")
82
+ else:
83
+ print("Warning: No HF_TOKEN found in .env file")
84
+
85
+ self.processor = AutoProcessor.from_pretrained(self.model_id, token=hf_token)
86
+
87
+ # Load model with proper device configuration
88
+ if self.device == "cuda" and torch.cuda.is_available():
89
+ print("Loading model on GPU with bfloat16...")
90
+ self.model = AutoModelForImageTextToText.from_pretrained(
91
+ self.model_id,
92
+ torch_dtype=torch.bfloat16,
93
+ device_map="cuda",
94
+ token=hf_token,
95
+ )
96
+ else:
97
+ print("Loading model on CPU (this may take a while)...")
98
+ self.model = AutoModelForImageTextToText.from_pretrained(
99
+ self.model_id,
100
+ torch_dtype=torch.float32, # Use float32 for CPU
101
+ device_map="cpu",
102
+ token=hf_token,
103
+ )
104
+
105
+ print(f"Model loaded on device: {next(self.model.parameters()).device}")
106
+ print("Model loaded successfully!")
107
+
108
+ def generate_report(
109
+ self,
110
+ images: List[Image.Image],
111
+ prompt: str,
112
+ max_new_tokens: int = 350,
113
+ temperature: float = 0.7,
114
+ top_p: float = 0.9,
115
+ top_k: int = 50,
116
+ do_sample: bool = True,
117
+ ) -> str:
118
+ """Generate a radiology report from medical images."""
119
+ if self.model is None or self.processor is None:
120
+ raise RuntimeError("Model not loaded. Call load_model() first.")
121
+
122
+ content = [{"type": "image", "image": img} for img in images]
123
+ content.append({"type": "text", "text": prompt})
124
+
125
+ messages = [
126
+ {
127
+ "role": "user",
128
+ "content": content
129
+ }
130
+ ]
131
+
132
+ inputs = self.processor.apply_chat_template(
133
+ messages,
134
+ add_generation_prompt=True,
135
+ tokenize=True,
136
+ return_dict=True,
137
+ return_tensors="pt"
138
+ )
139
+
140
+ # Move to device with proper dtype
141
+ if self.device == "cuda":
142
+ inputs = inputs.to(self.model.device, dtype=torch.bfloat16)
143
+ else:
144
+ inputs = inputs.to(self.model.device)
145
+
146
+ input_len = inputs["input_ids"].shape[-1]
147
+
148
+ with torch.inference_mode():
149
+ if do_sample and temperature > 0:
150
+ generation = self.model.generate(
151
+ **inputs,
152
+ max_new_tokens=max_new_tokens,
153
+ do_sample=True,
154
+ temperature=temperature,
155
+ top_p=top_p,
156
+ top_k=top_k,
157
+ )
158
+ else:
159
+ generation = self.model.generate(
160
+ **inputs,
161
+ max_new_tokens=max_new_tokens,
162
+ do_sample=False,
163
+ )
164
+ generation = generation[0][input_len:]
165
+
166
+ report = self.processor.decode(generation, skip_special_tokens=True)
167
+
168
+ # Clear GPU cache after inference
169
+ if self.device == "cuda":
170
+ torch.cuda.empty_cache()
171
+ print("GPU cache cleared.")
172
+
173
+ return report
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ transformers>=4.50.0
3
+ torch>=2.2.0
4
+ torchvision
5
+ accelerate
6
+ pydicom>=2.4.0
7
+ Pillow>=10.0.0
8
+ numpy>=1.24.0,<2.0
9
+ python-dotenv>=1.0.0
10
+ scipy