Spaces:
Sleeping
Sleeping
File size: 7,323 Bytes
b662a8d 304d804 1bde015 b662a8d 304d804 b662a8d 304d804 b662a8d 4e0ab32 304d804 1bde015 304d804 1bde015 304d804 1bde015 304d804 b662a8d 304d804 b662a8d 304d804 b662a8d 304d804 b662a8d 304d804 d4a27b4 304d804 b662a8d 304d804 b662a8d 304d804 b662a8d 304d804 b662a8d 304d804 6b2e38e 304d804 4e0ab32 6b2e38e 304d804 4e0ab32 304d804 d4a27b4 304d804 4e0ab32 304d804 b662a8d 304d804 b662a8d 1bde015 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 | import gradio as gr
import torch
import json
import time
from PIL import Image
from transformers import AutoProcessor, AutoModel
from huggingface_hub import InferenceClient
# ==========================================
# Phase 2: Edge-Side Tensor Router (SiGLIP)
# ==========================================
class SiGLIPRouter:
def __init__(self):
print("Loading Edge Routing Engine (SiGLIP-So400M)...")
# Load locally for edge-simulated routing
self.processor = AutoProcessor.from_pretrained("google/siglip-so400m-patch14-384")
self.model = AutoModel.from_pretrained("google/siglip-so400m-patch14-384")
def route_evidence(self, visual_query, image_paths, margin_ratio=0.15, absolute_floor=0.1):
"""
Relative Margin Thresholding Engine
Routes images based on dynamic confidence window relative to the best match.
"""
if not image_paths:
return [], {}
# Convert file paths to PIL Images
images = [Image.open(img).convert("RGB") for img in image_paths]
# Extract multimodal tensors
inputs = self.processor(text=[visual_query], images=images, padding="max_length", return_tensors="pt")
with torch.no_grad():
outputs = self.model(**inputs)
# Get Sigmoid matching probabilities
logits_per_image = outputs.logits_per_image
probs = torch.sigmoid(logits_per_image).squeeze()
# Handle single image fallback
if probs.dim() == 0:
probs = probs.unsqueeze(0)
# Calculate dynamic threshold based on the Anchor (Highest Prob)
max_prob = torch.max(probs).item()
dynamic_threshold = max(max_prob * (1.0 - margin_ratio), absolute_floor)
# Precision Pruning
routed_images = []
for idx, prob in enumerate(probs):
if prob.item() >= dynamic_threshold:
routed_images.append(image_paths[idx])
metrics = {
"Anchor_Probability (Max)": round(max_prob, 4),
"Dynamic_Threshold": round(dynamic_threshold, 4),
"Total_Candidates": len(image_paths),
"Images_Retained": len(routed_images)
}
return routed_images, metrics
# ==========================================
# Phase 1: Cloud-Side Distillation (Qwen)
# ==========================================
def extract_intent_and_query(clinical_text, hf_token):
"""
Dual-track extraction using Qwen-72B via Hugging Face Inference API.
Translates messy text into Clinical Intent and Visual Target.
"""
if not hf_token:
return {"error": "Missing Hugging Face API Token."}, "Error"
client = InferenceClient("Qwen/Qwen2.5-72B-Instruct", token=hf_token)
system_prompt = """You are an expert Clinical Triage AI and a Multimodal Routing Specialist.
Your task is to analyze messy, noisy patient narratives and output a strictly formatted JSON object with two fields.
1. "Clinical_Intent": A purified medical summary of the core issue.
2. "Visual_Query": An extreme extraction of visual, anatomical, and radiological keywords relevant ONLY to the core issue. Think like an image-recognition model. Use nouns and modalities (e.g., 'Brain MRA, skull, Willis circle'). DO NOT use abstract symptoms like 'headache'.
Output ONLY valid JSON."""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": f"Patient Narrative:\n{clinical_text}"}
]
try:
response = client.chat_completion(messages=messages, max_tokens=300)
content = response.choices[0].message.content
# Clean markdown formatting if present
clean_content = content.replace("```json", "").replace("```", "").strip()
result = json.loads(clean_content)
return result, result.get("Visual_Query", "medical scan")
except Exception as e:
return {"error": f"Cloud Distillation Failed: {str(e)}"}, "medical scan"
# Initialize local router
router = SiGLIPRouter()
# ==========================================
# Main Execution Pipeline
# ==========================================
def execute_pipeline(hf_token, narrative, images, margin_ratio):
start_time = time.time()
if not images:
return {"Error": "No images uploaded."}, [], {"Status": "Failed"}
# 1. Cloud Intelligence (Linguistic Distillation)
cloud_result, visual_query = extract_intent_and_query(narrative, hf_token)
if "error" in cloud_result:
return cloud_result, [], {"Status": "Cloud API Error"}
# 2. Edge Intelligence (Tensor Routing)
routed_imgs, metrics = router.route_evidence(visual_query, images, margin_ratio=margin_ratio)
end_time = time.time()
metrics["Total_Latency (s)"] = round(end_time - start_time, 2)
metrics["Active_Margin_Ratio"] = margin_ratio
return cloud_result, routed_imgs, metrics
# ==========================================
# Gradio UI Design
# ==========================================
with gr.Blocks(title="MANN-Engram Showcase", theme=gr.themes.Base()) as demo:
gr.Markdown("# 🧠 MANN-Engram: Edge-Cloud Multimodal Semantic Router")
gr.Markdown("> **A Privacy-First, Zero-Hallucination Shield for Clinical Vision-Language Models.**")
with gr.Row():
# Left Column: Inputs & Settings
with gr.Column(scale=1):
gr.Markdown("### ⚙️ Engine Settings")
hf_token = gr.Textbox(label="Hugging Face API Token (For Cloud Brain)", type="password", placeholder="hf_xxxxxxxx...")
margin_slider = gr.Slider(
minimum=0.05, maximum=0.40, step=0.05, value=0.15,
label="Routing Tolerance (Margin Ratio)",
info="Lower (0.05) = Sniper Mode (Extreme Precision). Higher (0.30) = Cluster Mode (Recalls multiple related views)."
)
gr.Markdown("### 📥 Patient Data Dump")
narrative_input = gr.Textbox(
label="Messy Clinical Narrative", lines=8,
placeholder="Paste the chaotic patient complaint and history here..."
)
image_input = gr.File(label="Upload Unorganized Scans (Images)", file_count="multiple", type="filepath")
run_btn = gr.Button("🚀 Execute Routing Pipeline", variant="primary")
# Right Column: Outputs
with gr.Column(scale=1):
gr.Markdown("### ☁️ Cloud Output: Dual-Track Distillation")
cloud_output = gr.JSON(label="Purified Intent & Visual Query")
gr.Markdown("### 🛡️ Edge Output: Routed Core Evidence")
routed_gallery = gr.Gallery(label="Surgically Selected Scans", columns=2, object_fit="contain", height=400)
gr.Markdown("### 📊 Telemetry & Metrics")
metrics_output = gr.JSON(label="Routing Diagnostics")
# Wire up the button
run_btn.click(
fn=execute_pipeline,
inputs=[hf_token, narrative_input, image_input, margin_slider],
outputs=[cloud_output, routed_gallery, metrics_output]
)
if __name__ == "__main__":
demo.launch() |