SaniaE commited on
Commit
8b9f879
·
verified ·
1 Parent(s): 4debe0a

updated endpoint logic

Browse files
Files changed (1) hide show
  1. app.py +73 -57
app.py CHANGED
@@ -83,94 +83,110 @@ def _generate_sync(m_name, image, temp=0.7):
83
 
84
  # --- Endpoint 1: The Multi-Perspective Generator ---
85
 
86
- @app.post("/generate-caption")
87
- async def generate_caption(file: UploadFile = File(...), temp: float = Query(0.7)):
88
- image_bytes = await file.read()
89
- image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
 
 
 
 
 
90
 
91
- # Run both architectures in parallel
92
- tasks = [
93
- asyncio.to_thread(_generate_sync, "blip", image, temp),
94
- asyncio.to_thread(_generate_sync, "vit", image, temp)
95
- ]
96
  captions = await asyncio.gather(*tasks)
97
 
98
- return {
99
- "blip_caption": captions[0],
100
- "vit_git_caption": captions[1]
101
- }
102
 
103
- # --- Endpoint 2: The Saliency Explorer (XAI Glow) ---
104
 
105
- @app.post("/saliency-explorer")
106
- async def get_saliency_map(file: UploadFile = File(...), query_text: str = Query(...)):
107
  image_bytes = await file.read()
108
  orig_img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
109
 
110
  blip = MODELS["blip"]
111
- inputs = blip["processor"](images=orig_img, text=query_text, return_tensors="pt").to(DEVICE)
112
 
113
  with torch.no_grad():
114
- vision_hidden = blip["model"].vision_model(inputs.pixel_values).last_hidden_state
115
- outputs = blip["model"].text_decoder(
116
- input_ids=inputs.input_ids,
117
- attention_mask=inputs.attention_mask,
118
- encoder_hidden_states=vision_hidden,
119
  output_attentions=True
120
  )
121
 
122
- # Slicing out the [CLS] token from cross-attentions
123
- cross_attentions = outputs.cross_attentions[-1]
124
- mask_1d = cross_attentions[0, :, 1:-1, 1:].mean(dim=(0, 1))
 
 
 
 
 
 
125
  grid_size = int(np.sqrt(mask_1d.shape[-1]))
126
  mask = mask_1d.view(grid_size, grid_size).cpu().numpy()
127
 
128
- # Normalization & XAI Glow Application
129
  mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-8)
130
  mask_pill = Image.fromarray((mask * 255).astype('uint8')).resize(orig_img.size, resample=Image.BICUBIC)
131
- mask_pill = mask_pill.filter(ImageFilter.GaussianBlur(radius=12))
132
 
133
- heatmap_rgba = plt.get_cmap('jet')(np.array(mask_pill)/255.0)
134
  heatmap_img = Image.fromarray((heatmap_rgba[:, :, :3] * 255).astype('uint8')).convert("RGB")
135
- blended_img = Image.blend(orig_img, heatmap_img, alpha=0.5)
 
 
136
 
137
  buf = io.BytesIO()
138
  blended_img.save(buf, format="PNG")
139
  buf.seek(0)
140
  return StreamingResponse(buf, media_type="image/png")
141
 
 
142
  # --- Endpoint 3: Internal Debate (Audit Mode) ---
143
 
144
- @app.post("/internal-debate")
145
- async def internal_debate(file: UploadFile = File(...), user_prompt: str = Query(...)):
146
- image_bytes = await file.read()
147
- image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
148
 
149
- # 1. Gather model perceptions
150
- blip_caption = await asyncio.to_thread(_generate_sync, "blip", image)
151
- vit_caption = await asyncio.to_thread(_generate_sync, "vit", image)
152
-
153
- # 2. Semantic Embedding Logic
154
- blip_data = MODELS["blip"]
155
- def get_emb(text):
156
- inputs = blip_data["processor"](text=text, return_tensors="pt", padding=True).to(DEVICE)
 
 
157
  with torch.no_grad():
158
- return F.normalize(blip_data["model"].text_decoder.bert(**inputs).last_hidden_state.mean(dim=1), p=2, dim=-1)
 
 
 
159
 
160
- u_emb = get_emb(user_prompt)
161
- b_emb = get_emb(blip_caption)
162
- v_emb = get_emb(vit_caption)
 
 
 
163
 
164
- # 3. MLE Calibration (Jaccard Weighting)
165
- def calibrate(emb1, emb2, t1, t2):
166
- s1, s2 = set(t1.lower().split()), set(t2.lower().split())
167
- jaccard = len(s1 & s2) / len(s1 | s2) if s1 | s2 else 0
168
- cosine = torch.matmul(emb1, emb2.T).item()
169
- return (cosine * 0.4) + (jaccard * 0.6)
170
 
171
- score_blip = calibrate(u_emb, b_emb, user_prompt, blip_caption)
172
- score_vit = calibrate(u_emb, v_emb, user_prompt, vit_caption)
173
- consensus = calibrate(b_emb, v_emb, blip_caption, vit_caption)
 
 
 
 
174
 
175
  return {
176
  "perspectives": {
@@ -179,9 +195,9 @@ async def internal_debate(file: UploadFile = File(...), user_prompt: str = Query
179
  "vit_git_view": vit_caption
180
  },
181
  "audit_metrics": {
182
- "user_vs_blip": round(score_blip, 4),
183
- "user_vs_vit": round(score_vit, 4),
184
  "inter_model_consensus": round(consensus, 4)
185
  },
186
- "verdict": "Consensus" if consensus > 0.65 else "Perspective Divergence"
187
  }
 
83
 
84
  # --- Endpoint 1: The Multi-Perspective Generator ---
85
 
86
+ @app.post("/generate")
87
+ async def generate_endpoint(
88
+ file: UploadFile = File(...),
89
+ temp: float = Query(0.8),
90
+ top_k: int = Query(50),
91
+ top_p: float = Query(0.9)
92
+ ):
93
+ image = Image.open(file.file).convert("RGB")
94
+ available = ["blip", "vit"]
95
 
96
+ # Generate 5 captions using a mix of models
97
+ model_selection = random.choices(available, k=5)
98
+ tasks = [asyncio.to_thread(_generate_sync, m, image, temp, top_k, top_p) for m in model_selection]
 
 
99
  captions = await asyncio.gather(*tasks)
100
 
101
+ return {"captions": captions, "architectures": model_selection}
 
 
 
102
 
103
+ # --- Endpoint 2: Objective Vision Saliency (Static Image Perception) ---
104
 
105
+ @app.post("/saliency-explorer/vision")
106
+ async def get_objective_saliency(file: UploadFile = File(...)):
107
  image_bytes = await file.read()
108
  orig_img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
109
 
110
  blip = MODELS["blip"]
111
+ inputs = blip["processor"](images=orig_img, return_tensors="pt").to(DEVICE)
112
 
113
  with torch.no_grad():
114
+ # Capturing Self-Attention from the Vision Encoder itself
115
+ # This shows what the model finds interesting in the image, regardless of prompt
116
+ outputs = blip["model"].vision_model(
117
+ inputs.pixel_values,
 
118
  output_attentions=True
119
  )
120
 
121
+ # Last layer attention: (batch, heads, patches, patches)
122
+ attentions = outputs.attentions[-1]
123
+
124
+ # Average across heads and focus on CLS token's view of the patches
125
+ # Patch grid for BLIP-Large is typically 24x24 (576 patches + 1 CLS)
126
+ nh = attentions.shape[1]
127
+ attentional_map = attentions[0, :, 0, 1:].reshape(nh, -1)
128
+ mask_1d = attentional_map.mean(dim=0)
129
+
130
  grid_size = int(np.sqrt(mask_1d.shape[-1]))
131
  mask = mask_1d.view(grid_size, grid_size).cpu().numpy()
132
 
133
+ # Normalization and High-Contrast "Heat"
134
  mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-8)
135
  mask_pill = Image.fromarray((mask * 255).astype('uint8')).resize(orig_img.size, resample=Image.BICUBIC)
136
+ mask_pill = mask_pill.filter(ImageFilter.GaussianBlur(radius=10))
137
 
138
+ heatmap_rgba = plt.get_cmap('magma')(np.array(mask_pill)/255.0)
139
  heatmap_img = Image.fromarray((heatmap_rgba[:, :, :3] * 255).astype('uint8')).convert("RGB")
140
+
141
+ # Blending at 0.6 alpha to make the "Model's Focus" pop
142
+ blended_img = Image.blend(orig_img, heatmap_img, alpha=0.6)
143
 
144
  buf = io.BytesIO()
145
  blended_img.save(buf, format="PNG")
146
  buf.seek(0)
147
  return StreamingResponse(buf, media_type="image/png")
148
 
149
+ # --- Endpoint 3: Perspective Auditor (Internal Debate) ---
150
  # --- Endpoint 3: Internal Debate (Audit Mode) ---
151
 
152
+ @app.post("/audit-perspective")
153
+ async def audit_perspective(file: UploadFile = File(...), user_prompt: str = Query(...)):
154
+ image = Image.open(file.file).convert("RGB")
 
155
 
156
+ # Run both models to get the "Internal Debate"
157
+ blip_caption = await asyncio.to_thread(_generate_sync, "blip", image, 0.7, 50, 0.9)
158
+ vit_caption = await asyncio.to_thread(_generate_sync, "vit", image, 0.7, 50, 0.9)
159
+
160
+ def get_metrics(target, reference):
161
+ # 1. Semantic Embedding (The "Vibe" check)
162
+ blip = MODELS["blip"]
163
+ t_in = blip["processor"](text=target, return_tensors="pt", padding=True).to(DEVICE)
164
+ r_in = blip["processor"](text=reference, return_tensors="pt", padding=True).to(DEVICE)
165
+
166
  with torch.no_grad():
167
+ t_emb = F.normalize(blip["model"].text_decoder.bert(**t_in).last_hidden_state.mean(dim=1), p=2, dim=-1)
168
+ r_emb = F.normalize(blip["model"].text_decoder.bert(**r_in).last_hidden_state.mean(dim=1), p=2, dim=-1)
169
+
170
+ cosine_sim = torch.matmul(t_emb, r_emb.T).item()
171
 
172
+ # 2. Jaccard Calibration (The "Accuracy" check - 70% weight)
173
+ t_words = set(target.lower().replace(",", "").split())
174
+ r_words = set(reference.lower().replace(",", "").split())
175
+ jaccard = len(t_words & r_words) / len(t_words | r_words) if t_words | r_words else 0
176
+
177
+ return (cosine_sim * 0.3) + (jaccard * 0.7)
178
 
179
+ user_vs_blip = get_metrics(user_prompt, blip_caption)
180
+ user_vs_vit = get_metrics(user_prompt, vit_caption)
181
+ consensus = get_metrics(blip_caption, vit_caption)
 
 
 
182
 
183
+ # XAI Verdict Logic
184
+ if consensus < 0.5:
185
+ verdict = "Model Confusion: High Uncertainty"
186
+ elif user_vs_blip < 0.6:
187
+ verdict = "Perspective Divergence: Prompt Mismatch"
188
+ else:
189
+ verdict = "Verified: Strong Alignment"
190
 
191
  return {
192
  "perspectives": {
 
195
  "vit_git_view": vit_caption
196
  },
197
  "audit_metrics": {
198
+ "user_vs_blip": round(user_vs_blip, 4),
199
+ "user_vs_vit": round(user_vs_vit, 4),
200
  "inter_model_consensus": round(consensus, 4)
201
  },
202
+ "verdict": verdict
203
  }