File size: 7,323 Bytes
b662a8d
304d804
 
 
 
 
1bde015
b662a8d
304d804
 
 
 
 
 
 
 
 
b662a8d
304d804
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b662a8d
4e0ab32
304d804
 
 
 
 
 
 
 
 
 
 
 
1bde015
304d804
 
 
 
 
 
 
 
1bde015
304d804
1bde015
304d804
 
b662a8d
304d804
 
 
 
b662a8d
304d804
 
 
b662a8d
304d804
 
b662a8d
304d804
 
 
 
 
d4a27b4
304d804
 
 
b662a8d
304d804
b662a8d
 
304d804
b662a8d
304d804
 
 
b662a8d
 
304d804
 
 
 
6b2e38e
304d804
 
 
 
4e0ab32
6b2e38e
304d804
 
 
 
4e0ab32
304d804
d4a27b4
304d804
4e0ab32
304d804
 
 
 
 
 
 
 
 
 
b662a8d
304d804
 
 
 
 
b662a8d
 
 
1bde015
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import gradio as gr
import torch
import json
import time
from PIL import Image
from transformers import AutoProcessor, AutoModel
from huggingface_hub import InferenceClient

# ==========================================
# Phase 2: Edge-Side Tensor Router (SiGLIP)
# ==========================================
class SiGLIPRouter:
    def __init__(self):
        print("Loading Edge Routing Engine (SiGLIP-So400M)...")
        # Load locally for edge-simulated routing
        self.processor = AutoProcessor.from_pretrained("google/siglip-so400m-patch14-384")
        self.model = AutoModel.from_pretrained("google/siglip-so400m-patch14-384")

    def route_evidence(self, visual_query, image_paths, margin_ratio=0.15, absolute_floor=0.1):
        """
        Relative Margin Thresholding Engine
        Routes images based on dynamic confidence window relative to the best match.
        """
        if not image_paths:
            return [], {}
            
        # Convert file paths to PIL Images
        images = [Image.open(img).convert("RGB") for img in image_paths]
        
        # Extract multimodal tensors
        inputs = self.processor(text=[visual_query], images=images, padding="max_length", return_tensors="pt")
        
        with torch.no_grad():
            outputs = self.model(**inputs)
            
        # Get Sigmoid matching probabilities
        logits_per_image = outputs.logits_per_image
        probs = torch.sigmoid(logits_per_image).squeeze()
        
        # Handle single image fallback
        if probs.dim() == 0:
            probs = probs.unsqueeze(0)
            
        # Calculate dynamic threshold based on the Anchor (Highest Prob)
        max_prob = torch.max(probs).item()
        dynamic_threshold = max(max_prob * (1.0 - margin_ratio), absolute_floor)
        
        # Precision Pruning
        routed_images = []
        for idx, prob in enumerate(probs):
            if prob.item() >= dynamic_threshold:
                routed_images.append(image_paths[idx])
                
        metrics = {
            "Anchor_Probability (Max)": round(max_prob, 4),
            "Dynamic_Threshold": round(dynamic_threshold, 4),
            "Total_Candidates": len(image_paths),
            "Images_Retained": len(routed_images)
        }
        
        return routed_images, metrics

# ==========================================
# Phase 1: Cloud-Side Distillation (Qwen)
# ==========================================
def extract_intent_and_query(clinical_text, hf_token):
    """
    Dual-track extraction using Qwen-72B via Hugging Face Inference API.
    Translates messy text into Clinical Intent and Visual Target.
    """
    if not hf_token:
        return {"error": "Missing Hugging Face API Token."}, "Error"

    client = InferenceClient("Qwen/Qwen2.5-72B-Instruct", token=hf_token)
    
    system_prompt = """You are an expert Clinical Triage AI and a Multimodal Routing Specialist. 
    Your task is to analyze messy, noisy patient narratives and output a strictly formatted JSON object with two fields.
    1. "Clinical_Intent": A purified medical summary of the core issue.
    2. "Visual_Query": An extreme extraction of visual, anatomical, and radiological keywords relevant ONLY to the core issue. Think like an image-recognition model. Use nouns and modalities (e.g., 'Brain MRA, skull, Willis circle'). DO NOT use abstract symptoms like 'headache'.
    Output ONLY valid JSON."""

    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": f"Patient Narrative:\n{clinical_text}"}
    ]

    try:
        response = client.chat_completion(messages=messages, max_tokens=300)
        content = response.choices[0].message.content
        
        # Clean markdown formatting if present
        clean_content = content.replace("```json", "").replace("```", "").strip()
        result = json.loads(clean_content)
        
        return result, result.get("Visual_Query", "medical scan")
    except Exception as e:
        return {"error": f"Cloud Distillation Failed: {str(e)}"}, "medical scan"

# Initialize local router
router = SiGLIPRouter()

# ==========================================
# Main Execution Pipeline
# ==========================================
def execute_pipeline(hf_token, narrative, images, margin_ratio):
    start_time = time.time()
    
    if not images:
        return {"Error": "No images uploaded."}, [], {"Status": "Failed"}

    # 1. Cloud Intelligence (Linguistic Distillation)
    cloud_result, visual_query = extract_intent_and_query(narrative, hf_token)
    
    if "error" in cloud_result:
        return cloud_result, [], {"Status": "Cloud API Error"}

    # 2. Edge Intelligence (Tensor Routing)
    routed_imgs, metrics = router.route_evidence(visual_query, images, margin_ratio=margin_ratio)
    
    end_time = time.time()
    metrics["Total_Latency (s)"] = round(end_time - start_time, 2)
    metrics["Active_Margin_Ratio"] = margin_ratio
    
    return cloud_result, routed_imgs, metrics

# ==========================================
# Gradio UI Design
# ==========================================
with gr.Blocks(title="MANN-Engram Showcase", theme=gr.themes.Base()) as demo:
    gr.Markdown("# 🧠 MANN-Engram: Edge-Cloud Multimodal Semantic Router")
    gr.Markdown("> **A Privacy-First, Zero-Hallucination Shield for Clinical Vision-Language Models.**")
    
    with gr.Row():
        # Left Column: Inputs & Settings
        with gr.Column(scale=1):
            gr.Markdown("### ⚙️ Engine Settings")
            hf_token = gr.Textbox(label="Hugging Face API Token (For Cloud Brain)", type="password", placeholder="hf_xxxxxxxx...")
            
            margin_slider = gr.Slider(
                minimum=0.05, maximum=0.40, step=0.05, value=0.15, 
                label="Routing Tolerance (Margin Ratio)",
                info="Lower (0.05) = Sniper Mode (Extreme Precision). Higher (0.30) = Cluster Mode (Recalls multiple related views)."
            )
            
            gr.Markdown("### 📥 Patient Data Dump")
            narrative_input = gr.Textbox(
                label="Messy Clinical Narrative", lines=8,
                placeholder="Paste the chaotic patient complaint and history here..."
            )
            image_input = gr.File(label="Upload Unorganized Scans (Images)", file_count="multiple", type="filepath")
            
            run_btn = gr.Button("🚀 Execute Routing Pipeline", variant="primary")

        # Right Column: Outputs
        with gr.Column(scale=1):
            gr.Markdown("### ☁️ Cloud Output: Dual-Track Distillation")
            cloud_output = gr.JSON(label="Purified Intent & Visual Query")
            
            gr.Markdown("### 🛡️ Edge Output: Routed Core Evidence")
            routed_gallery = gr.Gallery(label="Surgically Selected Scans", columns=2, object_fit="contain", height=400)
            
            gr.Markdown("### 📊 Telemetry & Metrics")
            metrics_output = gr.JSON(label="Routing Diagnostics")

    # Wire up the button
    run_btn.click(
        fn=execute_pipeline,
        inputs=[hf_token, narrative_input, image_input, margin_slider],
        outputs=[cloud_output, routed_gallery, metrics_output]
    )

if __name__ == "__main__":
    demo.launch()