|
|
| try:
|
| import spaces
|
| HF_SPACES_GPU = True
|
| except ImportError:
|
|
|
| class spaces:
|
| @staticmethod
|
| def GPU(func):
|
| return func
|
| HF_SPACES_GPU = False
|
|
|
| import os
|
| import json
|
| import gc
|
| import traceback
|
| from typing import Optional, Tuple, Any
|
|
|
| import torch
|
| import gradio as gr
|
| import supervision as sv
|
| from PIL import Image
|
|
|
|
|
| try:
|
| from transformers import (
|
| AutoModelForCausalLM,
|
| AutoTokenizer,
|
| AutoModelForImageTextToText,
|
| AutoProcessor,
|
| BitsAndBytesConfig,
|
| )
|
| except Exception:
|
| AutoModelForCausalLM = None
|
| AutoTokenizer = None
|
| AutoModelForImageTextToText = None
|
| AutoProcessor = None
|
| BitsAndBytesConfig = None
|
|
|
|
|
| try:
|
| from huggingface_hub import hf_hub_download
|
| except ImportError:
|
| hf_hub_download = None
|
|
|
|
|
| try:
|
| from rfdetr import RFDETRMedium
|
| except ImportError:
|
| print("Warning: RF-DETR not found. Please ensure it's properly installed.")
|
| RFDETRMedium = None
|
|
|
|
|
|
|
|
|
|
|
| class SpacesConfig:
|
| """Configuration optimized for Hugging Face Spaces."""
|
|
|
| def __init__(self):
|
|
|
| hf_token = os.environ.get('HF_TOKEN') or os.environ.get('HUGGINGFACE_TOKEN')
|
|
|
| self.settings = {
|
| 'results_dir': '/tmp/results',
|
| 'checkpoint': None,
|
| 'hf_model_repo': 'edeler/lorai',
|
| 'hf_model_filename': 'lorai.pth',
|
| 'hf_token': hf_token,
|
| 'resolution': 576,
|
| 'threshold': 0.7,
|
| 'use_llm': True,
|
| 'llm_model_id': 'google/medgemma-4b-it',
|
| 'llm_max_new_tokens': 200,
|
| 'llm_temperature': 0.2,
|
| 'llm_4bit': True,
|
| 'enable_caching': True,
|
| 'max_cache_size': 100,
|
| }
|
|
|
| def get(self, key: str, default: Any = None) -> Any:
|
| return self.settings.get(key, default)
|
|
|
| def set_hf_model_repo(self, repo_id: str, filename: str = 'lorai.pth'):
|
| """Set Hugging Face model repository."""
|
| self.settings['hf_model_repo'] = repo_id
|
| self.settings['hf_model_filename'] = filename
|
|
|
|
|
|
|
|
|
|
|
| class MemoryManager:
|
| """Simplified memory management for Spaces."""
|
|
|
| def __init__(self):
|
| self.memory_thresholds = {
|
| 'gpu_warning': 0.8,
|
| 'system_warning': 0.85,
|
| }
|
|
|
| def cleanup_memory(self, force: bool = False) -> None:
|
| """Perform memory cleanup."""
|
| try:
|
| gc.collect()
|
| if torch and torch.cuda.is_available():
|
| torch.cuda.empty_cache()
|
| torch.cuda.synchronize()
|
| except Exception as e:
|
| print(f"Memory cleanup error: {e}")
|
|
|
|
|
| memory_manager = MemoryManager()
|
|
|
|
|
|
|
|
|
|
|
| def find_checkpoint(hf_repo: Optional[str] = None, hf_filename: str = 'lorai.pth') -> Optional[str]:
|
| """Find RF-DETR checkpoint in various locations or download from Hugging Face Hub."""
|
|
|
|
|
| repo_id = hf_repo or os.environ.get('HF_MODEL_REPO')
|
|
|
| if repo_id and hf_hub_download is not None:
|
| try:
|
| print(f"Downloading checkpoint from Hugging Face Hub: {repo_id}/{hf_filename}")
|
| checkpoint_path = hf_hub_download(
|
| repo_id=repo_id,
|
| filename=hf_filename,
|
| cache_dir="/tmp/hf_cache"
|
| )
|
| print(f"✓ Downloaded checkpoint to: {checkpoint_path}")
|
| return checkpoint_path
|
| except Exception as e:
|
| print(f"Warning: Failed to download from Hugging Face Hub: {e}")
|
| print("Falling back to local checkpoints...")
|
|
|
|
|
| candidates = [
|
| "lorai.pth",
|
| "rf-detr-medium.pth",
|
| "/tmp/results/checkpoint_best_total.pth",
|
| "/tmp/results/checkpoint_best_ema.pth",
|
| "/tmp/results/checkpoint_best_regular.pth",
|
| "/tmp/results/checkpoint.pth",
|
| ]
|
|
|
| for path in candidates:
|
| if os.path.isfile(path):
|
| print(f"Found local checkpoint: {path}")
|
| return path
|
|
|
| return None
|
|
|
| def load_model(checkpoint_path: str, resolution: int):
|
| """Load RF-DETR model."""
|
| if RFDETRMedium is None:
|
| raise RuntimeError("RF-DETR not available. Please install it properly.")
|
|
|
| model = RFDETRMedium(pretrain_weights=checkpoint_path, resolution=resolution)
|
| try:
|
| model.optimize_for_inference()
|
| except Exception:
|
| pass
|
| return model
|
|
|
|
|
|
|
|
|
|
|
| class TextGenerator:
|
| """Simplified text generator for Spaces."""
|
|
|
| def __init__(self, model_id: str, max_tokens: int = 200, temperature: float = 0.2):
|
| self.model_id = model_id
|
| self.max_tokens = max_tokens
|
| self.temperature = temperature
|
| self.model = None
|
| self.tokenizer = None
|
| self.processor = None
|
| self.is_multimodal = False
|
|
|
| def load_model(self, hf_token: Optional[str] = None):
|
| """Load the LLM model."""
|
| if self.model is not None:
|
| return
|
|
|
| if (AutoModelForCausalLM is None and AutoModelForImageTextToText is None):
|
| raise RuntimeError("Transformers not available")
|
|
|
|
|
| memory_manager.cleanup_memory()
|
|
|
| print(f"Loading model: {self.model_id}")
|
|
|
| model_kwargs = {
|
| "device_map": "auto",
|
| "low_cpu_mem_usage": True,
|
| }
|
|
|
|
|
| if hf_token:
|
| model_kwargs["token"] = hf_token
|
|
|
| if torch and torch.cuda.is_available():
|
| model_kwargs["torch_dtype"] = torch.bfloat16
|
|
|
|
|
| if BitsAndBytesConfig is not None:
|
| try:
|
| compute_dtype = torch.bfloat16 if torch and torch.cuda.is_available() else torch.float16
|
| model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
| load_in_4bit=True,
|
| bnb_4bit_compute_dtype=compute_dtype,
|
| bnb_4bit_use_double_quant=True,
|
| bnb_4bit_quant_type="nf4"
|
| )
|
| model_kwargs["torch_dtype"] = compute_dtype
|
| except Exception:
|
| pass
|
|
|
|
|
| is_multimodal = "medgemma" in self.model_id.lower()
|
|
|
| if is_multimodal and AutoModelForImageTextToText is not None and AutoProcessor is not None:
|
| self.processor = AutoProcessor.from_pretrained(self.model_id, token=hf_token)
|
| self.model = AutoModelForImageTextToText.from_pretrained(self.model_id, **model_kwargs)
|
| self.is_multimodal = True
|
| elif AutoModelForCausalLM is not None and AutoTokenizer is not None:
|
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, token=hf_token)
|
| self.model = AutoModelForCausalLM.from_pretrained(self.model_id, **model_kwargs)
|
| self.is_multimodal = False
|
| else:
|
| raise RuntimeError("Required model classes not available")
|
|
|
| print("✓ Model loaded successfully")
|
|
|
| def generate(self, text: str, image: Optional[Image.Image] = None, hf_token: Optional[str] = None) -> str:
|
| """Generate text using the loaded model."""
|
| self.load_model(hf_token)
|
|
|
| if self.model is None:
|
| return f"[Model not loaded: {text}]"
|
|
|
| try:
|
|
|
| system_text = "You are a concise medical assistant. Provide a brief, clear summary of detection results. Avoid repetition and be direct. Do not give medical advice."
|
| user_text = f"Summarize these detection results in 3 clear sentences:\n\n{text}"
|
|
|
| if self.is_multimodal:
|
|
|
| user_content = [{"type": "text", "text": user_text}]
|
| if image is not None:
|
| user_content.append({"type": "image", "image": image})
|
|
|
| messages = [
|
| {"role": "system", "content": [{"type": "text", "text": system_text}]},
|
| {"role": "user", "content": user_content},
|
| ]
|
|
|
| inputs = self.processor.apply_chat_template(
|
| messages,
|
| add_generation_prompt=True,
|
| tokenize=True,
|
| return_dict=True,
|
| return_tensors="pt",
|
| )
|
|
|
| if torch:
|
| inputs = inputs.to(self.model.device, dtype=torch.bfloat16)
|
|
|
| with torch.inference_mode():
|
| generation = self.model.generate(
|
| **inputs,
|
| max_new_tokens=self.max_tokens,
|
| do_sample=self.temperature > 0,
|
| temperature=max(0.01, self.temperature) if self.temperature > 0 else None,
|
| use_cache=False,
|
| )
|
|
|
| input_len = inputs["input_ids"].shape[-1]
|
| generation = generation[0][input_len:]
|
| decoded = self.processor.decode(generation, skip_special_tokens=True)
|
| return decoded.strip()
|
|
|
| else:
|
|
|
| messages = [
|
| {"role": "system", "content": system_text},
|
| {"role": "user", "content": user_text},
|
| ]
|
|
|
| inputs = self.tokenizer.apply_chat_template(
|
| messages,
|
| add_generation_prompt=True,
|
| tokenize=True,
|
| return_dict=True,
|
| return_tensors="pt",
|
| )
|
|
|
| inputs = inputs.to(self.model.device)
|
|
|
| with torch.inference_mode():
|
| generation = self.model.generate(
|
| **inputs,
|
| max_new_tokens=self.max_tokens,
|
| do_sample=self.temperature > 0,
|
| temperature=max(0.01, self.temperature) if self.temperature > 0 else None,
|
| use_cache=False,
|
| )
|
|
|
| input_len = inputs["input_ids"].shape[-1]
|
| generation = generation[0][input_len:]
|
| decoded = self.tokenizer.decode(generation, skip_special_tokens=True)
|
| return decoded.strip()
|
|
|
| except Exception as e:
|
| error_msg = f"[Generation error: {e}]"
|
| print(f"Generation error: {traceback.format_exc()}")
|
| return f"{error_msg}\n\n{text}"
|
|
|
|
|
|
|
|
|
|
|
| class AppState:
|
| """Application state for Spaces."""
|
|
|
| def __init__(self):
|
| self.config = SpacesConfig()
|
| self.model = None
|
| self.class_names = None
|
| self.text_generator = None
|
|
|
| def load_model(self):
|
| """Load the detection model."""
|
| if self.model is not None:
|
| return
|
|
|
| checkpoint = find_checkpoint(
|
| hf_repo=self.config.get('hf_model_repo'),
|
| hf_filename=self.config.get('hf_model_filename', 'lorai.pth')
|
| )
|
| if not checkpoint:
|
| hf_repo = self.config.get('hf_model_repo') or os.environ.get('HF_MODEL_REPO')
|
| if hf_repo:
|
| raise FileNotFoundError(
|
| f"No RF-DETR checkpoint found. Could not download from '{hf_repo}'. "
|
| "Please check the repository ID and ensure the model file exists."
|
| )
|
| else:
|
| raise FileNotFoundError(
|
| "No RF-DETR checkpoint found. Please either:\n"
|
| "1. Set HF_MODEL_REPO environment variable (e.g., 'edeler/lorai'), or\n"
|
| "2. Upload lorai.pth to your Space's root directory"
|
| )
|
|
|
| print(f"Loading RF-DETR from: {checkpoint}")
|
| self.model = load_model(checkpoint, self.config.get('resolution'))
|
|
|
|
|
| try:
|
| results_json = "/tmp/results/results.json"
|
| if os.path.isfile(results_json):
|
| with open(results_json, 'r') as f:
|
| data = json.load(f)
|
| classes = []
|
| for split in ("valid", "test", "train"):
|
| if "class_map" in data and split in data["class_map"]:
|
| for item in data["class_map"][split]:
|
| name = item.get("class")
|
| if name and name != "all" and name not in classes:
|
| classes.append(name)
|
| self.class_names = classes if classes else None
|
| except Exception:
|
| pass
|
|
|
| print("✓ RF-DETR model loaded")
|
|
|
| def preload_all_models(self):
|
| """Preload both detection and LLM models into VRAM at startup."""
|
| print("=" * 60)
|
| print("Preloading all models into VRAM...")
|
| print("=" * 60)
|
|
|
|
|
| print("\n[1/2] Loading RF-DETR detection model...")
|
| self.load_model()
|
|
|
|
|
| if self.config.get('use_llm'):
|
| print("\n[2/2] Loading MedGemma LLM model...")
|
| try:
|
| model_size = "4B"
|
| generator = self.get_text_generator(model_size)
|
| hf_token = self.config.get('hf_token')
|
| generator.load_model(hf_token)
|
| print("✓ MedGemma model loaded and ready")
|
| except Exception as e:
|
| print(f"⚠️ Warning: Could not preload LLM model: {e}")
|
| print("LLM will be loaded on first use instead")
|
|
|
| print("\n" + "=" * 60)
|
| print("✓ All models loaded and ready in VRAM!")
|
| print("=" * 60 + "\n")
|
|
|
| def get_text_generator(self, model_size: str = "4B") -> TextGenerator:
|
| """Get or create text generator."""
|
|
|
| model_id = 'google/medgemma-27b-it' if model_size == "27B" else 'google/medgemma-4b-it'
|
|
|
|
|
| if (self.text_generator is None or
|
| hasattr(self.text_generator, 'model_id') and
|
| self.text_generator.model_id != model_id):
|
|
|
| max_tokens = self.config.get('llm_max_new_tokens')
|
| temperature = self.config.get('llm_temperature')
|
|
|
| self.text_generator = TextGenerator(model_id, max_tokens, temperature)
|
| return self.text_generator
|
|
|
|
|
|
|
|
|
|
|
| def create_detection_interface():
|
| """Create the Gradio interface."""
|
|
|
|
|
| COLOR_PALETTE = sv.ColorPalette.from_hex([
|
| "#ffff00", "#ff9b00", "#ff66ff", "#3399ff", "#ff66b2",
|
| "#ff8080", "#b266ff", "#9999ff", "#66ffff", "#33ff99",
|
| "#66ff66", "#99ff00",
|
| ])
|
|
|
| @spaces.GPU
|
| def annotate_image(image: Image.Image, threshold: float, model_size: str = "4B") -> Tuple[Image.Image, str]:
|
| """Process an image and return annotated version with description."""
|
|
|
| if image is None:
|
| return None, "Please upload an image."
|
|
|
| try:
|
|
|
| if app_state.model is None:
|
| app_state.load_model()
|
|
|
|
|
| detections = app_state.model.predict(image, threshold=threshold)
|
|
|
|
|
| bbox_annotator = sv.BoxAnnotator(color=COLOR_PALETTE, thickness=2)
|
| label_annotator = sv.LabelAnnotator(text_scale=0.5, text_color=sv.Color.BLACK)
|
|
|
| labels = []
|
| for i in range(len(detections)):
|
| class_id = int(detections.class_id[i]) if detections.class_id is not None else None
|
| conf = float(detections.confidence[i]) if detections.confidence is not None else 0.0
|
|
|
| if app_state.class_names and class_id is not None:
|
| if 0 <= class_id < len(app_state.class_names):
|
| label_name = app_state.class_names[class_id]
|
| else:
|
| label_name = str(class_id)
|
| else:
|
| label_name = str(class_id) if class_id is not None else "object"
|
|
|
| labels.append(f"{label_name} {conf:.2f}")
|
|
|
| annotated = image.copy()
|
| annotated = bbox_annotator.annotate(annotated, detections)
|
| annotated = label_annotator.annotate(annotated, detections, labels)
|
|
|
|
|
| description = f"Found {len(detections)} detections above threshold {threshold}:\n\n"
|
|
|
| if len(detections) > 0:
|
| counts = {}
|
| for i in range(len(detections)):
|
| class_id = int(detections.class_id[i]) if detections.class_id is not None else None
|
| if app_state.class_names and class_id is not None:
|
| if 0 <= class_id < len(app_state.class_names):
|
| name = app_state.class_names[class_id]
|
| else:
|
| name = str(class_id)
|
| else:
|
| name = str(class_id) if class_id is not None else "object"
|
| counts[name] = counts.get(name, 0) + 1
|
|
|
| for name, count in counts.items():
|
| description += f"- {count}× {name}\n"
|
|
|
|
|
| if app_state.config.get('use_llm'):
|
| try:
|
| generator = app_state.get_text_generator(model_size)
|
| hf_token = app_state.config.get('hf_token')
|
|
|
| llm_description = generator.generate(description, image=annotated, hf_token=hf_token)
|
| description = llm_description
|
| except Exception as e:
|
| print(f"LLM generation failed: {e}")
|
|
|
| pass
|
| else:
|
| description += "No objects detected above the confidence threshold."
|
|
|
| return annotated, description
|
|
|
| except Exception as e:
|
| error_msg = f"Error processing image: {str(e)}"
|
| print(f"Processing error: {traceback.format_exc()}")
|
| return None, error_msg
|
|
|
|
|
| with gr.Blocks(title="Medical Image Analysis", theme=gr.themes.Soft()) as demo:
|
| gr.Markdown("# 🏥 Medical Image Analysis")
|
| gr.Markdown("Upload a medical image to detect and analyze findings using AI.")
|
|
|
|
|
| hf_token = app_state.config.get('hf_token')
|
| if not hf_token:
|
| gr.Markdown("⚠️ **Note:** HF_TOKEN not set. AI text generation will be disabled. Detection will still work.")
|
| else:
|
| gr.Markdown("✅ **AI-powered analysis enabled** using MedGemma 4B")
|
|
|
| with gr.Row():
|
| with gr.Column():
|
| input_image = gr.Image(type="pil", label="Upload Image", height=400)
|
| threshold_slider = gr.Slider(
|
| minimum=0.1,
|
| maximum=1.0,
|
| value=0.7,
|
| step=0.05,
|
| label="Confidence Threshold",
|
| info="Higher values = fewer but more confident detections"
|
| )
|
|
|
| model_size_radio = gr.Radio(
|
| choices=["4B"],
|
| value="4B",
|
| label="MedGemma Model Size",
|
| info="Using MedGemma 4B for AI-generated analysis",
|
| visible=False
|
| )
|
|
|
| analyze_btn = gr.Button("🔍 Analyze Image", variant="primary")
|
|
|
|
|
| gr.Examples(
|
| examples=[
|
| ["1.jpg"],
|
| ["2.jpg"],
|
| ["3.jpg"],
|
| ],
|
| inputs=input_image,
|
| label="Example Images",
|
| examples_per_page=3
|
| )
|
|
|
| with gr.Column():
|
| output_image = gr.Image(type="pil", label="Results", height=400)
|
| output_text = gr.Textbox(
|
| label="Analysis Results",
|
| lines=8,
|
| max_lines=15,
|
| show_copy_button=True
|
| )
|
|
|
|
|
| analyze_btn.click(
|
| fn=annotate_image,
|
| inputs=[input_image, threshold_slider, model_size_radio],
|
| outputs=[output_image, output_text]
|
| )
|
|
|
|
|
| input_image.change(
|
| fn=annotate_image,
|
| inputs=[input_image, threshold_slider, model_size_radio],
|
| outputs=[output_image, output_text]
|
| )
|
|
|
|
|
| gr.Markdown("---")
|
|
|
| return demo
|
|
|
|
|
|
|
|
|
|
|
|
|
| app_state = AppState()
|
|
|
| def main():
|
| """Main entry point for the Spaces app."""
|
| print("🚀 Starting Medical Image Analysis App")
|
|
|
|
|
| os.makedirs(app_state.config.get('results_dir'), exist_ok=True)
|
|
|
|
|
|
|
| if not HF_SPACES_GPU:
|
| print("Running locally - preloading models into VRAM...")
|
| try:
|
| app_state.preload_all_models()
|
| except Exception as e:
|
| print(f"⚠️ Warning: Failed to preload models: {e}")
|
| print("Models will be loaded on first use instead")
|
| else:
|
| print("Running on HF Spaces - models will load on first inference (via @spaces.GPU)")
|
| print("This is the recommended approach for Spaces GPU management.")
|
|
|
|
|
| demo = create_detection_interface()
|
|
|
|
|
| demo.launch(
|
| server_name="0.0.0.0",
|
| server_port=7860,
|
| share=False,
|
| show_error=True,
|
| show_api=False,
|
| )
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|