Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from PIL import Image | |
| import torch | |
| import numpy as np | |
| import os | |
| from io import StringIO | |
| import sys | |
| import torch.nn as nn | |
| # --- TorchDynamo Fix for Unsloth/MedGemma --- | |
| import torch._dynamo | |
| torch._dynamo.config.capture_scalar_outputs = True | |
| # --- DEFINITIVE FIX FOR JIT COMPILER ERRORS --- | |
| torch.compiler.disable() | |
| # --- Dependency Handling --- | |
| try: | |
| from monai.networks.nets import SwinUNETR | |
| import torchvision.transforms as T | |
| from unsloth import FastVisionModel | |
| from transformers import TextStreamer | |
| from s2wrapper import forward as multiscale_forward | |
| except ImportError as e: | |
| st.error(f"A required library is not installed. Please install dependencies. Error: {e}") | |
| st.stop() | |
| # --- Config and Model Definition --- | |
| class Config: | |
| ORIGINAL_LABELS = [0,3,6,9,12,15,18,21,24,27,30,33,36,39,42,45,48,51,54,57,60] | |
| LABEL_MAP = {val: i for i, val in enumerate(ORIGINAL_LABELS)} | |
| NUM_CLASSES = len(ORIGINAL_LABELS) | |
| IMG_SIZE = (256, 256) | |
| FEATURE_SIZE = 48 | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| class multiscaleSwinUNETR(nn.Module): | |
| def __init__(self, num_classes, scales=[1]): | |
| super().__init__() | |
| self.scales = scales | |
| self.num_classes = num_classes | |
| self.model = SwinUNETR( | |
| spatial_dims=2, | |
| in_channels=3, | |
| out_channels=num_classes, | |
| feature_size=Config.FEATURE_SIZE, | |
| drop_rate=0.0, | |
| attn_drop_rate=0.0, | |
| dropout_path_rate=0.0, | |
| use_checkpoint=True, | |
| use_v2=True | |
| ) | |
| self.segmentation_head = nn.Sequential( | |
| nn.Conv2d(len(scales)*num_classes, num_classes, 3, padding=1), | |
| nn.BatchNorm2d(num_classes), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(num_classes, num_classes, 1) | |
| ) | |
| def forward(self, x): | |
| outs = multiscale_forward(self.model, x, scales=self.scales, output_shape="bchw") | |
| if isinstance(outs, (list, tuple)): | |
| normed = [] | |
| for f in outs: | |
| f = f / (f.std(dim=(2, 3), keepdim=True) + 1e-6) | |
| normed.append(f) | |
| feats = torch.cat(normed, dim=1) | |
| elif isinstance(outs, torch.Tensor) and outs.dim() == 4: | |
| if len(self.scales) == 1: | |
| return outs | |
| feats = outs / (outs.std(dim=(2, 3), keepdim=True) + 1e-6) | |
| else: | |
| raise ValueError(f"Unexpected output shape/type from multiscale_forward: {type(outs)}, {getattr(outs,'shape',None)}") | |
| logits = self.segmentation_head(feats) | |
| return logits | |
| # --- Model Loading --- | |
| def load_swinunetr_model(): | |
| """Loads the multiscale SwinUNETR segmentation model.""" | |
| model_path = 's2-swinunetr-weights.pth' | |
| if not os.path.exists(model_path): | |
| st.error(f"Segmentation model file not found at {model_path}") | |
| return None, None | |
| try: | |
| model = multiscaleSwinUNETR(num_classes=Config.NUM_CLASSES, scales=[1]) | |
| model.load_state_dict(torch.load(model_path, map_location=Config.DEVICE)) | |
| model.eval() | |
| return model, Config | |
| except Exception as e: | |
| st.error(f"Error loading segmentation model: {e}") | |
| return None, None | |
| def load_medgemma_model(): | |
| """Loads the MedGemma vision-language model in eager mode.""" | |
| try: | |
| model, processor = FastVisionModel.from_pretrained( | |
| "fiqqy/MedGemma-MM-OR-FT10", | |
| load_in_4bit=False, | |
| use_gradient_checkpointing="unsloth", | |
| ) | |
| return model, processor | |
| except Exception as e: | |
| st.error(f"Error loading MedGemma model: {e}") | |
| return None, None | |
| # --- Preprocessing --- | |
| def preprocess_frames(frames, config): | |
| """Prepares image frames for the segmentation model.""" | |
| transform = T.Compose([ | |
| T.Resize(config.IMG_SIZE, antialias=True), | |
| T.ToTensor(), | |
| T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| tensors = [transform(frame.convert("RGB")) for frame in frames] | |
| batch = torch.stack(tensors) | |
| return batch | |
| # --- Color Palette for Mask Visualization --- | |
| def make_palette(num_classes): | |
| rng = np.random.default_rng(0) | |
| colors = rng.integers(0, 255, size=(num_classes, 3), dtype=np.uint8) | |
| colors[0] = np.array([0, 0, 0]) | |
| return colors | |
| # --- Inference --- | |
| def run_segmentation(model, config, frames): | |
| """Runs segmentation on the uploaded frames and visualizes with a color palette.""" | |
| st.write("Running segmentation...") | |
| batch = preprocess_frames(frames, config) | |
| device = config.DEVICE | |
| batch = batch.to(device) | |
| model = model.to(device) | |
| with torch.no_grad(): | |
| logits = model(batch) | |
| preds = torch.argmax(logits, 1).cpu().numpy() | |
| mask = preds[0] | |
| st.write(f"Mask unique values: {np.unique(mask)}") | |
| palette = make_palette(config.NUM_CLASSES) | |
| color_mask = palette[mask] | |
| mask_img = Image.fromarray(color_mask.astype(np.uint8)) | |
| return mask_img | |
| # --- MedGemma Captioning --- | |
| def run_captioning(medgemma_model, processor, frames, mask_img, instruction): | |
| """Runs MedGemma inference using 3 frames, 1 mask, and an instruction.""" | |
| st.write("Preparing inputs for MedGemma...") | |
| images = [f.convert("RGB") for f in frames] | |
| mask_img = mask_img.convert("RGB") | |
| messages = [ | |
| {"role": "user", "content": [ | |
| {"type": "image"}, {"type": "image"}, {"type": "image"}, {"type": "image"}, | |
| {"type": "text", "text": instruction}, | |
| ]}, | |
| ] | |
| input_text = processor.apply_chat_template(messages, add_generation_prompt=True) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| all_images = images + [mask_img] | |
| inputs = processor( | |
| all_images, input_text, add_special_tokens=False, return_tensors="pt", | |
| ).to(device) | |
| text_streamer = TextStreamer(processor, skip_prompt=True) | |
| old_stdout = sys.stdout | |
| sys.stdout = captured_output = StringIO() | |
| st.write("Running MedGemma Analysis...") | |
| torch._dynamo.disable() | |
| medgemma_model.generate( | |
| **inputs, streamer=text_streamer, max_new_tokens=768, | |
| use_cache=True, temperature=1.0, top_p=0.95, top_k=64 | |
| ) | |
| sys.stdout = old_stdout | |
| result = captured_output.getvalue() | |
| return result | |
| # --- Streamlit UI --- | |
| def show(): | |
| """Main function to render the Streamlit UI.""" | |
| st.title("Surgical Scene Analysis System") | |
| st.write("A system to test surgical scene segmentation and captioning models.") | |
| st.header("1. Load Models") | |
| if "seg_model" not in st.session_state or "seg_config" not in st.session_state: | |
| st.session_state.seg_model, st.session_state.seg_config = None, None | |
| if st.button("Load Segmentation Model"): | |
| with st.spinner("Loading SwinUNETR..."): | |
| st.session_state.seg_model, st.session_state.seg_config = load_swinunetr_model() | |
| if st.session_state.seg_model is not None: | |
| st.success("Segmentation model is loaded.") | |
| else: | |
| st.warning("Segmentation model is not loaded.") | |
| if "medgemma_model" not in st.session_state: | |
| st.session_state.medgemma_model, st.session_state.processor = None, None | |
| if st.button("Load MedGemma Model"): | |
| with st.spinner("Loading MedGemma... This can take several minutes."): | |
| st.session_state.medgemma_model, st.session_state.processor = load_medgemma_model() | |
| if st.session_state.get("medgemma_model") and st.session_state.get("processor"): | |
| st.success("MedGemma model is loaded.") | |
| else: | |
| st.warning("MedGemma model is not loaded.") | |
| st.header("2. Upload Data & Generate Mask") | |
| st.subheader("Upload Three Sequential Surgical Video Frames") | |
| col1, col2, col3 = st.columns(3) | |
| uploaded_files = [ | |
| col1.file_uploader("Upload Frame 1", type=["png", "jpg", "jpeg"], key="frame1"), | |
| col2.file_uploader("Upload Frame 2", type=["png", "jpg", "jpeg"], key="frame2"), | |
| col3.file_uploader("Upload Frame 3", type=["png", "jpg", "jpeg"], key="frame3") | |
| ] | |
| frames = [Image.open(f) for f in uploaded_files if f is not None] | |
| display_size = (256, 256) | |
| if "mask_img" not in st.session_state: | |
| st.session_state.mask_img = None | |
| if len(frames) == 3: | |
| st.success("All three frames have been uploaded successfully.") | |
| img_cols = st.columns(4) | |
| for i, frame in enumerate(frames): | |
| img_cols[i].image(frame.resize(display_size), caption=f"Frame {i+1}", use_container_width=True) | |
| if st.session_state.seg_model and st.session_state.seg_config and st.button("Run Segmentation"): | |
| with st.spinner("Generating segmentation mask..."): | |
| st.session_state.mask_img = run_segmentation(st.session_state.seg_model, st.session_state.seg_config, frames) | |
| if st.session_state.mask_img is not None: | |
| img_cols[3].image(st.session_state.mask_img.resize(display_size), caption="Segmentation Mask", use_container_width=True) | |
| else: | |
| st.info("Please upload all three frames to proceed.") | |
| st.header("3. Generate Scene Analysis") | |
| instruction_prompt = st.text_area( | |
| "Enter your custom instruction prompt:", | |
| "Provide a detailed summary of the surgical action, noting the instruments used and their interactions." | |
| ) | |
| can_run_analysis = ( | |
| st.session_state.get("medgemma_model") is not None and | |
| len(frames) == 3 and | |
| st.session_state.get("mask_img") is not None and | |
| bool(instruction_prompt) | |
| ) | |
| if st.button("Run Analysis", disabled=not can_run_analysis): | |
| with st.spinner("Running MedGemma analysis... This may take a moment."): | |
| result = run_captioning( | |
| st.session_state.medgemma_model, st.session_state.processor, | |
| frames, st.session_state.mask_img, instruction_prompt | |
| ) | |
| st.subheader("Analysis Result") | |
| st.write(result) | |
| if not can_run_analysis: | |
| st.warning("Please ensure the MedGemma model is loaded, three frames are uploaded, segmentation is complete, and a prompt is provided.") | |
| if __name__ == "__main__": | |
| show() | |