added optimizations
Browse files
app.py
CHANGED
|
@@ -6,11 +6,11 @@ import random
|
|
| 6 |
import numpy as np
|
| 7 |
import torch
|
| 8 |
import torch.nn.functional as F
|
| 9 |
-
import
|
| 10 |
-
from PIL import Image
|
| 11 |
from fastapi import FastAPI, UploadFile, File, Query
|
| 12 |
from fastapi.middleware.cors import CORSMiddleware
|
| 13 |
-
from fastapi.responses import StreamingResponse
|
| 14 |
from huggingface_hub import snapshot_download, login
|
| 15 |
|
| 16 |
from transformers import (
|
|
@@ -21,14 +21,13 @@ from transformers import (
|
|
| 21 |
|
| 22 |
app = FastAPI(title="XAI Auditor Ensemble with CLIP Jury")
|
| 23 |
|
| 24 |
-
# Enable smooth frontend cross-origin header interceptions for performance metrics
|
| 25 |
app.add_middleware(
|
| 26 |
CORSMiddleware,
|
| 27 |
allow_origins=["*"],
|
| 28 |
allow_credentials=True,
|
| 29 |
allow_methods=["*"],
|
| 30 |
allow_headers=["*"],
|
| 31 |
-
expose_headers=["X-Processing-Time", "X-Audit-Time"
|
| 32 |
)
|
| 33 |
|
| 34 |
# --- Configuration & Paths ---
|
|
@@ -81,47 +80,63 @@ async def startup_event():
|
|
| 81 |
)
|
| 82 |
}
|
| 83 |
|
| 84 |
-
# 3. Load Fine-Tuned CLIP
|
| 85 |
cfg_c = MODEL_CONFIGS["clip"]
|
| 86 |
MODELS["clip"] = {
|
| 87 |
"model": CLIPModel.from_pretrained(os.path.join(local_dir, cfg_c["model_subfolder"])).to(DEVICE),
|
| 88 |
"processor": CLIPProcessor.from_pretrained(os.path.join(local_dir, cfg_c["proc_subfolder"]))
|
| 89 |
}
|
| 90 |
|
| 91 |
-
print("All models synchronized.
|
| 92 |
|
| 93 |
# --- Utilities ---
|
| 94 |
|
| 95 |
-
def
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
|
| 108 |
# --- Endpoints ---
|
| 109 |
|
| 110 |
@app.post("/generate")
|
| 111 |
async def generate_captions(
|
| 112 |
file: UploadFile = File(...),
|
| 113 |
-
temp: float = Query(0.
|
| 114 |
-
top_k: int = Query(
|
| 115 |
top_p: float = Query(0.9)
|
| 116 |
):
|
|
|
|
| 117 |
start_time = time.perf_counter()
|
| 118 |
image = Image.open(file.file).convert("RGB")
|
| 119 |
architectures = ["blip", "vit"]
|
| 120 |
selection = random.choices(architectures, k=5)
|
| 121 |
|
| 122 |
-
#
|
| 123 |
-
|
| 124 |
-
captions = await asyncio.gather(*tasks)
|
| 125 |
|
| 126 |
elapsed_time = time.perf_counter() - start_time
|
| 127 |
print(f"[BENCHMARK] /generate ensemble turnaround: {elapsed_time:.4f}s")
|
|
@@ -137,6 +152,7 @@ async def generate_captions(
|
|
| 137 |
|
| 138 |
@app.post("/saliency")
|
| 139 |
async def get_vision_saliency(file: UploadFile = File(...)):
|
|
|
|
| 140 |
start_time = time.perf_counter()
|
| 141 |
image_bytes = await file.read()
|
| 142 |
orig_img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
|
@@ -151,37 +167,44 @@ async def get_vision_saliency(file: UploadFile = File(...)):
|
|
| 151 |
grid_size = int(np.sqrt(mask_1d.shape[-1]))
|
| 152 |
mask = mask_1d.view(grid_size, grid_size).cpu().numpy()
|
| 153 |
|
|
|
|
| 154 |
mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-8)
|
| 155 |
-
mask_img = Image.fromarray((mask * 255).astype('uint8')).resize(orig_img.size, resample=Image.BICUBIC)
|
| 156 |
-
mask_img = mask_img.filter(ImageFilter.GaussianBlur(radius=10))
|
| 157 |
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
|
|
|
|
| 162 |
buf = io.BytesIO()
|
| 163 |
-
|
| 164 |
buf.seek(0)
|
| 165 |
|
| 166 |
elapsed_time = time.perf_counter() - start_time
|
| 167 |
print(f"[BENCHMARK] /saliency last-layer map turnaround: {elapsed_time:.4f}s")
|
| 168 |
|
| 169 |
-
return StreamingResponse(
|
| 170 |
-
buf,
|
| 171 |
-
media_type="image/png",
|
| 172 |
-
headers={"X-Processing-Time": f"{elapsed_time:.4f}"}
|
| 173 |
-
)
|
| 174 |
|
| 175 |
@app.post("/audit")
|
| 176 |
async def internal_debate_audit(file: UploadFile = File(...), user_prompt: str = Query(...)):
|
|
|
|
| 177 |
start_time = time.perf_counter()
|
| 178 |
image_bytes = await file.read()
|
| 179 |
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
| 180 |
|
| 181 |
-
#
|
| 182 |
-
blip_caption = await asyncio.to_thread(
|
| 183 |
|
| 184 |
-
#
|
| 185 |
clip_m = MODELS["clip"]["model"]
|
| 186 |
clip_p = MODELS["clip"]["processor"]
|
| 187 |
|
|
@@ -193,7 +216,6 @@ async def internal_debate_audit(file: UploadFile = File(...), user_prompt: str =
|
|
| 193 |
|
| 194 |
u_score, m_score = float(probs[0]), float(probs[1])
|
| 195 |
|
| 196 |
-
# 3. Decision Logic
|
| 197 |
if u_score < 0.35:
|
| 198 |
verdict = "Perspective Divergence: Intent not grounded in image."
|
| 199 |
elif abs(u_score - m_score) < 0.15:
|
|
|
|
| 6 |
import numpy as np
|
| 7 |
import torch
|
| 8 |
import torch.nn.functional as F
|
| 9 |
+
import cv2
|
| 10 |
+
from PIL import Image
|
| 11 |
from fastapi import FastAPI, UploadFile, File, Query
|
| 12 |
from fastapi.middleware.cors import CORSMiddleware
|
| 13 |
+
from fastapi.responses import StreamingResponse
|
| 14 |
from huggingface_hub import snapshot_download, login
|
| 15 |
|
| 16 |
from transformers import (
|
|
|
|
| 21 |
|
| 22 |
app = FastAPI(title="XAI Auditor Ensemble with CLIP Jury")
|
| 23 |
|
|
|
|
| 24 |
app.add_middleware(
|
| 25 |
CORSMiddleware,
|
| 26 |
allow_origins=["*"],
|
| 27 |
allow_credentials=True,
|
| 28 |
allow_methods=["*"],
|
| 29 |
allow_headers=["*"],
|
| 30 |
+
expose_headers=["X-Processing-Time", "X-Audit-Time"]
|
| 31 |
)
|
| 32 |
|
| 33 |
# --- Configuration & Paths ---
|
|
|
|
| 80 |
)
|
| 81 |
}
|
| 82 |
|
| 83 |
+
# 3. Load Fine-Tuned CLIP
|
| 84 |
cfg_c = MODEL_CONFIGS["clip"]
|
| 85 |
MODELS["clip"] = {
|
| 86 |
"model": CLIPModel.from_pretrained(os.path.join(local_dir, cfg_c["model_subfolder"])).to(DEVICE),
|
| 87 |
"processor": CLIPProcessor.from_pretrained(os.path.join(local_dir, cfg_c["proc_subfolder"]))
|
| 88 |
}
|
| 89 |
|
| 90 |
+
print("All models synchronized. Ensemble backend is active.")
|
| 91 |
|
| 92 |
# --- Utilities ---
|
| 93 |
|
| 94 |
+
def _generate_sync_batch(selection, image, temp, top_k, top_p, max_len=45, do_sample=True):
|
| 95 |
+
"""Processes generation sequentially to eliminate CPU context-switching overhead."""
|
| 96 |
+
captions = []
|
| 97 |
+
for m_name in selection:
|
| 98 |
+
m_data = MODELS[m_name]
|
| 99 |
+
if m_name == "vit":
|
| 100 |
+
i_proc, t_proc = m_data["processor"]
|
| 101 |
+
inputs = i_proc(images=image, return_tensors="pt").to(DEVICE)
|
| 102 |
+
# Cap max_new_tokens for snappier generation runtimes
|
| 103 |
+
ids = m_data["model"].generate(
|
| 104 |
+
**inputs, max_new_tokens=max_len, do_sample=do_sample,
|
| 105 |
+
temperature=temp if do_sample else None,
|
| 106 |
+
top_k=top_k if do_sample else None,
|
| 107 |
+
top_p=top_p if do_sample else None
|
| 108 |
+
)
|
| 109 |
+
caption = t_proc.batch_decode(ids, skip_special_tokens=True)[0].strip()
|
| 110 |
+
else:
|
| 111 |
+
proc = m_data["processor"]
|
| 112 |
+
inputs = proc(images=image, return_tensors="pt").to(DEVICE)
|
| 113 |
+
ids = m_data["model"].generate(
|
| 114 |
+
**inputs, max_new_tokens=max_len, do_sample=do_sample,
|
| 115 |
+
temperature=temp if do_sample else None,
|
| 116 |
+
top_k=top_k if do_sample else None,
|
| 117 |
+
top_p=top_p if do_sample else None
|
| 118 |
+
)
|
| 119 |
+
caption = proc.batch_decode(ids, skip_special_tokens=True)[0].strip()
|
| 120 |
+
captions.append(caption)
|
| 121 |
+
return captions
|
| 122 |
|
| 123 |
# --- Endpoints ---
|
| 124 |
|
| 125 |
@app.post("/generate")
|
| 126 |
async def generate_captions(
|
| 127 |
file: UploadFile = File(...),
|
| 128 |
+
temp: float = Query(0.7),
|
| 129 |
+
top_k: int = Query(40),
|
| 130 |
top_p: float = Query(0.9)
|
| 131 |
):
|
| 132 |
+
"""Generates 5 diverse captions using an optimized sequential pipeline."""
|
| 133 |
start_time = time.perf_counter()
|
| 134 |
image = Image.open(file.file).convert("RGB")
|
| 135 |
architectures = ["blip", "vit"]
|
| 136 |
selection = random.choices(architectures, k=5)
|
| 137 |
|
| 138 |
+
# Run loop sequentially inside a thread worker to safely dodge GIL contention
|
| 139 |
+
captions = await asyncio.to_thread(_generate_sync_batch, selection, image, temp, top_k, top_p, 45, True)
|
|
|
|
| 140 |
|
| 141 |
elapsed_time = time.perf_counter() - start_time
|
| 142 |
print(f"[BENCHMARK] /generate ensemble turnaround: {elapsed_time:.4f}s")
|
|
|
|
| 152 |
|
| 153 |
@app.post("/saliency")
|
| 154 |
async def get_vision_saliency(file: UploadFile = File(...)):
|
| 155 |
+
"""Objective Saliency: Highly optimized native vision encoder self-attention mapping."""
|
| 156 |
start_time = time.perf_counter()
|
| 157 |
image_bytes = await file.read()
|
| 158 |
orig_img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
|
|
|
| 167 |
grid_size = int(np.sqrt(mask_1d.shape[-1]))
|
| 168 |
mask = mask_1d.view(grid_size, grid_size).cpu().numpy()
|
| 169 |
|
| 170 |
+
# Normalize attention matrix
|
| 171 |
mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-8)
|
|
|
|
|
|
|
| 172 |
|
| 173 |
+
# Vectorized OpenCV handling for super fast image processing
|
| 174 |
+
w, h = orig_img.size
|
| 175 |
+
mask_resized = cv2.resize(mask, (w, h), interpolation=cv2.INTER_CUBIC)
|
| 176 |
+
mask_blurred = cv2.GaussianBlur(mask_resized, (21, 21), 0)
|
| 177 |
+
|
| 178 |
+
# Convert normalized heatmap to standard color map space
|
| 179 |
+
heatmap_uint8 = np.uint8(255 * mask_blurred)
|
| 180 |
+
heatmap_bgr = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_MAGMA)
|
| 181 |
+
heatmap_rgb = cv2.cvtColor(heatmap_bgr, cv2.COLOR_BGR2RGB)
|
| 182 |
+
|
| 183 |
+
# Composite overlay mix
|
| 184 |
+
orig_np = np.array(orig_img)
|
| 185 |
+
blended_np = cv2.addWeighted(orig_np, 0.5, heatmap_rgb, 0.5, 0)
|
| 186 |
|
| 187 |
+
blended_img = Image.fromarray(blended_np)
|
| 188 |
buf = io.BytesIO()
|
| 189 |
+
blended_img.save(buf, format="PNG")
|
| 190 |
buf.seek(0)
|
| 191 |
|
| 192 |
elapsed_time = time.perf_counter() - start_time
|
| 193 |
print(f"[BENCHMARK] /saliency last-layer map turnaround: {elapsed_time:.4f}s")
|
| 194 |
|
| 195 |
+
return StreamingResponse(buf, media_type="image/png")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
|
| 197 |
@app.post("/audit")
|
| 198 |
async def internal_debate_audit(file: UploadFile = File(...), user_prompt: str = Query(...)):
|
| 199 |
+
"""The CLIP-Powered Jury: Fast deterministic grounding verification track."""
|
| 200 |
start_time = time.perf_counter()
|
| 201 |
image_bytes = await file.read()
|
| 202 |
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
| 203 |
|
| 204 |
+
# OPTIMIZATION: Greedy decoding (do_sample=False) + short length constraint
|
| 205 |
+
blip_caption = (await asyncio.to_thread(_generate_sync_batch, ["blip"], image, 1.0, 1, 1.0, 25, False))[0]
|
| 206 |
|
| 207 |
+
# CLIP Scoring
|
| 208 |
clip_m = MODELS["clip"]["model"]
|
| 209 |
clip_p = MODELS["clip"]["processor"]
|
| 210 |
|
|
|
|
| 216 |
|
| 217 |
u_score, m_score = float(probs[0]), float(probs[1])
|
| 218 |
|
|
|
|
| 219 |
if u_score < 0.35:
|
| 220 |
verdict = "Perspective Divergence: Intent not grounded in image."
|
| 221 |
elif abs(u_score - m_score) < 0.15:
|