SaniaE commited on
Commit
77a0b2f
·
verified ·
1 Parent(s): ace7c16

added optimizations

Browse files
Files changed (1) hide show
  1. app.py +61 -39
app.py CHANGED
@@ -6,11 +6,11 @@ import random
6
  import numpy as np
7
  import torch
8
  import torch.nn.functional as F
9
- import matplotlib.pyplot as plt
10
- from PIL import Image, ImageFilter
11
  from fastapi import FastAPI, UploadFile, File, Query
12
  from fastapi.middleware.cors import CORSMiddleware
13
- from fastapi.responses import StreamingResponse, Response
14
  from huggingface_hub import snapshot_download, login
15
 
16
  from transformers import (
@@ -21,14 +21,13 @@ from transformers import (
21
 
22
  app = FastAPI(title="XAI Auditor Ensemble with CLIP Jury")
23
 
24
- # Enable smooth frontend cross-origin header interceptions for performance metrics
25
  app.add_middleware(
26
  CORSMiddleware,
27
  allow_origins=["*"],
28
  allow_credentials=True,
29
  allow_methods=["*"],
30
  allow_headers=["*"],
31
- expose_headers=["X-Processing-Time", "X-Audit-Time", "X-Grounding-Verdict"]
32
  )
33
 
34
  # --- Configuration & Paths ---
@@ -81,47 +80,63 @@ async def startup_event():
81
  )
82
  }
83
 
84
- # 3. Load Fine-Tuned CLIP (Your Jury)
85
  cfg_c = MODEL_CONFIGS["clip"]
86
  MODELS["clip"] = {
87
  "model": CLIPModel.from_pretrained(os.path.join(local_dir, cfg_c["model_subfolder"])).to(DEVICE),
88
  "processor": CLIPProcessor.from_pretrained(os.path.join(local_dir, cfg_c["proc_subfolder"]))
89
  }
90
 
91
- print("All models synchronized. Auditor is active.")
92
 
93
  # --- Utilities ---
94
 
95
- def _generate_sync(m_name, image, temp, top_k, top_p):
96
- m_data = MODELS[m_name]
97
- if m_name == "vit":
98
- i_proc, t_proc = m_data["processor"]
99
- inputs = i_proc(images=image, return_tensors="pt").to(DEVICE)
100
- ids = m_data["model"].generate(**inputs, max_length=80, do_sample=True, temperature=temp, top_k=top_k, top_p=top_p)
101
- return t_proc.batch_decode(ids, skip_special_tokens=True)[0].strip()
102
- else:
103
- proc = m_data["processor"]
104
- inputs = proc(images=image, return_tensors="pt").to(DEVICE)
105
- ids = m_data["model"].generate(**inputs, max_length=80, do_sample=True, temperature=temp, top_k=top_k, top_p=top_p)
106
- return proc.batch_decode(ids, skip_special_tokens=True)[0].strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
  # --- Endpoints ---
109
 
110
  @app.post("/generate")
111
  async def generate_captions(
112
  file: UploadFile = File(...),
113
- temp: float = Query(0.8),
114
- top_k: int = Query(50),
115
  top_p: float = Query(0.9)
116
  ):
 
117
  start_time = time.perf_counter()
118
  image = Image.open(file.file).convert("RGB")
119
  architectures = ["blip", "vit"]
120
  selection = random.choices(architectures, k=5)
121
 
122
- # Offload generative sampling loop to a worker thread pool
123
- tasks = [asyncio.to_thread(_generate_sync, m, image, temp, top_k, top_p) for m in selection]
124
- captions = await asyncio.gather(*tasks)
125
 
126
  elapsed_time = time.perf_counter() - start_time
127
  print(f"[BENCHMARK] /generate ensemble turnaround: {elapsed_time:.4f}s")
@@ -137,6 +152,7 @@ async def generate_captions(
137
 
138
  @app.post("/saliency")
139
  async def get_vision_saliency(file: UploadFile = File(...)):
 
140
  start_time = time.perf_counter()
141
  image_bytes = await file.read()
142
  orig_img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
@@ -151,37 +167,44 @@ async def get_vision_saliency(file: UploadFile = File(...)):
151
  grid_size = int(np.sqrt(mask_1d.shape[-1]))
152
  mask = mask_1d.view(grid_size, grid_size).cpu().numpy()
153
 
 
154
  mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-8)
155
- mask_img = Image.fromarray((mask * 255).astype('uint8')).resize(orig_img.size, resample=Image.BICUBIC)
156
- mask_img = mask_img.filter(ImageFilter.GaussianBlur(radius=10))
157
 
158
- heatmap = plt.get_cmap('magma')(np.array(mask_img)/255.0)
159
- heatmap_img = Image.fromarray((heatmap[:, :, :3] * 255).astype('uint8')).convert("RGB")
160
- blended = Image.blend(orig_img, heatmap_img, alpha=0.6)
 
 
 
 
 
 
 
 
 
 
161
 
 
162
  buf = io.BytesIO()
163
- blended.save(buf, format="PNG")
164
  buf.seek(0)
165
 
166
  elapsed_time = time.perf_counter() - start_time
167
  print(f"[BENCHMARK] /saliency last-layer map turnaround: {elapsed_time:.4f}s")
168
 
169
- return StreamingResponse(
170
- buf,
171
- media_type="image/png",
172
- headers={"X-Processing-Time": f"{elapsed_time:.4f}"}
173
- )
174
 
175
  @app.post("/audit")
176
  async def internal_debate_audit(file: UploadFile = File(...), user_prompt: str = Query(...)):
 
177
  start_time = time.perf_counter()
178
  image_bytes = await file.read()
179
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
180
 
181
- # 1. Model Perception
182
- blip_caption = await asyncio.to_thread(_generate_sync, "blip", image, 0.7, 50, 0.9)
183
 
184
- # 2. CLIP Scoring (Multimodal Alignment)
185
  clip_m = MODELS["clip"]["model"]
186
  clip_p = MODELS["clip"]["processor"]
187
 
@@ -193,7 +216,6 @@ async def internal_debate_audit(file: UploadFile = File(...), user_prompt: str =
193
 
194
  u_score, m_score = float(probs[0]), float(probs[1])
195
 
196
- # 3. Decision Logic
197
  if u_score < 0.35:
198
  verdict = "Perspective Divergence: Intent not grounded in image."
199
  elif abs(u_score - m_score) < 0.15:
 
6
  import numpy as np
7
  import torch
8
  import torch.nn.functional as F
9
+ import cv2
10
+ from PIL import Image
11
  from fastapi import FastAPI, UploadFile, File, Query
12
  from fastapi.middleware.cors import CORSMiddleware
13
+ from fastapi.responses import StreamingResponse
14
  from huggingface_hub import snapshot_download, login
15
 
16
  from transformers import (
 
21
 
22
  app = FastAPI(title="XAI Auditor Ensemble with CLIP Jury")
23
 
 
24
  app.add_middleware(
25
  CORSMiddleware,
26
  allow_origins=["*"],
27
  allow_credentials=True,
28
  allow_methods=["*"],
29
  allow_headers=["*"],
30
+ expose_headers=["X-Processing-Time", "X-Audit-Time"]
31
  )
32
 
33
  # --- Configuration & Paths ---
 
80
  )
81
  }
82
 
83
+ # 3. Load Fine-Tuned CLIP
84
  cfg_c = MODEL_CONFIGS["clip"]
85
  MODELS["clip"] = {
86
  "model": CLIPModel.from_pretrained(os.path.join(local_dir, cfg_c["model_subfolder"])).to(DEVICE),
87
  "processor": CLIPProcessor.from_pretrained(os.path.join(local_dir, cfg_c["proc_subfolder"]))
88
  }
89
 
90
+ print("All models synchronized. Ensemble backend is active.")
91
 
92
  # --- Utilities ---
93
 
94
+ def _generate_sync_batch(selection, image, temp, top_k, top_p, max_len=45, do_sample=True):
95
+ """Processes generation sequentially to eliminate CPU context-switching overhead."""
96
+ captions = []
97
+ for m_name in selection:
98
+ m_data = MODELS[m_name]
99
+ if m_name == "vit":
100
+ i_proc, t_proc = m_data["processor"]
101
+ inputs = i_proc(images=image, return_tensors="pt").to(DEVICE)
102
+ # Cap max_new_tokens for snappier generation runtimes
103
+ ids = m_data["model"].generate(
104
+ **inputs, max_new_tokens=max_len, do_sample=do_sample,
105
+ temperature=temp if do_sample else None,
106
+ top_k=top_k if do_sample else None,
107
+ top_p=top_p if do_sample else None
108
+ )
109
+ caption = t_proc.batch_decode(ids, skip_special_tokens=True)[0].strip()
110
+ else:
111
+ proc = m_data["processor"]
112
+ inputs = proc(images=image, return_tensors="pt").to(DEVICE)
113
+ ids = m_data["model"].generate(
114
+ **inputs, max_new_tokens=max_len, do_sample=do_sample,
115
+ temperature=temp if do_sample else None,
116
+ top_k=top_k if do_sample else None,
117
+ top_p=top_p if do_sample else None
118
+ )
119
+ caption = proc.batch_decode(ids, skip_special_tokens=True)[0].strip()
120
+ captions.append(caption)
121
+ return captions
122
 
123
  # --- Endpoints ---
124
 
125
  @app.post("/generate")
126
  async def generate_captions(
127
  file: UploadFile = File(...),
128
+ temp: float = Query(0.7),
129
+ top_k: int = Query(40),
130
  top_p: float = Query(0.9)
131
  ):
132
+ """Generates 5 diverse captions using an optimized sequential pipeline."""
133
  start_time = time.perf_counter()
134
  image = Image.open(file.file).convert("RGB")
135
  architectures = ["blip", "vit"]
136
  selection = random.choices(architectures, k=5)
137
 
138
+ # Run loop sequentially inside a thread worker to safely dodge GIL contention
139
+ captions = await asyncio.to_thread(_generate_sync_batch, selection, image, temp, top_k, top_p, 45, True)
 
140
 
141
  elapsed_time = time.perf_counter() - start_time
142
  print(f"[BENCHMARK] /generate ensemble turnaround: {elapsed_time:.4f}s")
 
152
 
153
  @app.post("/saliency")
154
  async def get_vision_saliency(file: UploadFile = File(...)):
155
+ """Objective Saliency: Highly optimized native vision encoder self-attention mapping."""
156
  start_time = time.perf_counter()
157
  image_bytes = await file.read()
158
  orig_img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
 
167
  grid_size = int(np.sqrt(mask_1d.shape[-1]))
168
  mask = mask_1d.view(grid_size, grid_size).cpu().numpy()
169
 
170
+ # Normalize attention matrix
171
  mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-8)
 
 
172
 
173
+ # Vectorized OpenCV handling for super fast image processing
174
+ w, h = orig_img.size
175
+ mask_resized = cv2.resize(mask, (w, h), interpolation=cv2.INTER_CUBIC)
176
+ mask_blurred = cv2.GaussianBlur(mask_resized, (21, 21), 0)
177
+
178
+ # Convert normalized heatmap to standard color map space
179
+ heatmap_uint8 = np.uint8(255 * mask_blurred)
180
+ heatmap_bgr = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_MAGMA)
181
+ heatmap_rgb = cv2.cvtColor(heatmap_bgr, cv2.COLOR_BGR2RGB)
182
+
183
+ # Composite overlay mix
184
+ orig_np = np.array(orig_img)
185
+ blended_np = cv2.addWeighted(orig_np, 0.5, heatmap_rgb, 0.5, 0)
186
 
187
+ blended_img = Image.fromarray(blended_np)
188
  buf = io.BytesIO()
189
+ blended_img.save(buf, format="PNG")
190
  buf.seek(0)
191
 
192
  elapsed_time = time.perf_counter() - start_time
193
  print(f"[BENCHMARK] /saliency last-layer map turnaround: {elapsed_time:.4f}s")
194
 
195
+ return StreamingResponse(buf, media_type="image/png")
 
 
 
 
196
 
197
  @app.post("/audit")
198
  async def internal_debate_audit(file: UploadFile = File(...), user_prompt: str = Query(...)):
199
+ """The CLIP-Powered Jury: Fast deterministic grounding verification track."""
200
  start_time = time.perf_counter()
201
  image_bytes = await file.read()
202
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
203
 
204
+ # OPTIMIZATION: Greedy decoding (do_sample=False) + short length constraint
205
+ blip_caption = (await asyncio.to_thread(_generate_sync_batch, ["blip"], image, 1.0, 1, 1.0, 25, False))[0]
206
 
207
+ # CLIP Scoring
208
  clip_m = MODELS["clip"]["model"]
209
  clip_p = MODELS["clip"]["processor"]
210
 
 
216
 
217
  u_score, m_score = float(probs[0]), float(probs[1])
218
 
 
219
  if u_score < 0.35:
220
  verdict = "Perspective Divergence: Intent not grounded in image."
221
  elif abs(u_score - m_score) < 0.15: