jiang-cc commited on
Commit
dbd07aa
·
verified ·
1 Parent(s): 5195a24

feat: show preprocessing vs inference timing in output

Browse files
Files changed (1) hide show
  1. app.py +13 -4
app.py CHANGED
@@ -5,6 +5,7 @@ AD-Copilot Demo: Comparison-Aware Anomaly Detection with Vision-Language Model
5
  import json
6
  import os
7
  import re
 
8
  import traceback
9
  import spaces
10
  import gradio as gr
@@ -115,28 +116,26 @@ def predict(
115
  return "Please upload at least one image.", None
116
 
117
  try:
 
118
  max_new_tokens = int(max_new_tokens)
119
 
120
  # Build message content based on available images
121
  content = []
122
 
123
  if has_ref and has_test:
124
- # Paired comparison mode
125
  ref = reference_image.copy()
126
  tst = test_image.copy()
127
  ref.thumbnail((512, 512), Image.Resampling.LANCZOS)
128
  tst.thumbnail((512, 512), Image.Resampling.LANCZOS)
129
  content.append({"type": "image", "image": ref})
130
  content.append({"type": "image", "image": tst})
131
- vis_source = tst # visualize on test image
132
  elif has_test:
133
- # Single image mode (only test image)
134
  tst = test_image.copy()
135
  tst.thumbnail((512, 512), Image.Resampling.LANCZOS)
136
  content.append({"type": "image", "image": tst})
137
  vis_source = tst
138
  else:
139
- # Single image mode (only reference image)
140
  ref = reference_image.copy()
141
  ref.thumbnail((512, 512), Image.Resampling.LANCZOS)
142
  content.append({"type": "image", "image": ref})
@@ -158,9 +157,14 @@ def predict(
158
  return_tensors="pt",
159
  ).to(model.device)
160
 
 
 
161
  generated_ids = model.generate(
162
  **inputs, max_new_tokens=max_new_tokens, do_sample=False
163
  )
 
 
 
164
  generated_ids_trimmed = [
165
  out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids)
166
  ]
@@ -176,6 +180,11 @@ def predict(
176
  if bboxes:
177
  vis_image = draw_bboxes(vis_source, bboxes)
178
 
 
 
 
 
 
179
  return output, vis_image
180
  except Exception as e:
181
  tb = traceback.format_exc()
 
5
  import json
6
  import os
7
  import re
8
+ import time
9
  import traceback
10
  import spaces
11
  import gradio as gr
 
116
  return "Please upload at least one image.", None
117
 
118
  try:
119
+ t_start = time.time()
120
  max_new_tokens = int(max_new_tokens)
121
 
122
  # Build message content based on available images
123
  content = []
124
 
125
  if has_ref and has_test:
 
126
  ref = reference_image.copy()
127
  tst = test_image.copy()
128
  ref.thumbnail((512, 512), Image.Resampling.LANCZOS)
129
  tst.thumbnail((512, 512), Image.Resampling.LANCZOS)
130
  content.append({"type": "image", "image": ref})
131
  content.append({"type": "image", "image": tst})
132
+ vis_source = tst
133
  elif has_test:
 
134
  tst = test_image.copy()
135
  tst.thumbnail((512, 512), Image.Resampling.LANCZOS)
136
  content.append({"type": "image", "image": tst})
137
  vis_source = tst
138
  else:
 
139
  ref = reference_image.copy()
140
  ref.thumbnail((512, 512), Image.Resampling.LANCZOS)
141
  content.append({"type": "image", "image": ref})
 
157
  return_tensors="pt",
158
  ).to(model.device)
159
 
160
+ t_preprocess = time.time()
161
+
162
  generated_ids = model.generate(
163
  **inputs, max_new_tokens=max_new_tokens, do_sample=False
164
  )
165
+
166
+ t_generate = time.time()
167
+
168
  generated_ids_trimmed = [
169
  out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids)
170
  ]
 
180
  if bboxes:
181
  vis_image = draw_bboxes(vis_source, bboxes)
182
 
183
+ # Append timing info
184
+ prep_time = t_preprocess - t_start
185
+ gen_time = t_generate - t_preprocess
186
+ output += f"\n\n---\nPreprocessing: {prep_time:.1f}s | Inference: {gen_time:.1f}s"
187
+
188
  return output, vis_image
189
  except Exception as e:
190
  tb = traceback.format_exc()