import gradio as gr import torch import json import time from PIL import Image from transformers import AutoProcessor, AutoModel from huggingface_hub import InferenceClient # ========================================== # Phase 2: Edge-Side Tensor Router (SiGLIP) # ========================================== class SiGLIPRouter: def __init__(self): print("Loading Edge Routing Engine (SiGLIP-So400M)...") # Load locally for edge-simulated routing self.processor = AutoProcessor.from_pretrained("google/siglip-so400m-patch14-384") self.model = AutoModel.from_pretrained("google/siglip-so400m-patch14-384") def route_evidence(self, visual_query, image_paths, margin_ratio=0.15, absolute_floor=0.1): """ Relative Margin Thresholding Engine Routes images based on dynamic confidence window relative to the best match. """ if not image_paths: return [], {} # Convert file paths to PIL Images images = [Image.open(img).convert("RGB") for img in image_paths] # Extract multimodal tensors inputs = self.processor(text=[visual_query], images=images, padding="max_length", return_tensors="pt") with torch.no_grad(): outputs = self.model(**inputs) # Get Sigmoid matching probabilities logits_per_image = outputs.logits_per_image probs = torch.sigmoid(logits_per_image).squeeze() # Handle single image fallback if probs.dim() == 0: probs = probs.unsqueeze(0) # Calculate dynamic threshold based on the Anchor (Highest Prob) max_prob = torch.max(probs).item() dynamic_threshold = max(max_prob * (1.0 - margin_ratio), absolute_floor) # Precision Pruning routed_images = [] for idx, prob in enumerate(probs): if prob.item() >= dynamic_threshold: routed_images.append(image_paths[idx]) metrics = { "Anchor_Probability (Max)": round(max_prob, 4), "Dynamic_Threshold": round(dynamic_threshold, 4), "Total_Candidates": len(image_paths), "Images_Retained": len(routed_images) } return routed_images, metrics # ========================================== # Phase 1: Cloud-Side Distillation (Qwen) # ========================================== def extract_intent_and_query(clinical_text, hf_token): """ Dual-track extraction using Qwen-72B via Hugging Face Inference API. Translates messy text into Clinical Intent and Visual Target. """ if not hf_token: return {"error": "Missing Hugging Face API Token."}, "Error" client = InferenceClient("Qwen/Qwen2.5-72B-Instruct", token=hf_token) system_prompt = """You are an expert Clinical Triage AI and a Multimodal Routing Specialist. Your task is to analyze messy, noisy patient narratives and output a strictly formatted JSON object with two fields. 1. "Clinical_Intent": A purified medical summary of the core issue. 2. "Visual_Query": An extreme extraction of visual, anatomical, and radiological keywords relevant ONLY to the core issue. Think like an image-recognition model. Use nouns and modalities (e.g., 'Brain MRA, skull, Willis circle'). DO NOT use abstract symptoms like 'headache'. Output ONLY valid JSON.""" messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": f"Patient Narrative:\n{clinical_text}"} ] try: response = client.chat_completion(messages=messages, max_tokens=300) content = response.choices[0].message.content # Clean markdown formatting if present clean_content = content.replace("```json", "").replace("```", "").strip() result = json.loads(clean_content) return result, result.get("Visual_Query", "medical scan") except Exception as e: return {"error": f"Cloud Distillation Failed: {str(e)}"}, "medical scan" # Initialize local router router = SiGLIPRouter() # ========================================== # Main Execution Pipeline # ========================================== def execute_pipeline(hf_token, narrative, images, margin_ratio): start_time = time.time() if not images: return {"Error": "No images uploaded."}, [], {"Status": "Failed"} # 1. Cloud Intelligence (Linguistic Distillation) cloud_result, visual_query = extract_intent_and_query(narrative, hf_token) if "error" in cloud_result: return cloud_result, [], {"Status": "Cloud API Error"} # 2. Edge Intelligence (Tensor Routing) routed_imgs, metrics = router.route_evidence(visual_query, images, margin_ratio=margin_ratio) end_time = time.time() metrics["Total_Latency (s)"] = round(end_time - start_time, 2) metrics["Active_Margin_Ratio"] = margin_ratio return cloud_result, routed_imgs, metrics # ========================================== # Gradio UI Design # ========================================== with gr.Blocks(title="MANN-Engram Showcase", theme=gr.themes.Base()) as demo: gr.Markdown("# 🧠 MANN-Engram: Edge-Cloud Multimodal Semantic Router") gr.Markdown("> **A Privacy-First, Zero-Hallucination Shield for Clinical Vision-Language Models.**") with gr.Row(): # Left Column: Inputs & Settings with gr.Column(scale=1): gr.Markdown("### ⚙️ Engine Settings") hf_token = gr.Textbox(label="Hugging Face API Token (For Cloud Brain)", type="password", placeholder="hf_xxxxxxxx...") margin_slider = gr.Slider( minimum=0.05, maximum=0.40, step=0.05, value=0.15, label="Routing Tolerance (Margin Ratio)", info="Lower (0.05) = Sniper Mode (Extreme Precision). Higher (0.30) = Cluster Mode (Recalls multiple related views)." ) gr.Markdown("### 📥 Patient Data Dump") narrative_input = gr.Textbox( label="Messy Clinical Narrative", lines=8, placeholder="Paste the chaotic patient complaint and history here..." ) image_input = gr.File(label="Upload Unorganized Scans (Images)", file_count="multiple", type="filepath") run_btn = gr.Button("🚀 Execute Routing Pipeline", variant="primary") # Right Column: Outputs with gr.Column(scale=1): gr.Markdown("### ☁️ Cloud Output: Dual-Track Distillation") cloud_output = gr.JSON(label="Purified Intent & Visual Query") gr.Markdown("### 🛡️ Edge Output: Routed Core Evidence") routed_gallery = gr.Gallery(label="Surgically Selected Scans", columns=2, object_fit="contain", height=400) gr.Markdown("### 📊 Telemetry & Metrics") metrics_output = gr.JSON(label="Routing Diagnostics") # Wire up the button run_btn.click( fn=execute_pipeline, inputs=[hf_token, narrative_input, image_input, margin_slider], outputs=[cloud_output, routed_gallery, metrics_output] ) if __name__ == "__main__": demo.launch()