Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import json | |
| import time | |
| from PIL import Image | |
| from transformers import AutoProcessor, AutoModel | |
| from huggingface_hub import InferenceClient | |
| # ========================================== | |
| # Phase 2: Edge-Side Tensor Router (SiGLIP) | |
| # ========================================== | |
| class SiGLIPRouter: | |
| def __init__(self): | |
| print("Loading Edge Routing Engine (SiGLIP-So400M)...") | |
| # Load locally for edge-simulated routing | |
| self.processor = AutoProcessor.from_pretrained("google/siglip-so400m-patch14-384") | |
| self.model = AutoModel.from_pretrained("google/siglip-so400m-patch14-384") | |
| def route_evidence(self, visual_query, image_paths, margin_ratio=0.15, absolute_floor=0.1): | |
| """ | |
| Relative Margin Thresholding Engine | |
| Routes images based on dynamic confidence window relative to the best match. | |
| """ | |
| if not image_paths: | |
| return [], {} | |
| # Convert file paths to PIL Images | |
| images = [Image.open(img).convert("RGB") for img in image_paths] | |
| # Extract multimodal tensors | |
| inputs = self.processor(text=[visual_query], images=images, padding="max_length", return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| # Get Sigmoid matching probabilities | |
| logits_per_image = outputs.logits_per_image | |
| probs = torch.sigmoid(logits_per_image).squeeze() | |
| # Handle single image fallback | |
| if probs.dim() == 0: | |
| probs = probs.unsqueeze(0) | |
| # Calculate dynamic threshold based on the Anchor (Highest Prob) | |
| max_prob = torch.max(probs).item() | |
| dynamic_threshold = max(max_prob * (1.0 - margin_ratio), absolute_floor) | |
| # Precision Pruning | |
| routed_images = [] | |
| for idx, prob in enumerate(probs): | |
| if prob.item() >= dynamic_threshold: | |
| routed_images.append(image_paths[idx]) | |
| metrics = { | |
| "Anchor_Probability (Max)": round(max_prob, 4), | |
| "Dynamic_Threshold": round(dynamic_threshold, 4), | |
| "Total_Candidates": len(image_paths), | |
| "Images_Retained": len(routed_images) | |
| } | |
| return routed_images, metrics | |
| # ========================================== | |
| # Phase 1: Cloud-Side Distillation (Qwen) | |
| # ========================================== | |
| def extract_intent_and_query(clinical_text, hf_token): | |
| """ | |
| Dual-track extraction using Qwen-72B via Hugging Face Inference API. | |
| Translates messy text into Clinical Intent and Visual Target. | |
| """ | |
| if not hf_token: | |
| return {"error": "Missing Hugging Face API Token."}, "Error" | |
| client = InferenceClient("Qwen/Qwen2.5-72B-Instruct", token=hf_token) | |
| system_prompt = """You are an expert Clinical Triage AI and a Multimodal Routing Specialist. | |
| Your task is to analyze messy, noisy patient narratives and output a strictly formatted JSON object with two fields. | |
| 1. "Clinical_Intent": A purified medical summary of the core issue. | |
| 2. "Visual_Query": An extreme extraction of visual, anatomical, and radiological keywords relevant ONLY to the core issue. Think like an image-recognition model. Use nouns and modalities (e.g., 'Brain MRA, skull, Willis circle'). DO NOT use abstract symptoms like 'headache'. | |
| Output ONLY valid JSON.""" | |
| messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": f"Patient Narrative:\n{clinical_text}"} | |
| ] | |
| try: | |
| response = client.chat_completion(messages=messages, max_tokens=300) | |
| content = response.choices[0].message.content | |
| # Clean markdown formatting if present | |
| clean_content = content.replace("```json", "").replace("```", "").strip() | |
| result = json.loads(clean_content) | |
| return result, result.get("Visual_Query", "medical scan") | |
| except Exception as e: | |
| return {"error": f"Cloud Distillation Failed: {str(e)}"}, "medical scan" | |
| # Initialize local router | |
| router = SiGLIPRouter() | |
| # ========================================== | |
| # Main Execution Pipeline | |
| # ========================================== | |
| def execute_pipeline(hf_token, narrative, images, margin_ratio): | |
| start_time = time.time() | |
| if not images: | |
| return {"Error": "No images uploaded."}, [], {"Status": "Failed"} | |
| # 1. Cloud Intelligence (Linguistic Distillation) | |
| cloud_result, visual_query = extract_intent_and_query(narrative, hf_token) | |
| if "error" in cloud_result: | |
| return cloud_result, [], {"Status": "Cloud API Error"} | |
| # 2. Edge Intelligence (Tensor Routing) | |
| routed_imgs, metrics = router.route_evidence(visual_query, images, margin_ratio=margin_ratio) | |
| end_time = time.time() | |
| metrics["Total_Latency (s)"] = round(end_time - start_time, 2) | |
| metrics["Active_Margin_Ratio"] = margin_ratio | |
| return cloud_result, routed_imgs, metrics | |
| # ========================================== | |
| # Gradio UI Design | |
| # ========================================== | |
| with gr.Blocks(title="MANN-Engram Showcase", theme=gr.themes.Base()) as demo: | |
| gr.Markdown("# 🧠 MANN-Engram: Edge-Cloud Multimodal Semantic Router") | |
| gr.Markdown("> **A Privacy-First, Zero-Hallucination Shield for Clinical Vision-Language Models.**") | |
| with gr.Row(): | |
| # Left Column: Inputs & Settings | |
| with gr.Column(scale=1): | |
| gr.Markdown("### ⚙️ Engine Settings") | |
| hf_token = gr.Textbox(label="Hugging Face API Token (For Cloud Brain)", type="password", placeholder="hf_xxxxxxxx...") | |
| margin_slider = gr.Slider( | |
| minimum=0.05, maximum=0.40, step=0.05, value=0.15, | |
| label="Routing Tolerance (Margin Ratio)", | |
| info="Lower (0.05) = Sniper Mode (Extreme Precision). Higher (0.30) = Cluster Mode (Recalls multiple related views)." | |
| ) | |
| gr.Markdown("### 📥 Patient Data Dump") | |
| narrative_input = gr.Textbox( | |
| label="Messy Clinical Narrative", lines=8, | |
| placeholder="Paste the chaotic patient complaint and history here..." | |
| ) | |
| image_input = gr.File(label="Upload Unorganized Scans (Images)", file_count="multiple", type="filepath") | |
| run_btn = gr.Button("🚀 Execute Routing Pipeline", variant="primary") | |
| # Right Column: Outputs | |
| with gr.Column(scale=1): | |
| gr.Markdown("### ☁️ Cloud Output: Dual-Track Distillation") | |
| cloud_output = gr.JSON(label="Purified Intent & Visual Query") | |
| gr.Markdown("### 🛡️ Edge Output: Routed Core Evidence") | |
| routed_gallery = gr.Gallery(label="Surgically Selected Scans", columns=2, object_fit="contain", height=400) | |
| gr.Markdown("### 📊 Telemetry & Metrics") | |
| metrics_output = gr.JSON(label="Routing Diagnostics") | |
| # Wire up the button | |
| run_btn.click( | |
| fn=execute_pipeline, | |
| inputs=[hf_token, narrative_input, image_input, margin_slider], | |
| outputs=[cloud_output, routed_gallery, metrics_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |