| import os
|
| import json
|
| import gc
|
| import time
|
| import traceback
|
| from typing import Dict, List, Optional, Tuple, Callable, Any
|
|
|
| import torch
|
| import gradio as gr
|
| import supervision as sv
|
| from PIL import Image
|
|
|
|
|
| try:
|
| from transformers import (
|
| AutoModelForCausalLM,
|
| AutoTokenizer,
|
| AutoModelForImageTextToText,
|
| AutoProcessor,
|
| BitsAndBytesConfig,
|
| )
|
| except Exception:
|
| AutoModelForCausalLM = None
|
| AutoTokenizer = None
|
| AutoModelForImageTextToText = None
|
| AutoProcessor = None
|
| BitsAndBytesConfig = None
|
|
|
|
|
| try:
|
| from rfdetr import RFDETRMedium
|
| except ImportError:
|
| print("Warning: RF-DETR not found. Please ensure it's properly installed.")
|
| RFDETRMedium = None
|
|
|
|
|
|
|
|
|
|
|
| class SpacesConfig:
|
| """Configuration optimized for Hugging Face Spaces."""
|
|
|
| def __init__(self):
|
| self.settings = {
|
| 'results_dir': '/tmp/results',
|
| 'checkpoint': None,
|
| 'resolution': 576,
|
| 'threshold': 0.7,
|
| 'use_llm': True,
|
| 'llm_model_id': 'google/medgemma-4b-it',
|
| 'llm_max_new_tokens': 200,
|
| 'llm_temperature': 0.2,
|
| 'llm_4bit': True,
|
| 'enable_caching': True,
|
| 'max_cache_size': 100,
|
| }
|
|
|
| def get(self, key: str, default: Any = None) -> Any:
|
| return self.settings.get(key, default)
|
|
|
|
|
|
|
|
|
|
|
| class MemoryManager:
|
| """Simplified memory management for Spaces."""
|
|
|
| def __init__(self):
|
| self.memory_thresholds = {
|
| 'gpu_warning': 0.8,
|
| 'system_warning': 0.85,
|
| }
|
|
|
| def cleanup_memory(self, force: bool = False) -> None:
|
| """Perform memory cleanup."""
|
| try:
|
| gc.collect()
|
| if torch and torch.cuda.is_available():
|
| torch.cuda.empty_cache()
|
| torch.cuda.synchronize()
|
| except Exception as e:
|
| print(f"Memory cleanup error: {e}")
|
|
|
|
|
| memory_manager = MemoryManager()
|
|
|
|
|
|
|
|
|
|
|
| def find_checkpoint() -> Optional[str]:
|
| """Find RF-DETR checkpoint in various locations."""
|
| candidates = [
|
| "rf-detr-medium.pth",
|
| "/tmp/results/checkpoint_best_total.pth",
|
| "/tmp/results/checkpoint_best_ema.pth",
|
| "/tmp/results/checkpoint_best_regular.pth",
|
| "/tmp/results/checkpoint.pth",
|
| ]
|
|
|
| for path in candidates:
|
| if os.path.isfile(path):
|
| return path
|
| return None
|
|
|
| def load_model(checkpoint_path: str, resolution: int):
|
| """Load RF-DETR model."""
|
| if RFDETRMedium is None:
|
| raise RuntimeError("RF-DETR not available. Please install it properly.")
|
|
|
| model = RFDETRMedium(pretrain_weights=checkpoint_path, resolution=resolution)
|
| try:
|
| model.optimize_for_inference()
|
| except Exception:
|
| pass
|
| return model
|
|
|
|
|
|
|
|
|
|
|
| class TextGenerator:
|
| """Simplified text generator for Spaces."""
|
|
|
| def __init__(self, model_id: str, max_tokens: int = 200, temperature: float = 0.2):
|
| self.model_id = model_id
|
| self.max_tokens = max_tokens
|
| self.temperature = temperature
|
| self.model = None
|
| self.tokenizer = None
|
| self.processor = None
|
| self.is_multimodal = False
|
|
|
| def load_model(self):
|
| """Load the LLM model."""
|
| if self.model is not None:
|
| return
|
|
|
| if (AutoModelForCausalLM is None and AutoModelForImageTextToText is None):
|
| raise RuntimeError("Transformers not available")
|
|
|
|
|
| memory_manager.cleanup_memory()
|
|
|
| print(f"Loading model: {self.model_id}")
|
|
|
| model_kwargs = {
|
| "device_map": "auto",
|
| "low_cpu_mem_usage": True,
|
| }
|
|
|
| if torch and torch.cuda.is_available():
|
| model_kwargs["torch_dtype"] = torch.bfloat16
|
|
|
|
|
| if BitsAndBytesConfig is not None:
|
| try:
|
| compute_dtype = torch.bfloat16 if torch and torch.cuda.is_available() else torch.float16
|
| model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
| load_in_4bit=True,
|
| bnb_4bit_compute_dtype=compute_dtype,
|
| bnb_4bit_use_double_quant=True,
|
| bnb_4bit_quant_type="nf4"
|
| )
|
| model_kwargs["torch_dtype"] = compute_dtype
|
| except Exception:
|
| pass
|
|
|
|
|
| is_multimodal = "medgemma" in self.model_id.lower()
|
|
|
| if is_multimodal and AutoModelForImageTextToText is not None and AutoProcessor is not None:
|
| self.processor = AutoProcessor.from_pretrained(self.model_id)
|
| self.model = AutoModelForImageTextToText.from_pretrained(self.model_id, **model_kwargs)
|
| self.is_multimodal = True
|
| elif AutoModelForCausalLM is not None and AutoTokenizer is not None:
|
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
|
| self.model = AutoModelForCausalLM.from_pretrained(self.model_id, **model_kwargs)
|
| self.is_multimodal = False
|
| else:
|
| raise RuntimeError("Required model classes not available")
|
|
|
| print("✓ Model loaded successfully")
|
|
|
| def generate(self, text: str, image: Optional[Image.Image] = None) -> str:
|
| """Generate text using the loaded model."""
|
| self.load_model()
|
|
|
| if self.model is None:
|
| return f"[Model not loaded: {text}]"
|
|
|
| try:
|
|
|
| system_text = "You are a concise medical assistant. Provide a brief, clear summary of detection results. Avoid repetition and be direct. Do not give medical advice."
|
| user_text = f"Summarize these detection results in 3 clear sentences:\n\n{text}"
|
|
|
| if self.is_multimodal:
|
|
|
| user_content = [{"type": "text", "text": user_text}]
|
| if image is not None:
|
| user_content.append({"type": "image", "image": image})
|
|
|
| messages = [
|
| {"role": "system", "content": [{"type": "text", "text": system_text}]},
|
| {"role": "user", "content": user_content},
|
| ]
|
|
|
| inputs = self.processor.apply_chat_template(
|
| messages,
|
| add_generation_prompt=True,
|
| tokenize=True,
|
| return_dict=True,
|
| return_tensors="pt",
|
| )
|
|
|
| if torch:
|
| inputs = inputs.to(self.model.device, dtype=torch.bfloat16)
|
|
|
| with torch.inference_mode():
|
| generation = self.model.generate(
|
| **inputs,
|
| max_new_tokens=self.max_tokens,
|
| do_sample=self.temperature > 0,
|
| temperature=max(0.01, self.temperature) if self.temperature > 0 else None,
|
| use_cache=False,
|
| )
|
|
|
| input_len = inputs["input_ids"].shape[-1]
|
| generation = generation[0][input_len:]
|
| decoded = self.processor.decode(generation, skip_special_tokens=True)
|
| return decoded.strip()
|
|
|
| else:
|
|
|
| messages = [
|
| {"role": "system", "content": system_text},
|
| {"role": "user", "content": user_text},
|
| ]
|
|
|
| inputs = self.tokenizer.apply_chat_template(
|
| messages,
|
| add_generation_prompt=True,
|
| tokenize=True,
|
| return_dict=True,
|
| return_tensors="pt",
|
| )
|
|
|
| inputs = inputs.to(self.model.device)
|
|
|
| with torch.inference_mode():
|
| generation = self.model.generate(
|
| **inputs,
|
| max_new_tokens=self.max_tokens,
|
| do_sample=self.temperature > 0,
|
| temperature=max(0.01, self.temperature) if self.temperature > 0 else None,
|
| use_cache=False,
|
| )
|
|
|
| input_len = inputs["input_ids"].shape[-1]
|
| generation = generation[0][input_len:]
|
| decoded = self.tokenizer.decode(generation, skip_special_tokens=True)
|
| return decoded.strip()
|
|
|
| except Exception as e:
|
| error_msg = f"[Generation error: {e}]"
|
| print(f"Generation error: {traceback.format_exc()}")
|
| return f"{error_msg}\n\n{text}"
|
|
|
|
|
|
|
|
|
|
|
| class AppState:
|
| """Application state for Spaces."""
|
|
|
| def __init__(self):
|
| self.config = SpacesConfig()
|
| self.model = None
|
| self.class_names = None
|
| self.text_generator = None
|
|
|
| def load_model(self):
|
| """Load the detection model."""
|
| if self.model is not None:
|
| return
|
|
|
| checkpoint = find_checkpoint()
|
| if not checkpoint:
|
| raise FileNotFoundError(
|
| "No RF-DETR checkpoint found. Please upload rf-detr-medium.pth to your Space."
|
| )
|
|
|
| print(f"Loading RF-DETR from: {checkpoint}")
|
| self.model = load_model(checkpoint, self.config.get('resolution'))
|
|
|
|
|
| try:
|
| results_json = "/tmp/results/results.json"
|
| if os.path.isfile(results_json):
|
| with open(results_json, 'r') as f:
|
| data = json.load(f)
|
| classes = []
|
| for split in ("valid", "test", "train"):
|
| if "class_map" in data and split in data["class_map"]:
|
| for item in data["class_map"][split]:
|
| name = item.get("class")
|
| if name and name != "all" and name not in classes:
|
| classes.append(name)
|
| self.class_names = classes if classes else None
|
| except Exception:
|
| pass
|
|
|
| print("✓ RF-DETR model loaded")
|
|
|
| def get_text_generator(self, model_size: str = "4B") -> TextGenerator:
|
| """Get or create text generator."""
|
|
|
| model_id = 'google/medgemma-27b-it' if model_size == "27B" else 'google/medgemma-4b-it'
|
|
|
|
|
| if (self.text_generator is None or
|
| hasattr(self.text_generator, 'model_id') and
|
| self.text_generator.model_id != model_id):
|
|
|
| max_tokens = self.config.get('llm_max_new_tokens')
|
| temperature = self.config.get('llm_temperature')
|
|
|
| self.text_generator = TextGenerator(model_id, max_tokens, temperature)
|
| return self.text_generator
|
|
|
|
|
|
|
|
|
|
|
| def create_detection_interface():
|
| """Create the Gradio interface."""
|
|
|
|
|
| COLOR_PALETTE = sv.ColorPalette.from_hex([
|
| "#ffff00", "#ff9b00", "#ff66ff", "#3399ff", "#ff66b2",
|
| "#ff8080", "#b266ff", "#9999ff", "#66ffff", "#33ff99",
|
| "#66ff66", "#99ff00",
|
| ])
|
|
|
| def annotate_image(image: Image.Image, threshold: float, model_size: str = "4B") -> Tuple[Image.Image, str]:
|
| """Process an image and return annotated version with description."""
|
|
|
| if image is None:
|
| return None, "Please upload an image."
|
|
|
| try:
|
|
|
| app_state.load_model()
|
|
|
|
|
| detections = app_state.model.predict(image, threshold=threshold)
|
|
|
|
|
| bbox_annotator = sv.BoxAnnotator(color=COLOR_PALETTE, thickness=2)
|
| label_annotator = sv.LabelAnnotator(text_scale=0.5, text_color=sv.Color.BLACK)
|
|
|
| labels = []
|
| for i in range(len(detections)):
|
| class_id = int(detections.class_id[i]) if detections.class_id is not None else None
|
| conf = float(detections.confidence[i]) if detections.confidence is not None else 0.0
|
|
|
| if app_state.class_names and class_id is not None:
|
| if 0 <= class_id < len(app_state.class_names):
|
| label_name = app_state.class_names[class_id]
|
| else:
|
| label_name = str(class_id)
|
| else:
|
| label_name = str(class_id) if class_id is not None else "object"
|
|
|
| labels.append(f"{label_name} {conf:.2f}")
|
|
|
| annotated = image.copy()
|
| annotated = bbox_annotator.annotate(annotated, detections)
|
| annotated = label_annotator.annotate(annotated, detections, labels)
|
|
|
|
|
| description = f"Found {len(detections)} detections above threshold {threshold}:\n\n"
|
|
|
| if len(detections) > 0:
|
| counts = {}
|
| for i in range(len(detections)):
|
| class_id = int(detections.class_id[i]) if detections.class_id is not None else None
|
| if app_state.class_names and class_id is not None:
|
| if 0 <= class_id < len(app_state.class_names):
|
| name = app_state.class_names[class_id]
|
| else:
|
| name = str(class_id)
|
| else:
|
| name = str(class_id) if class_id is not None else "object"
|
| counts[name] = counts.get(name, 0) + 1
|
|
|
| for name, count in counts.items():
|
| description += f"- {count}× {name}\n"
|
|
|
|
|
| if app_state.config.get('use_llm'):
|
| try:
|
| generator = app_state.get_text_generator(model_size)
|
| llm_description = generator.generate(description, image=annotated)
|
| description = llm_description
|
| except Exception as e:
|
| description = f"[LLM error: {e}]\n\n{description}"
|
| else:
|
| description += "No objects detected above the confidence threshold."
|
|
|
| return annotated, description
|
|
|
| except Exception as e:
|
| error_msg = f"Error processing image: {str(e)}"
|
| print(f"Processing error: {traceback.format_exc()}")
|
| return None, error_msg
|
|
|
|
|
| with gr.Blocks(title="Medical Image Analysis", theme=gr.themes.Soft()) as demo:
|
| gr.Markdown("# 🏥 Medical Image Analysis")
|
| gr.Markdown("Upload a medical image to detect and analyze findings using AI.")
|
|
|
| with gr.Row():
|
| with gr.Column():
|
| input_image = gr.Image(type="pil", label="Upload Image", height=400)
|
| threshold_slider = gr.Slider(
|
| minimum=0.1,
|
| maximum=1.0,
|
| value=0.7,
|
| step=0.05,
|
| label="Confidence Threshold",
|
| info="Higher values = fewer but more confident detections"
|
| )
|
|
|
| model_size_radio = gr.Radio(
|
| choices=["4B", "27B"],
|
| value="4B",
|
| label="MedGemma Model Size",
|
| info="4B: Faster, less memory | 27B: More accurate, more memory"
|
| )
|
|
|
| analyze_btn = gr.Button("🔍 Analyze Image", variant="primary")
|
|
|
| with gr.Column():
|
| output_image = gr.Image(type="pil", label="Results", height=400)
|
| output_text = gr.Textbox(
|
| label="Analysis Results",
|
| lines=8,
|
| max_lines=15,
|
| show_copy_button=True
|
| )
|
|
|
|
|
| analyze_btn.click(
|
| fn=annotate_image,
|
| inputs=[input_image, threshold_slider, model_size_radio],
|
| outputs=[output_image, output_text]
|
| )
|
|
|
|
|
| input_image.change(
|
| fn=annotate_image,
|
| inputs=[input_image, threshold_slider, model_size_radio],
|
| outputs=[output_image, output_text]
|
| )
|
|
|
|
|
| gr.Markdown("---")
|
| gr.Markdown("*Powered by RF-DETR and MedGemma • Built for Hugging Face Spaces*")
|
|
|
| return demo
|
|
|
|
|
|
|
|
|
|
|
|
|
| app_state = AppState()
|
|
|
| def main():
|
| """Main entry point for the Spaces app."""
|
| print("🚀 Starting Medical Image Analysis App")
|
|
|
|
|
| os.makedirs(app_state.config.get('results_dir'), exist_ok=True)
|
|
|
|
|
| demo = create_detection_interface()
|
|
|
|
|
| demo.launch(
|
| server_name="0.0.0.0",
|
| server_port=7860,
|
| share=False,
|
| show_error=True,
|
| show_api=False,
|
| )
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|