SaniaE commited on
Commit
6d073fb
·
verified ·
1 Parent(s): d0c7c50

testing further optimizations

Browse files
Files changed (1) hide show
  1. app.py +43 -30
app.py CHANGED
@@ -20,7 +20,7 @@ from transformers import (
20
  CLIPModel, CLIPProcessor
21
  )
22
 
23
- app = FastAPI(title="Optimized Dual-Ensemble XAI Auditor Backend")
24
 
25
  app.add_middleware(
26
  CORSMiddleware,
@@ -32,7 +32,6 @@ app.add_middleware(
32
  )
33
 
34
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
35
- # Use float16 on GPU to slice memory overhead in half; fallback to float32 on CPU
36
  DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
37
  MODELS = {}
38
 
@@ -42,50 +41,54 @@ async def startup_event():
42
  token = os.getenv("HF_Token")
43
  if token: login(token=token)
44
 
45
- print("Syncing ensemble weights...")
46
  local_dir = snapshot_download(repo_id="SaniaE/Image_Captioning_Ensemble", token=token, local_dir="weights")
47
 
48
- # 1. Load BLIP-Large (.half() cuts VRAM allocation in half)
49
  blip_model = BlipForConditionalGeneration.from_pretrained(os.path.join(local_dir, "blip"))
50
  MODELS["blip"] = {
51
- "model": blip_model.to(device=DEVICE, dtype=DTYPE),
52
  "processor": BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
53
  }
54
 
55
- # 2. Load ViT / Descriptive Language Model Track
56
- # Point this to your new fine-tuned folder/repo path once your retraining runs are complete
57
  vit_model = AutoModelForCausalLM.from_pretrained(os.path.join(local_dir, "vit"))
58
  MODELS["vit"] = {
59
- "model": vit_model.to(device=DEVICE, dtype=DTYPE),
60
  "processor": (
61
  ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning"),
62
  AutoProcessor.from_pretrained("microsoft/git-large")
63
  )
64
  }
65
 
66
- # 3. Load CLIP Jury
67
  clip_model = CLIPModel.from_pretrained(os.path.join(local_dir, "clip/clip_model"))
68
  MODELS["clip"] = {
69
  "model": clip_model.to(device=DEVICE, dtype=DTYPE),
70
  "processor": CLIPProcessor.from_pretrained(os.path.join(local_dir, "clip/clip_processor"))
71
  }
72
 
73
- print("Dual-ensemble successfully loaded in optimized low-precision layout.")
74
 
75
- # --- Optimized Core Utility ---
76
 
77
  def _generate_batched_ensemble(selection, image, temp, top_k, top_p, max_len=20):
78
- """高效的双模型批处理生成引擎 (Optimized Dual-Model Batch Generation Engine)"""
 
 
79
  counts = {arch: selection.count(arch) for arch in ["blip", "vit"]}
80
  results_map = {"blip": [], "vit": []}
81
 
82
  with torch.inference_mode():
83
 
84
- # ---- 1. Optimized BLIP Pass ----
85
  if counts["blip"] > 0:
86
  b_data = MODELS["blip"]
 
 
 
 
87
  inputs = b_data["processor"](images=image, return_tensors="pt")
88
- # Ensure pixel tensors match our low-precision runtime configuration
89
  pixel_values = inputs.pixel_values.to(device=DEVICE, dtype=DTYPE)
90
  batched_pixels = pixel_values.repeat(counts["blip"], 1, 1, 1)
91
 
@@ -101,12 +104,20 @@ def _generate_batched_ensemble(selection, image, temp, top_k, top_p, max_len=20)
101
 
102
  decoded = b_data["processor"].batch_decode(ids, skip_special_tokens=True)
103
  results_map["blip"] = [cap.strip() for cap in decoded]
 
 
 
 
 
104
 
105
- # ---- 2. Optimized ViT/GIT Pass ----
106
  if counts["vit"] > 0:
107
  v_data = MODELS["vit"]
108
- i_proc, t_proc = v_data["processor"]
109
 
 
 
 
 
110
  inputs = i_proc(images=image, return_tensors="pt")
111
  pixel_values = inputs.pixel_values.to(device=DEVICE, dtype=DTYPE)
112
  batched_pixels = pixel_values.repeat(counts["vit"], 1, 1, 1)
@@ -129,8 +140,13 @@ def _generate_batched_ensemble(selection, image, temp, top_k, top_p, max_len=20)
129
 
130
  decoded = t_proc.batch_decode(ids, skip_special_tokens=True)
131
  results_map["vit"] = [cap.strip() for cap in decoded]
 
 
 
 
 
132
 
133
- # Map outputs back to the original order requested by your random draw selection
134
  final_captions = []
135
  blip_idx, vit_idx = 0, 0
136
  for arch in selection:
@@ -152,7 +168,7 @@ async def generate_captions(
152
  top_k: int = Query(40),
153
  top_p: float = Query(0.9)
154
  ):
155
- """Generates 5 diverse captions across an explicit multi-model selection field."""
156
  start_time = time.perf_counter()
157
  image = Image.open(file.file).convert("RGB")
158
 
@@ -163,9 +179,6 @@ async def generate_captions(
163
  _generate_batched_ensemble, selection, image, temp, top_k, top_p, 20
164
  )
165
 
166
- if DEVICE == "cuda": torch.cuda.empty_cache()
167
- gc.collect()
168
-
169
  elapsed_time = time.perf_counter() - start_time
170
  print(f"[BENCHMARK] /generate dual-ensemble turnaround: {elapsed_time:.4f}s")
171
 
@@ -185,6 +198,7 @@ async def get_vision_saliency(file: UploadFile = File(...)):
185
  orig_img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
186
 
187
  blip = MODELS["blip"]
 
188
  inputs = blip["processor"](images=orig_img, return_tensors="pt")
189
  pixel_values = inputs.pixel_values.to(device=DEVICE, dtype=DTYPE)
190
 
@@ -195,6 +209,12 @@ async def get_vision_saliency(file: UploadFile = File(...)):
195
  grid_size = int(np.sqrt(mask_1d.shape[-1]))
196
  mask = mask_1d.view(grid_size, grid_size).cpu().numpy()
197
 
 
 
 
 
 
 
198
  mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-8)
199
  w, h = orig_img.size
200
  mask_resized = cv2.resize(mask, (w, h), interpolation=cv2.INTER_CUBIC)
@@ -211,9 +231,6 @@ async def get_vision_saliency(file: UploadFile = File(...)):
211
  blended_img.save(buf, format="PNG")
212
  buf.seek(0)
213
 
214
- if DEVICE == "cuda": torch.cuda.empty_cache()
215
- gc.collect()
216
-
217
  return StreamingResponse(buf, media_type="image/png")
218
 
219
  @app.post("/audit")
@@ -223,12 +240,12 @@ async def internal_debate_audit(file: UploadFile = File(...), user_prompt: str =
223
  image_bytes = await file.read()
224
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
225
 
226
- # 1. Deterministic Base Prediction Pass
227
  blip_caption = (await asyncio.to_thread(
228
  _generate_batched_ensemble, ["blip"], image, 1.0, 1, 1.0, 20
229
  ))[0]
230
 
231
- # 2. Low-Precision Decoupled CLIP Scoring Matrix
232
  clip_m = MODELS["clip"]["model"]
233
  clip_p = MODELS["clip"]["processor"]
234
 
@@ -236,7 +253,6 @@ async def internal_debate_audit(file: UploadFile = File(...), user_prompt: str =
236
  text_inputs = clip_p(text=[user_prompt, blip_caption], return_tensors="pt", padding=True)
237
 
238
  with torch.inference_mode():
239
- # Move inputs to device and cast features dynamically
240
  img_pixels = image_inputs.pixel_values.to(device=DEVICE, dtype=DTYPE)
241
  txt_ids = text_inputs.input_ids.to(DEVICE)
242
  txt_mask = text_inputs.attention_mask.to(DEVICE)
@@ -254,9 +270,6 @@ async def internal_debate_audit(file: UploadFile = File(...), user_prompt: str =
254
  verdict = "Model Bias Detected." if abs(u_score - m_score) >= 0.15 else "Consensus: High Alignment."
255
  if u_score < 0.35: verdict = "Perspective Divergence: Intent not grounded in image."
256
 
257
- if DEVICE == "cuda": torch.cuda.empty_cache()
258
- gc.collect()
259
-
260
  return {
261
  "perspectives": {"user": user_prompt, "ai": blip_caption},
262
  "audit_scores": {"intent_grounding": round(u_score, 4), "ai_grounding": round(m_score, 4)},
 
20
  CLIPModel, CLIPProcessor
21
  )
22
 
23
+ app = FastAPI(title="XAI Auditor: Hot-Swapping Dual Ensemble")
24
 
25
  app.add_middleware(
26
  CORSMiddleware,
 
32
  )
33
 
34
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
35
  DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
36
  MODELS = {}
37
 
 
41
  token = os.getenv("HF_Token")
42
  if token: login(token=token)
43
 
44
+ print("Syncing dual-ensemble weights from repository...")
45
  local_dir = snapshot_download(repo_id="SaniaE/Image_Captioning_Ensemble", token=token, local_dir="weights")
46
 
47
+ # 1. Initialize BLIP-Large on CPU
48
  blip_model = BlipForConditionalGeneration.from_pretrained(os.path.join(local_dir, "blip"))
49
  MODELS["blip"] = {
50
+ "model": blip_model.to(dtype=DTYPE), # Kept on CPU until called
51
  "processor": BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
52
  }
53
 
54
+ # 2. Initialize ViT/GIT Tracker on CPU
 
55
  vit_model = AutoModelForCausalLM.from_pretrained(os.path.join(local_dir, "vit"))
56
  MODELS["vit"] = {
57
+ "model": vit_model.to(dtype=DTYPE), # Kept on CPU until called
58
  "processor": (
59
  ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning"),
60
  AutoProcessor.from_pretrained("microsoft/git-large")
61
  )
62
  }
63
 
64
+ # 3. Load Fine-Tuned CLIP Jury onto active hardware (Crucial for fast parallel scoring)
65
  clip_model = CLIPModel.from_pretrained(os.path.join(local_dir, "clip/clip_model"))
66
  MODELS["clip"] = {
67
  "model": clip_model.to(device=DEVICE, dtype=DTYPE),
68
  "processor": CLIPProcessor.from_pretrained(os.path.join(local_dir, "clip/clip_processor"))
69
  }
70
 
71
+ print("Ensemble pipeline initialized with CPU-backed hot-swapping optimization.")
72
 
73
+ # --- Hot-Swapping Core Logic ---
74
 
75
  def _generate_batched_ensemble(selection, image, temp, top_k, top_p, max_len=20):
76
+ """
77
+ Executes inference by isolating model execution windows to prevent VRAM thrashing.
78
+ """
79
  counts = {arch: selection.count(arch) for arch in ["blip", "vit"]}
80
  results_map = {"blip": [], "vit": []}
81
 
82
  with torch.inference_mode():
83
 
84
+ # ---- 1. Isolated BLIP Window ----
85
  if counts["blip"] > 0:
86
  b_data = MODELS["blip"]
87
+
88
+ # Hot-load weights directly onto active device
89
+ b_data["model"].to(DEVICE)
90
+
91
  inputs = b_data["processor"](images=image, return_tensors="pt")
 
92
  pixel_values = inputs.pixel_values.to(device=DEVICE, dtype=DTYPE)
93
  batched_pixels = pixel_values.repeat(counts["blip"], 1, 1, 1)
94
 
 
104
 
105
  decoded = b_data["processor"].batch_decode(ids, skip_special_tokens=True)
106
  results_map["blip"] = [cap.strip() for cap in decoded]
107
+
108
+ # Evict model back to system storage space
109
+ b_data["model"].to("cpu")
110
+ if DEVICE == "cuda": torch.cuda.empty_cache()
111
+ gc.collect()
112
 
113
+ # ---- 2. Isolated ViT Window ----
114
  if counts["vit"] > 0:
115
  v_data = MODELS["vit"]
 
116
 
117
+ # Hot-load model weights now that BLIP has completely cleared out
118
+ v_data["model"].to(DEVICE)
119
+
120
+ i_proc, t_proc = v_data["processor"]
121
  inputs = i_proc(images=image, return_tensors="pt")
122
  pixel_values = inputs.pixel_values.to(device=DEVICE, dtype=DTYPE)
123
  batched_pixels = pixel_values.repeat(counts["vit"], 1, 1, 1)
 
140
 
141
  decoded = t_proc.batch_decode(ids, skip_special_tokens=True)
142
  results_map["vit"] = [cap.strip() for cap in decoded]
143
+
144
+ # Clear device footprints immediately
145
+ v_data["model"].to("cpu")
146
+ if DEVICE == "cuda": torch.cuda.empty_cache()
147
+ gc.collect()
148
 
149
+ # Align predictions back to original random generation array
150
  final_captions = []
151
  blip_idx, vit_idx = 0, 0
152
  for arch in selection:
 
168
  top_k: int = Query(40),
169
  top_p: float = Query(0.9)
170
  ):
171
+ """Generates 5 diverse captions via a hot-swapping tensor batching routine."""
172
  start_time = time.perf_counter()
173
  image = Image.open(file.file).convert("RGB")
174
 
 
179
  _generate_batched_ensemble, selection, image, temp, top_k, top_p, 20
180
  )
181
 
 
 
 
182
  elapsed_time = time.perf_counter() - start_time
183
  print(f"[BENCHMARK] /generate dual-ensemble turnaround: {elapsed_time:.4f}s")
184
 
 
198
  orig_img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
199
 
200
  blip = MODELS["blip"]
201
+ blip["model"].to(DEVICE) # Bring up to map attentions
202
  inputs = blip["processor"](images=orig_img, return_tensors="pt")
203
  pixel_values = inputs.pixel_values.to(device=DEVICE, dtype=DTYPE)
204
 
 
209
  grid_size = int(np.sqrt(mask_1d.shape[-1]))
210
  mask = mask_1d.view(grid_size, grid_size).cpu().numpy()
211
 
212
+ # Offload right after extraction
213
+ blip["model"].to("cpu")
214
+ if DEVICE == "cuda": torch.cuda.empty_cache()
215
+ gc.collect()
216
+
217
+ # Native OpenCV Heatmap Generation Matrix
218
  mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-8)
219
  w, h = orig_img.size
220
  mask_resized = cv2.resize(mask, (w, h), interpolation=cv2.INTER_CUBIC)
 
231
  blended_img.save(buf, format="PNG")
232
  buf.seek(0)
233
 
 
 
 
234
  return StreamingResponse(buf, media_type="image/png")
235
 
236
  @app.post("/audit")
 
240
  image_bytes = await file.read()
241
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
242
 
243
+ # 1. Get deterministic baseline prediction string
244
  blip_caption = (await asyncio.to_thread(
245
  _generate_batched_ensemble, ["blip"], image, 1.0, 1, 1.0, 20
246
  ))[0]
247
 
248
+ # 2. Match Embeddings (CLIP stays pinned to target hardware device)
249
  clip_m = MODELS["clip"]["model"]
250
  clip_p = MODELS["clip"]["processor"]
251
 
 
253
  text_inputs = clip_p(text=[user_prompt, blip_caption], return_tensors="pt", padding=True)
254
 
255
  with torch.inference_mode():
 
256
  img_pixels = image_inputs.pixel_values.to(device=DEVICE, dtype=DTYPE)
257
  txt_ids = text_inputs.input_ids.to(DEVICE)
258
  txt_mask = text_inputs.attention_mask.to(DEVICE)
 
270
  verdict = "Model Bias Detected." if abs(u_score - m_score) >= 0.15 else "Consensus: High Alignment."
271
  if u_score < 0.35: verdict = "Perspective Divergence: Intent not grounded in image."
272
 
 
 
 
273
  return {
274
  "perspectives": {"user": user_prompt, "ai": blip_caption},
275
  "audit_scores": {"intent_grounding": round(u_score, 4), "ai_grounding": round(m_score, 4)},