SaniaE commited on
Commit
40e2d6d
·
verified ·
1 Parent(s): edccc41

added more endpoints

Browse files
Files changed (1) hide show
  1. app.py +124 -1
app.py CHANGED
@@ -10,6 +10,12 @@ from transformers import (
10
  BlipProcessor, BlipForConditionalGeneration,
11
  ViTImageProcessor, AutoProcessor, AutoModelForCausalLM
12
  )
 
 
 
 
 
 
13
 
14
  app = FastAPI()
15
 
@@ -147,4 +153,121 @@ async def ui_tester(file: UploadFile = File(...), description: str = Query(...))
147
  },
148
  "status": "Match Found" if confidence_score > 55 else "Partial Match" if confidence_score > 30 else "No Match",
149
  "is_valid": confidence_score > 55
150
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  BlipProcessor, BlipForConditionalGeneration,
11
  ViTImageProcessor, AutoProcessor, AutoModelForCausalLM
12
  )
13
+ import torch.nn.functional as F
14
+ import numpy as np
15
+ import cv2
16
+ import io
17
+ from fastapi.responses import StreamingResponse
18
+
19
 
20
  app = FastAPI()
21
 
 
153
  },
154
  "status": "Match Found" if confidence_score > 55 else "Partial Match" if confidence_score > 30 else "No Match",
155
  "is_valid": confidence_score > 55
156
+ }
157
+
158
+
159
+ @app.post("/saliency-explorer")
160
+ async def saliency_explorer(file: UploadFile = File(...), query_text: str = Query(...)):
161
+ image = Image.open(file.file).convert("RGB")
162
+ blip = MODELS["blip"]
163
+
164
+ # Process inputs
165
+ inputs = blip["processor"](images=image, text=query_text, return_tensors="pt").to(DEVICE)
166
+ inputs.requires_grad = True # Enable gradients for saliency mapping
167
+
168
+ # Forward pass through the vision-language projector
169
+ outputs = blip["model"](**inputs, labels=inputs["input_ids"])
170
+ loss = outputs.loss
171
+ loss.backward()
172
+
173
+ # Extract gradients from the vision encoder's last layer
174
+ # Note: Using the last hidden state as a proxy for spatial importance
175
+ gradients = blip["model"].vision_model.embeddings.patch_embedding.weight.grad
176
+ pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])
177
+
178
+ # Generate heatmap
179
+ # In a real implementation, you would use Grad-CAM on the attention layers
180
+ # Here we simplify the spatial mapping for the demo response
181
+ heatmap = torch.mean(torch.abs(gradients), dim=1).squeeze().cpu().numpy()
182
+ heatmap = cv2.resize(heatmap, (image.size[0], image.size[1]))
183
+ heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min())
184
+
185
+ return {
186
+ "query": query_text,
187
+ "heatmap_data": heatmap.tolist(), # Send to frontend to overlay with CSS/Canvas
188
+ "explanation": f"Highlighted regions show where the model focused to validate '{query_text}'"
189
+ }
190
+
191
+ @app.post("/concept-ensemble")
192
+ async def concept_ensemble(file: UploadFile = File(...), user_prompt: str = Query(...)):
193
+ image = Image.open(file.file).convert("RGB")
194
+ blip = MODELS["blip"]
195
+
196
+ # 1. Get Model's Perceived Caption (Baseline)
197
+ inputs_gen = blip["processor"](images=image, return_tensors="pt").to(DEVICE)
198
+ generated_ids = blip["model"].generate(**inputs_gen, max_length=40)
199
+ model_caption = blip["processor"].decode(generated_ids[0], skip_special_tokens=True)
200
+
201
+ # 2. Generate Embeddings for the Matrix
202
+ # We compare User Prompt, Model Caption, and a 'Ground Truth' Visual Vector
203
+ texts = [user_prompt, model_caption]
204
+ inputs_text = blip["processor"](text=texts, return_tensors="pt", padding=True).to(DEVICE)
205
+
206
+ with torch.no_grad():
207
+ # Get text features
208
+ text_embeds = blip["model"].text_encoder(**inputs_text).last_hidden_state[:, 0, :]
209
+ # Get image features
210
+ image_embeds = blip["model"].vision_model(inputs_gen["pixel_values"]).last_hidden_state[:, 0, :]
211
+
212
+ # Normalize for Cosine Similarity
213
+ text_embeds = F.normalize(text_embeds, p=2, dim=-1)
214
+ image_embeds = F.normalize(image_embeds, p=2, dim=-1)
215
+
216
+ # Calculate Matrix: [Image vs User, Image vs Model, User vs Model]
217
+ sim_image_user = torch.matmul(image_embeds, text_embeds[0].T).item()
218
+ sim_image_model = torch.matmul(image_embeds, text_embeds[1].T).item()
219
+ sim_user_model = torch.matmul(text_embeds[0], text_embeds[1].T).item()
220
+
221
+ return {
222
+ "captions": {
223
+ "user": user_prompt,
224
+ "model": model_caption
225
+ },
226
+ "similarity_matrix": {
227
+ "visual_alignment_user": round(sim_image_user, 4),
228
+ "visual_alignment_model": round(sim_image_model, 4),
229
+ "semantic_overlap": round(sim_user_model, 4)
230
+ },
231
+ "ensemble_verdict": "Consensus" if sim_user_model > 0.8 else "Perspective Divergence"
232
+ }
233
+
234
+
235
+ @app.post("/saliency-explorer/image")
236
+ async def get_saliency_heatmap(file: UploadFile = File(...), query_text: str = Query(...)):
237
+ # 1. Load and process image
238
+ contents = await file.read()
239
+ nparr = np.frombuffer(contents, np.uint8)
240
+ orig_img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
241
+ image_rgb = cv2.cvtColor(orig_img, cv2.COLOR_BGR2RGB)
242
+ pil_img = Image.fromarray(image_rgb)
243
+
244
+ blip = MODELS["blip"]
245
+ inputs = blip["processor"](images=pil_img, text=query_text, return_tensors="pt").to(DEVICE)
246
+
247
+ # 2. Extract Attention/Gradients
248
+ # We target the cross-attention layer to see where the text 'queries' the image
249
+ inputs.pixel_values.requires_grad = True
250
+ outputs = blip["model"](**inputs, labels=inputs["input_ids"])
251
+ loss = outputs.loss
252
+ loss.backward()
253
+
254
+ # Generate Saliency from gradients
255
+ grad = inputs.pixel_values.grad.abs().max(dim=1)[0][0].cpu().numpy()
256
+
257
+ # 3. Create Heatmap Overlay
258
+ # Normalize gradients to 0-255
259
+ grad = (grad - grad.min()) / (grad.max() - grad.min() + 1e-8)
260
+ grad = (grad * 255).astype(np.uint8)
261
+
262
+ # Resize to original image size
263
+ heatmap = cv2.resize(grad, (orig_img.shape[1], orig_img.shape[0]))
264
+
265
+ # Apply Color Map (JET or VIRIDIS look very 'Pinterest-chic' / Pro)
266
+ heatmap_color = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
267
+
268
+ # Superimpose heatmap onto original image (0.6 original, 0.4 heatmap)
269
+ result_img = cv2.addWeighted(orig_img, 0.6, heatmap_color, 0.4, 0)
270
+
271
+ # 4. Stream the image back
272
+ res, im_png = cv2.imencode(".png", result_img)
273
+ return StreamingResponse(io.BytesIO(im_png.tobytes()), media_type="image/png")