wuff-mann's picture
Update app.py
304d804 verified
import gradio as gr
import torch
import json
import time
from PIL import Image
from transformers import AutoProcessor, AutoModel
from huggingface_hub import InferenceClient
# ==========================================
# Phase 2: Edge-Side Tensor Router (SiGLIP)
# ==========================================
class SiGLIPRouter:
def __init__(self):
print("Loading Edge Routing Engine (SiGLIP-So400M)...")
# Load locally for edge-simulated routing
self.processor = AutoProcessor.from_pretrained("google/siglip-so400m-patch14-384")
self.model = AutoModel.from_pretrained("google/siglip-so400m-patch14-384")
def route_evidence(self, visual_query, image_paths, margin_ratio=0.15, absolute_floor=0.1):
"""
Relative Margin Thresholding Engine
Routes images based on dynamic confidence window relative to the best match.
"""
if not image_paths:
return [], {}
# Convert file paths to PIL Images
images = [Image.open(img).convert("RGB") for img in image_paths]
# Extract multimodal tensors
inputs = self.processor(text=[visual_query], images=images, padding="max_length", return_tensors="pt")
with torch.no_grad():
outputs = self.model(**inputs)
# Get Sigmoid matching probabilities
logits_per_image = outputs.logits_per_image
probs = torch.sigmoid(logits_per_image).squeeze()
# Handle single image fallback
if probs.dim() == 0:
probs = probs.unsqueeze(0)
# Calculate dynamic threshold based on the Anchor (Highest Prob)
max_prob = torch.max(probs).item()
dynamic_threshold = max(max_prob * (1.0 - margin_ratio), absolute_floor)
# Precision Pruning
routed_images = []
for idx, prob in enumerate(probs):
if prob.item() >= dynamic_threshold:
routed_images.append(image_paths[idx])
metrics = {
"Anchor_Probability (Max)": round(max_prob, 4),
"Dynamic_Threshold": round(dynamic_threshold, 4),
"Total_Candidates": len(image_paths),
"Images_Retained": len(routed_images)
}
return routed_images, metrics
# ==========================================
# Phase 1: Cloud-Side Distillation (Qwen)
# ==========================================
def extract_intent_and_query(clinical_text, hf_token):
"""
Dual-track extraction using Qwen-72B via Hugging Face Inference API.
Translates messy text into Clinical Intent and Visual Target.
"""
if not hf_token:
return {"error": "Missing Hugging Face API Token."}, "Error"
client = InferenceClient("Qwen/Qwen2.5-72B-Instruct", token=hf_token)
system_prompt = """You are an expert Clinical Triage AI and a Multimodal Routing Specialist.
Your task is to analyze messy, noisy patient narratives and output a strictly formatted JSON object with two fields.
1. "Clinical_Intent": A purified medical summary of the core issue.
2. "Visual_Query": An extreme extraction of visual, anatomical, and radiological keywords relevant ONLY to the core issue. Think like an image-recognition model. Use nouns and modalities (e.g., 'Brain MRA, skull, Willis circle'). DO NOT use abstract symptoms like 'headache'.
Output ONLY valid JSON."""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": f"Patient Narrative:\n{clinical_text}"}
]
try:
response = client.chat_completion(messages=messages, max_tokens=300)
content = response.choices[0].message.content
# Clean markdown formatting if present
clean_content = content.replace("```json", "").replace("```", "").strip()
result = json.loads(clean_content)
return result, result.get("Visual_Query", "medical scan")
except Exception as e:
return {"error": f"Cloud Distillation Failed: {str(e)}"}, "medical scan"
# Initialize local router
router = SiGLIPRouter()
# ==========================================
# Main Execution Pipeline
# ==========================================
def execute_pipeline(hf_token, narrative, images, margin_ratio):
start_time = time.time()
if not images:
return {"Error": "No images uploaded."}, [], {"Status": "Failed"}
# 1. Cloud Intelligence (Linguistic Distillation)
cloud_result, visual_query = extract_intent_and_query(narrative, hf_token)
if "error" in cloud_result:
return cloud_result, [], {"Status": "Cloud API Error"}
# 2. Edge Intelligence (Tensor Routing)
routed_imgs, metrics = router.route_evidence(visual_query, images, margin_ratio=margin_ratio)
end_time = time.time()
metrics["Total_Latency (s)"] = round(end_time - start_time, 2)
metrics["Active_Margin_Ratio"] = margin_ratio
return cloud_result, routed_imgs, metrics
# ==========================================
# Gradio UI Design
# ==========================================
with gr.Blocks(title="MANN-Engram Showcase", theme=gr.themes.Base()) as demo:
gr.Markdown("# 🧠 MANN-Engram: Edge-Cloud Multimodal Semantic Router")
gr.Markdown("> **A Privacy-First, Zero-Hallucination Shield for Clinical Vision-Language Models.**")
with gr.Row():
# Left Column: Inputs & Settings
with gr.Column(scale=1):
gr.Markdown("### ⚙️ Engine Settings")
hf_token = gr.Textbox(label="Hugging Face API Token (For Cloud Brain)", type="password", placeholder="hf_xxxxxxxx...")
margin_slider = gr.Slider(
minimum=0.05, maximum=0.40, step=0.05, value=0.15,
label="Routing Tolerance (Margin Ratio)",
info="Lower (0.05) = Sniper Mode (Extreme Precision). Higher (0.30) = Cluster Mode (Recalls multiple related views)."
)
gr.Markdown("### 📥 Patient Data Dump")
narrative_input = gr.Textbox(
label="Messy Clinical Narrative", lines=8,
placeholder="Paste the chaotic patient complaint and history here..."
)
image_input = gr.File(label="Upload Unorganized Scans (Images)", file_count="multiple", type="filepath")
run_btn = gr.Button("🚀 Execute Routing Pipeline", variant="primary")
# Right Column: Outputs
with gr.Column(scale=1):
gr.Markdown("### ☁️ Cloud Output: Dual-Track Distillation")
cloud_output = gr.JSON(label="Purified Intent & Visual Query")
gr.Markdown("### 🛡️ Edge Output: Routed Core Evidence")
routed_gallery = gr.Gallery(label="Surgically Selected Scans", columns=2, object_fit="contain", height=400)
gr.Markdown("### 📊 Telemetry & Metrics")
metrics_output = gr.JSON(label="Routing Diagnostics")
# Wire up the button
run_btn.click(
fn=execute_pipeline,
inputs=[hf_token, narrative_input, image_input, margin_slider],
outputs=[cloud_output, routed_gallery, metrics_output]
)
if __name__ == "__main__":
demo.launch()