testing further optimizations
Browse files
app.py
CHANGED
|
@@ -20,7 +20,7 @@ from transformers import (
|
|
| 20 |
CLIPModel, CLIPProcessor
|
| 21 |
)
|
| 22 |
|
| 23 |
-
app = FastAPI(title="
|
| 24 |
|
| 25 |
app.add_middleware(
|
| 26 |
CORSMiddleware,
|
|
@@ -32,7 +32,6 @@ app.add_middleware(
|
|
| 32 |
)
|
| 33 |
|
| 34 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 35 |
-
# Use float16 on GPU to slice memory overhead in half; fallback to float32 on CPU
|
| 36 |
DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
|
| 37 |
MODELS = {}
|
| 38 |
|
|
@@ -42,50 +41,54 @@ async def startup_event():
|
|
| 42 |
token = os.getenv("HF_Token")
|
| 43 |
if token: login(token=token)
|
| 44 |
|
| 45 |
-
print("Syncing ensemble weights...")
|
| 46 |
local_dir = snapshot_download(repo_id="SaniaE/Image_Captioning_Ensemble", token=token, local_dir="weights")
|
| 47 |
|
| 48 |
-
# 1.
|
| 49 |
blip_model = BlipForConditionalGeneration.from_pretrained(os.path.join(local_dir, "blip"))
|
| 50 |
MODELS["blip"] = {
|
| 51 |
-
"model": blip_model.to(
|
| 52 |
"processor": BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
|
| 53 |
}
|
| 54 |
|
| 55 |
-
# 2.
|
| 56 |
-
# Point this to your new fine-tuned folder/repo path once your retraining runs are complete
|
| 57 |
vit_model = AutoModelForCausalLM.from_pretrained(os.path.join(local_dir, "vit"))
|
| 58 |
MODELS["vit"] = {
|
| 59 |
-
"model": vit_model.to(
|
| 60 |
"processor": (
|
| 61 |
ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning"),
|
| 62 |
AutoProcessor.from_pretrained("microsoft/git-large")
|
| 63 |
)
|
| 64 |
}
|
| 65 |
|
| 66 |
-
# 3. Load CLIP Jury
|
| 67 |
clip_model = CLIPModel.from_pretrained(os.path.join(local_dir, "clip/clip_model"))
|
| 68 |
MODELS["clip"] = {
|
| 69 |
"model": clip_model.to(device=DEVICE, dtype=DTYPE),
|
| 70 |
"processor": CLIPProcessor.from_pretrained(os.path.join(local_dir, "clip/clip_processor"))
|
| 71 |
}
|
| 72 |
|
| 73 |
-
print("
|
| 74 |
|
| 75 |
-
# ---
|
| 76 |
|
| 77 |
def _generate_batched_ensemble(selection, image, temp, top_k, top_p, max_len=20):
|
| 78 |
-
"""
|
|
|
|
|
|
|
| 79 |
counts = {arch: selection.count(arch) for arch in ["blip", "vit"]}
|
| 80 |
results_map = {"blip": [], "vit": []}
|
| 81 |
|
| 82 |
with torch.inference_mode():
|
| 83 |
|
| 84 |
-
# ---- 1.
|
| 85 |
if counts["blip"] > 0:
|
| 86 |
b_data = MODELS["blip"]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
inputs = b_data["processor"](images=image, return_tensors="pt")
|
| 88 |
-
# Ensure pixel tensors match our low-precision runtime configuration
|
| 89 |
pixel_values = inputs.pixel_values.to(device=DEVICE, dtype=DTYPE)
|
| 90 |
batched_pixels = pixel_values.repeat(counts["blip"], 1, 1, 1)
|
| 91 |
|
|
@@ -101,12 +104,20 @@ def _generate_batched_ensemble(selection, image, temp, top_k, top_p, max_len=20)
|
|
| 101 |
|
| 102 |
decoded = b_data["processor"].batch_decode(ids, skip_special_tokens=True)
|
| 103 |
results_map["blip"] = [cap.strip() for cap in decoded]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
|
| 105 |
-
# ---- 2.
|
| 106 |
if counts["vit"] > 0:
|
| 107 |
v_data = MODELS["vit"]
|
| 108 |
-
i_proc, t_proc = v_data["processor"]
|
| 109 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
inputs = i_proc(images=image, return_tensors="pt")
|
| 111 |
pixel_values = inputs.pixel_values.to(device=DEVICE, dtype=DTYPE)
|
| 112 |
batched_pixels = pixel_values.repeat(counts["vit"], 1, 1, 1)
|
|
@@ -129,8 +140,13 @@ def _generate_batched_ensemble(selection, image, temp, top_k, top_p, max_len=20)
|
|
| 129 |
|
| 130 |
decoded = t_proc.batch_decode(ids, skip_special_tokens=True)
|
| 131 |
results_map["vit"] = [cap.strip() for cap in decoded]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
|
| 133 |
-
#
|
| 134 |
final_captions = []
|
| 135 |
blip_idx, vit_idx = 0, 0
|
| 136 |
for arch in selection:
|
|
@@ -152,7 +168,7 @@ async def generate_captions(
|
|
| 152 |
top_k: int = Query(40),
|
| 153 |
top_p: float = Query(0.9)
|
| 154 |
):
|
| 155 |
-
"""Generates 5 diverse captions
|
| 156 |
start_time = time.perf_counter()
|
| 157 |
image = Image.open(file.file).convert("RGB")
|
| 158 |
|
|
@@ -163,9 +179,6 @@ async def generate_captions(
|
|
| 163 |
_generate_batched_ensemble, selection, image, temp, top_k, top_p, 20
|
| 164 |
)
|
| 165 |
|
| 166 |
-
if DEVICE == "cuda": torch.cuda.empty_cache()
|
| 167 |
-
gc.collect()
|
| 168 |
-
|
| 169 |
elapsed_time = time.perf_counter() - start_time
|
| 170 |
print(f"[BENCHMARK] /generate dual-ensemble turnaround: {elapsed_time:.4f}s")
|
| 171 |
|
|
@@ -185,6 +198,7 @@ async def get_vision_saliency(file: UploadFile = File(...)):
|
|
| 185 |
orig_img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
| 186 |
|
| 187 |
blip = MODELS["blip"]
|
|
|
|
| 188 |
inputs = blip["processor"](images=orig_img, return_tensors="pt")
|
| 189 |
pixel_values = inputs.pixel_values.to(device=DEVICE, dtype=DTYPE)
|
| 190 |
|
|
@@ -195,6 +209,12 @@ async def get_vision_saliency(file: UploadFile = File(...)):
|
|
| 195 |
grid_size = int(np.sqrt(mask_1d.shape[-1]))
|
| 196 |
mask = mask_1d.view(grid_size, grid_size).cpu().numpy()
|
| 197 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-8)
|
| 199 |
w, h = orig_img.size
|
| 200 |
mask_resized = cv2.resize(mask, (w, h), interpolation=cv2.INTER_CUBIC)
|
|
@@ -211,9 +231,6 @@ async def get_vision_saliency(file: UploadFile = File(...)):
|
|
| 211 |
blended_img.save(buf, format="PNG")
|
| 212 |
buf.seek(0)
|
| 213 |
|
| 214 |
-
if DEVICE == "cuda": torch.cuda.empty_cache()
|
| 215 |
-
gc.collect()
|
| 216 |
-
|
| 217 |
return StreamingResponse(buf, media_type="image/png")
|
| 218 |
|
| 219 |
@app.post("/audit")
|
|
@@ -223,12 +240,12 @@ async def internal_debate_audit(file: UploadFile = File(...), user_prompt: str =
|
|
| 223 |
image_bytes = await file.read()
|
| 224 |
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
| 225 |
|
| 226 |
-
# 1.
|
| 227 |
blip_caption = (await asyncio.to_thread(
|
| 228 |
_generate_batched_ensemble, ["blip"], image, 1.0, 1, 1.0, 20
|
| 229 |
))[0]
|
| 230 |
|
| 231 |
-
# 2.
|
| 232 |
clip_m = MODELS["clip"]["model"]
|
| 233 |
clip_p = MODELS["clip"]["processor"]
|
| 234 |
|
|
@@ -236,7 +253,6 @@ async def internal_debate_audit(file: UploadFile = File(...), user_prompt: str =
|
|
| 236 |
text_inputs = clip_p(text=[user_prompt, blip_caption], return_tensors="pt", padding=True)
|
| 237 |
|
| 238 |
with torch.inference_mode():
|
| 239 |
-
# Move inputs to device and cast features dynamically
|
| 240 |
img_pixels = image_inputs.pixel_values.to(device=DEVICE, dtype=DTYPE)
|
| 241 |
txt_ids = text_inputs.input_ids.to(DEVICE)
|
| 242 |
txt_mask = text_inputs.attention_mask.to(DEVICE)
|
|
@@ -254,9 +270,6 @@ async def internal_debate_audit(file: UploadFile = File(...), user_prompt: str =
|
|
| 254 |
verdict = "Model Bias Detected." if abs(u_score - m_score) >= 0.15 else "Consensus: High Alignment."
|
| 255 |
if u_score < 0.35: verdict = "Perspective Divergence: Intent not grounded in image."
|
| 256 |
|
| 257 |
-
if DEVICE == "cuda": torch.cuda.empty_cache()
|
| 258 |
-
gc.collect()
|
| 259 |
-
|
| 260 |
return {
|
| 261 |
"perspectives": {"user": user_prompt, "ai": blip_caption},
|
| 262 |
"audit_scores": {"intent_grounding": round(u_score, 4), "ai_grounding": round(m_score, 4)},
|
|
|
|
| 20 |
CLIPModel, CLIPProcessor
|
| 21 |
)
|
| 22 |
|
| 23 |
+
app = FastAPI(title="XAI Auditor: Hot-Swapping Dual Ensemble")
|
| 24 |
|
| 25 |
app.add_middleware(
|
| 26 |
CORSMiddleware,
|
|
|
|
| 32 |
)
|
| 33 |
|
| 34 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
| 35 |
DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
|
| 36 |
MODELS = {}
|
| 37 |
|
|
|
|
| 41 |
token = os.getenv("HF_Token")
|
| 42 |
if token: login(token=token)
|
| 43 |
|
| 44 |
+
print("Syncing dual-ensemble weights from repository...")
|
| 45 |
local_dir = snapshot_download(repo_id="SaniaE/Image_Captioning_Ensemble", token=token, local_dir="weights")
|
| 46 |
|
| 47 |
+
# 1. Initialize BLIP-Large on CPU
|
| 48 |
blip_model = BlipForConditionalGeneration.from_pretrained(os.path.join(local_dir, "blip"))
|
| 49 |
MODELS["blip"] = {
|
| 50 |
+
"model": blip_model.to(dtype=DTYPE), # Kept on CPU until called
|
| 51 |
"processor": BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
|
| 52 |
}
|
| 53 |
|
| 54 |
+
# 2. Initialize ViT/GIT Tracker on CPU
|
|
|
|
| 55 |
vit_model = AutoModelForCausalLM.from_pretrained(os.path.join(local_dir, "vit"))
|
| 56 |
MODELS["vit"] = {
|
| 57 |
+
"model": vit_model.to(dtype=DTYPE), # Kept on CPU until called
|
| 58 |
"processor": (
|
| 59 |
ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning"),
|
| 60 |
AutoProcessor.from_pretrained("microsoft/git-large")
|
| 61 |
)
|
| 62 |
}
|
| 63 |
|
| 64 |
+
# 3. Load Fine-Tuned CLIP Jury onto active hardware (Crucial for fast parallel scoring)
|
| 65 |
clip_model = CLIPModel.from_pretrained(os.path.join(local_dir, "clip/clip_model"))
|
| 66 |
MODELS["clip"] = {
|
| 67 |
"model": clip_model.to(device=DEVICE, dtype=DTYPE),
|
| 68 |
"processor": CLIPProcessor.from_pretrained(os.path.join(local_dir, "clip/clip_processor"))
|
| 69 |
}
|
| 70 |
|
| 71 |
+
print("Ensemble pipeline initialized with CPU-backed hot-swapping optimization.")
|
| 72 |
|
| 73 |
+
# --- Hot-Swapping Core Logic ---
|
| 74 |
|
| 75 |
def _generate_batched_ensemble(selection, image, temp, top_k, top_p, max_len=20):
|
| 76 |
+
"""
|
| 77 |
+
Executes inference by isolating model execution windows to prevent VRAM thrashing.
|
| 78 |
+
"""
|
| 79 |
counts = {arch: selection.count(arch) for arch in ["blip", "vit"]}
|
| 80 |
results_map = {"blip": [], "vit": []}
|
| 81 |
|
| 82 |
with torch.inference_mode():
|
| 83 |
|
| 84 |
+
# ---- 1. Isolated BLIP Window ----
|
| 85 |
if counts["blip"] > 0:
|
| 86 |
b_data = MODELS["blip"]
|
| 87 |
+
|
| 88 |
+
# Hot-load weights directly onto active device
|
| 89 |
+
b_data["model"].to(DEVICE)
|
| 90 |
+
|
| 91 |
inputs = b_data["processor"](images=image, return_tensors="pt")
|
|
|
|
| 92 |
pixel_values = inputs.pixel_values.to(device=DEVICE, dtype=DTYPE)
|
| 93 |
batched_pixels = pixel_values.repeat(counts["blip"], 1, 1, 1)
|
| 94 |
|
|
|
|
| 104 |
|
| 105 |
decoded = b_data["processor"].batch_decode(ids, skip_special_tokens=True)
|
| 106 |
results_map["blip"] = [cap.strip() for cap in decoded]
|
| 107 |
+
|
| 108 |
+
# Evict model back to system storage space
|
| 109 |
+
b_data["model"].to("cpu")
|
| 110 |
+
if DEVICE == "cuda": torch.cuda.empty_cache()
|
| 111 |
+
gc.collect()
|
| 112 |
|
| 113 |
+
# ---- 2. Isolated ViT Window ----
|
| 114 |
if counts["vit"] > 0:
|
| 115 |
v_data = MODELS["vit"]
|
|
|
|
| 116 |
|
| 117 |
+
# Hot-load model weights now that BLIP has completely cleared out
|
| 118 |
+
v_data["model"].to(DEVICE)
|
| 119 |
+
|
| 120 |
+
i_proc, t_proc = v_data["processor"]
|
| 121 |
inputs = i_proc(images=image, return_tensors="pt")
|
| 122 |
pixel_values = inputs.pixel_values.to(device=DEVICE, dtype=DTYPE)
|
| 123 |
batched_pixels = pixel_values.repeat(counts["vit"], 1, 1, 1)
|
|
|
|
| 140 |
|
| 141 |
decoded = t_proc.batch_decode(ids, skip_special_tokens=True)
|
| 142 |
results_map["vit"] = [cap.strip() for cap in decoded]
|
| 143 |
+
|
| 144 |
+
# Clear device footprints immediately
|
| 145 |
+
v_data["model"].to("cpu")
|
| 146 |
+
if DEVICE == "cuda": torch.cuda.empty_cache()
|
| 147 |
+
gc.collect()
|
| 148 |
|
| 149 |
+
# Align predictions back to original random generation array
|
| 150 |
final_captions = []
|
| 151 |
blip_idx, vit_idx = 0, 0
|
| 152 |
for arch in selection:
|
|
|
|
| 168 |
top_k: int = Query(40),
|
| 169 |
top_p: float = Query(0.9)
|
| 170 |
):
|
| 171 |
+
"""Generates 5 diverse captions via a hot-swapping tensor batching routine."""
|
| 172 |
start_time = time.perf_counter()
|
| 173 |
image = Image.open(file.file).convert("RGB")
|
| 174 |
|
|
|
|
| 179 |
_generate_batched_ensemble, selection, image, temp, top_k, top_p, 20
|
| 180 |
)
|
| 181 |
|
|
|
|
|
|
|
|
|
|
| 182 |
elapsed_time = time.perf_counter() - start_time
|
| 183 |
print(f"[BENCHMARK] /generate dual-ensemble turnaround: {elapsed_time:.4f}s")
|
| 184 |
|
|
|
|
| 198 |
orig_img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
| 199 |
|
| 200 |
blip = MODELS["blip"]
|
| 201 |
+
blip["model"].to(DEVICE) # Bring up to map attentions
|
| 202 |
inputs = blip["processor"](images=orig_img, return_tensors="pt")
|
| 203 |
pixel_values = inputs.pixel_values.to(device=DEVICE, dtype=DTYPE)
|
| 204 |
|
|
|
|
| 209 |
grid_size = int(np.sqrt(mask_1d.shape[-1]))
|
| 210 |
mask = mask_1d.view(grid_size, grid_size).cpu().numpy()
|
| 211 |
|
| 212 |
+
# Offload right after extraction
|
| 213 |
+
blip["model"].to("cpu")
|
| 214 |
+
if DEVICE == "cuda": torch.cuda.empty_cache()
|
| 215 |
+
gc.collect()
|
| 216 |
+
|
| 217 |
+
# Native OpenCV Heatmap Generation Matrix
|
| 218 |
mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-8)
|
| 219 |
w, h = orig_img.size
|
| 220 |
mask_resized = cv2.resize(mask, (w, h), interpolation=cv2.INTER_CUBIC)
|
|
|
|
| 231 |
blended_img.save(buf, format="PNG")
|
| 232 |
buf.seek(0)
|
| 233 |
|
|
|
|
|
|
|
|
|
|
| 234 |
return StreamingResponse(buf, media_type="image/png")
|
| 235 |
|
| 236 |
@app.post("/audit")
|
|
|
|
| 240 |
image_bytes = await file.read()
|
| 241 |
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
| 242 |
|
| 243 |
+
# 1. Get deterministic baseline prediction string
|
| 244 |
blip_caption = (await asyncio.to_thread(
|
| 245 |
_generate_batched_ensemble, ["blip"], image, 1.0, 1, 1.0, 20
|
| 246 |
))[0]
|
| 247 |
|
| 248 |
+
# 2. Match Embeddings (CLIP stays pinned to target hardware device)
|
| 249 |
clip_m = MODELS["clip"]["model"]
|
| 250 |
clip_p = MODELS["clip"]["processor"]
|
| 251 |
|
|
|
|
| 253 |
text_inputs = clip_p(text=[user_prompt, blip_caption], return_tensors="pt", padding=True)
|
| 254 |
|
| 255 |
with torch.inference_mode():
|
|
|
|
| 256 |
img_pixels = image_inputs.pixel_values.to(device=DEVICE, dtype=DTYPE)
|
| 257 |
txt_ids = text_inputs.input_ids.to(DEVICE)
|
| 258 |
txt_mask = text_inputs.attention_mask.to(DEVICE)
|
|
|
|
| 270 |
verdict = "Model Bias Detected." if abs(u_score - m_score) >= 0.15 else "Consensus: High Alignment."
|
| 271 |
if u_score < 0.35: verdict = "Perspective Divergence: Intent not grounded in image."
|
| 272 |
|
|
|
|
|
|
|
|
|
|
| 273 |
return {
|
| 274 |
"perspectives": {"user": user_prompt, "ai": blip_caption},
|
| 275 |
"audit_scores": {"intent_grounding": round(u_score, 4), "ai_grounding": round(m_score, 4)},
|