SaniaE commited on
Commit
282248d
·
verified ·
1 Parent(s): 40e2d6d

updated endpoints

Browse files
Files changed (1) hide show
  1. app.py +43 -73
app.py CHANGED
@@ -15,6 +15,7 @@ import numpy as np
15
  import cv2
16
  import io
17
  from fastapi.responses import StreamingResponse
 
18
 
19
 
20
  app = FastAPI()
@@ -155,65 +156,34 @@ async def ui_tester(file: UploadFile = File(...), description: str = Query(...))
155
  "is_valid": confidence_score > 55
156
  }
157
 
158
-
159
- @app.post("/saliency-explorer")
160
- async def saliency_explorer(file: UploadFile = File(...), query_text: str = Query(...)):
161
- image = Image.open(file.file).convert("RGB")
162
- blip = MODELS["blip"]
163
-
164
- # Process inputs
165
- inputs = blip["processor"](images=image, text=query_text, return_tensors="pt").to(DEVICE)
166
- inputs.requires_grad = True # Enable gradients for saliency mapping
167
-
168
- # Forward pass through the vision-language projector
169
- outputs = blip["model"](**inputs, labels=inputs["input_ids"])
170
- loss = outputs.loss
171
- loss.backward()
172
-
173
- # Extract gradients from the vision encoder's last layer
174
- # Note: Using the last hidden state as a proxy for spatial importance
175
- gradients = blip["model"].vision_model.embeddings.patch_embedding.weight.grad
176
- pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])
177
-
178
- # Generate heatmap
179
- # In a real implementation, you would use Grad-CAM on the attention layers
180
- # Here we simplify the spatial mapping for the demo response
181
- heatmap = torch.mean(torch.abs(gradients), dim=1).squeeze().cpu().numpy()
182
- heatmap = cv2.resize(heatmap, (image.size[0], image.size[1]))
183
- heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min())
184
-
185
- return {
186
- "query": query_text,
187
- "heatmap_data": heatmap.tolist(), # Send to frontend to overlay with CSS/Canvas
188
- "explanation": f"Highlighted regions show where the model focused to validate '{query_text}'"
189
- }
190
-
191
  @app.post("/concept-ensemble")
192
  async def concept_ensemble(file: UploadFile = File(...), user_prompt: str = Query(...)):
193
  image = Image.open(file.file).convert("RGB")
194
  blip = MODELS["blip"]
195
 
196
- # 1. Get Model's Perceived Caption (Baseline)
197
  inputs_gen = blip["processor"](images=image, return_tensors="pt").to(DEVICE)
198
- generated_ids = blip["model"].generate(**inputs_gen, max_length=40)
199
- model_caption = blip["processor"].decode(generated_ids[0], skip_special_tokens=True)
 
200
 
201
- # 2. Generate Embeddings for the Matrix
202
- # We compare User Prompt, Model Caption, and a 'Ground Truth' Visual Vector
203
  texts = [user_prompt, model_caption]
204
  inputs_text = blip["processor"](text=texts, return_tensors="pt", padding=True).to(DEVICE)
205
 
206
  with torch.no_grad():
207
- # Get text features
208
- text_embeds = blip["model"].text_encoder(**inputs_text).last_hidden_state[:, 0, :]
209
- # Get image features
210
- image_embeds = blip["model"].vision_model(inputs_gen["pixel_values"]).last_hidden_state[:, 0, :]
 
 
211
 
212
- # Normalize for Cosine Similarity
213
  text_embeds = F.normalize(text_embeds, p=2, dim=-1)
214
  image_embeds = F.normalize(image_embeds, p=2, dim=-1)
215
 
216
- # Calculate Matrix: [Image vs User, Image vs Model, User vs Model]
217
  sim_image_user = torch.matmul(image_embeds, text_embeds[0].T).item()
218
  sim_image_model = torch.matmul(image_embeds, text_embeds[1].T).item()
219
  sim_user_model = torch.matmul(text_embeds[0], text_embeds[1].T).item()
@@ -221,53 +191,53 @@ async def concept_ensemble(file: UploadFile = File(...), user_prompt: str = Quer
221
  return {
222
  "captions": {
223
  "user": user_prompt,
224
- "model": model_caption
225
  },
226
- "similarity_matrix": {
227
- "visual_alignment_user": round(sim_image_user, 4),
228
- "visual_alignment_model": round(sim_image_model, 4),
229
- "semantic_overlap": round(sim_user_model, 4)
230
  },
231
- "ensemble_verdict": "Consensus" if sim_user_model > 0.8 else "Perspective Divergence"
232
  }
233
 
234
-
235
  @app.post("/saliency-explorer/image")
236
  async def get_saliency_heatmap(file: UploadFile = File(...), query_text: str = Query(...)):
237
- # 1. Load and process image
238
- contents = await file.read()
239
- nparr = np.frombuffer(contents, np.uint8)
240
- orig_img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
241
- image_rgb = cv2.cvtColor(orig_img, cv2.COLOR_BGR2RGB)
242
- pil_img = Image.fromarray(image_rgb)
243
 
244
  blip = MODELS["blip"]
245
- inputs = blip["processor"](images=pil_img, text=query_text, return_tensors="pt").to(DEVICE)
246
 
247
- # 2. Extract Attention/Gradients
248
- # We target the cross-attention layer to see where the text 'queries' the image
249
  inputs.pixel_values.requires_grad = True
250
  outputs = blip["model"](**inputs, labels=inputs["input_ids"])
251
  loss = outputs.loss
252
  loss.backward()
253
 
254
- # Generate Saliency from gradients
255
  grad = inputs.pixel_values.grad.abs().max(dim=1)[0][0].cpu().numpy()
256
 
257
- # 3. Create Heatmap Overlay
258
- # Normalize gradients to 0-255
259
  grad = (grad - grad.min()) / (grad.max() - grad.min() + 1e-8)
260
- grad = (grad * 255).astype(np.uint8)
261
 
262
- # Resize to original image size
263
- heatmap = cv2.resize(grad, (orig_img.shape[1], orig_img.shape[0]))
 
 
 
 
 
264
 
265
- # Apply Color Map (JET or VIRIDIS look very 'Pinterest-chic' / Pro)
266
- heatmap_color = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
 
267
 
268
- # Superimpose heatmap onto original image (0.6 original, 0.4 heatmap)
269
- result_img = cv2.addWeighted(orig_img, 0.6, heatmap_color, 0.4, 0)
 
 
270
 
271
- # 4. Stream the image back
272
- res, im_png = cv2.imencode(".png", result_img)
273
- return StreamingResponse(io.BytesIO(im_png.tobytes()), media_type="image/png")
 
15
  import cv2
16
  import io
17
  from fastapi.responses import StreamingResponse
18
+ import matplotlib.pyplot as plt
19
 
20
 
21
  app = FastAPI()
 
156
  "is_valid": confidence_score > 55
157
  }
158
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  @app.post("/concept-ensemble")
160
  async def concept_ensemble(file: UploadFile = File(...), user_prompt: str = Query(...)):
161
  image = Image.open(file.file).convert("RGB")
162
  blip = MODELS["blip"]
163
 
164
+ # 1. Model Baseline (Generating its own perception)
165
  inputs_gen = blip["processor"](images=image, return_tensors="pt").to(DEVICE)
166
+ with torch.no_grad():
167
+ generated_ids = blip["model"].generate(**inputs_gen, max_length=40)
168
+ model_caption = blip["processor"].decode(generated_ids[0], skip_special_tokens=True)
169
 
170
+ # 2. Embedding Calculation
 
171
  texts = [user_prompt, model_caption]
172
  inputs_text = blip["processor"](text=texts, return_tensors="pt", padding=True).to(DEVICE)
173
 
174
  with torch.no_grad():
175
+ # Get pooled text and vision features
176
+ text_outputs = blip["model"].text_encoder(**inputs_text)
177
+ text_embeds = text_outputs.last_hidden_state[:, 0, :] # Use [CLS] token
178
+
179
+ vision_outputs = blip["model"].vision_model(inputs_gen["pixel_values"])
180
+ image_embeds = vision_outputs.last_hidden_state[:, 0, :]
181
 
182
+ # Normalize vectors for cosine similarity
183
  text_embeds = F.normalize(text_embeds, p=2, dim=-1)
184
  image_embeds = F.normalize(image_embeds, p=2, dim=-1)
185
 
186
+ # Similarity Matrix calculation
187
  sim_image_user = torch.matmul(image_embeds, text_embeds[0].T).item()
188
  sim_image_model = torch.matmul(image_embeds, text_embeds[1].T).item()
189
  sim_user_model = torch.matmul(text_embeds[0], text_embeds[1].T).item()
 
191
  return {
192
  "captions": {
193
  "user": user_prompt,
194
+ "model_best_guess": model_caption
195
  },
196
+ "similarity_scores": {
197
+ "visual_alignment_user": round(float(sim_image_user), 4),
198
+ "visual_alignment_model": round(float(sim_image_model), 4),
199
+ "semantic_overlap": round(float(sim_user_model), 4)
200
  },
201
+ "interpretation": "Strong Agreement" if sim_user_model > 0.85 else "Diverse Perspectives"
202
  }
203
 
 
204
  @app.post("/saliency-explorer/image")
205
  async def get_saliency_heatmap(file: UploadFile = File(...), query_text: str = Query(...)):
206
+ # 1. Load Image
207
+ image_bytes = await file.read()
208
+ orig_img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
 
 
 
209
 
210
  blip = MODELS["blip"]
211
+ inputs = blip["processor"](images=orig_img, text=query_text, return_tensors="pt").to(DEVICE)
212
 
213
+ # 2. Extract Gradients for Saliency
 
214
  inputs.pixel_values.requires_grad = True
215
  outputs = blip["model"](**inputs, labels=inputs["input_ids"])
216
  loss = outputs.loss
217
  loss.backward()
218
 
219
+ # Get max gradient across channels
220
  grad = inputs.pixel_values.grad.abs().max(dim=1)[0][0].cpu().numpy()
221
 
222
+ # 3. Create Heatmap with Matplotlib
223
+ # Normalize to [0, 1]
224
  grad = (grad - grad.min()) / (grad.max() - grad.min() + 1e-8)
 
225
 
226
+ # Apply color map (jet) and convert to RGBA
227
+ cm = plt.get_cmap('jet')
228
+ heatmap_rgba = cm(grad) # This creates an NxMx4 array
229
+
230
+ # Convert heatmap to PIL Image and resize to match original
231
+ heatmap_img = Image.fromarray((heatmap_rgba[:, :, :3] * 255).astype('uint8')).convert("RGB")
232
+ heatmap_img = heatmap_img.resize(orig_img.size, resample=Image.BILINEAR)
233
 
234
+ # 4. Blend Original + Heatmap
235
+ # 0.6 alpha for original, 0.4 for heatmap
236
+ blended_img = Image.blend(orig_img, heatmap_img, alpha=0.4)
237
 
238
+ # 5. Stream back
239
+ buf = io.BytesIO()
240
+ blended_img.save(buf, format="PNG")
241
+ buf.seek(0)
242
 
243
+ return StreamingResponse(buf, media_type="image/png")