dylanplummer commited on
Commit
2621d57
·
1 Parent(s): ed4eebf

speed optim

Browse files
Files changed (1) hide show
  1. app.py +19 -26
app.py CHANGED
@@ -76,8 +76,8 @@ def preprocess_image(img, img_size):
76
 
77
  def run_inference(batch_X):
78
  global ort_sess
79
- batch_X = torch.cat(batch_X)
80
- return ort_sess.run(None, {'video': batch_X.numpy()})
81
 
82
 
83
  def sigmoid(x):
@@ -208,17 +208,11 @@ def count_phases(phase_sin, phase_cos, threshold=0.5):
208
  count: Number of phase transitions
209
  phase_indices: Indices where transitions occur
210
  """
211
- phase_indices = []
212
- count = 0
213
- for i in range(1, len(phase_sin)):
214
- # Check if the sine and cosine phases cross each other
215
- if (phase_sin[i-1] < threshold and phase_sin[i] >= threshold) or \
216
- (phase_sin[i-1] >= threshold and phase_sin[i] < threshold):
217
- # Check if the cosine phase crosses the threshold
218
- if (phase_cos[i-1] < threshold and phase_cos[i] >= threshold) or \
219
- (phase_cos[i-1] >= threshold and phase_cos[i] < threshold):
220
- phase_indices.append(i)
221
- count += 1
222
  return count, phase_indices
223
 
224
 
@@ -263,7 +257,7 @@ def inference(in_video, use_60fps,
263
  frame = all_frames[-1] # padding will be with last frame
264
  break
265
 
266
- frame = cv2.cvtColor(np.uint8(frame), cv2.COLOR_BGR2RGB)
267
  # add square padding with opencv
268
  #frame = square_pad_opencv(frame)
269
  # frame_center_x = frame.shape[1] // 2
@@ -274,7 +268,7 @@ def inference(in_video, use_60fps,
274
  # crop_x = frame_center_x - IMG_SIZE // 2
275
  # crop_y = frame_center_y - IMG_SIZE // 2
276
  # frame = frame[crop_y:crop_y+IMG_SIZE, crop_x:crop_x+IMG_SIZE]
277
- frame = cv2.resize(frame, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_CUBIC)
278
  all_frames.append(frame)
279
 
280
  cap.release()
@@ -294,21 +288,20 @@ def inference(in_video, use_60fps,
294
  batch_list = []
295
  idx_list = []
296
  inference_futures = []
297
- with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
298
  for i in progress.tqdm(range(0, length + stride_length - stride_pad, stride_length)):
299
  batch = all_frames[i:i + seq_len]
300
- Xlist = []
301
- preprocess_tasks = [(idx, executor.submit(preprocess_image, img, IMG_SIZE)) for idx, img in enumerate(batch)]
302
- for idx, future in sorted(preprocess_tasks, key=lambda x: x[0]):
303
- Xlist.append(future.result())
304
 
305
- if len(Xlist) < seq_len:
306
- for _ in range(seq_len - len(Xlist)):
307
- Xlist.append(Xlist[-1])
 
 
 
308
 
309
- X = torch.cat(Xlist)
310
- X *= 255
311
- batch_list.append(X.unsqueeze(0))
312
  idx_list.append(i)
313
 
314
  if len(batch_list) == batch_size:
 
76
 
77
  def run_inference(batch_X):
78
  global ort_sess
79
+ batch_X = np.concatenate(batch_X, axis=0)
80
+ return ort_sess.run(None, {'video': batch_X})
81
 
82
 
83
  def sigmoid(x):
 
208
  count: Number of phase transitions
209
  phase_indices: Indices where transitions occur
210
  """
211
+ sin_crosses = (phase_sin[:-1] < threshold) != (phase_sin[1:] < threshold)
212
+ cos_crosses = (phase_cos[:-1] < threshold) != (phase_cos[1:] < threshold)
213
+ both_cross = sin_crosses & cos_crosses
214
+ phase_indices = (np.where(both_cross)[0] + 1).tolist()
215
+ count = len(phase_indices)
 
 
 
 
 
 
216
  return count, phase_indices
217
 
218
 
 
257
  frame = all_frames[-1] # padding will be with last frame
258
  break
259
 
260
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
261
  # add square padding with opencv
262
  #frame = square_pad_opencv(frame)
263
  # frame_center_x = frame.shape[1] // 2
 
268
  # crop_x = frame_center_x - IMG_SIZE // 2
269
  # crop_y = frame_center_y - IMG_SIZE // 2
270
  # frame = frame[crop_y:crop_y+IMG_SIZE, crop_x:crop_x+IMG_SIZE]
271
+ frame = cv2.resize(frame, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_LINEAR)
272
  all_frames.append(frame)
273
 
274
  cap.release()
 
288
  batch_list = []
289
  idx_list = []
290
  inference_futures = []
291
+ with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
292
  for i in progress.tqdm(range(0, length + stride_length - stride_pad, stride_length)):
293
  batch = all_frames[i:i + seq_len]
294
+ if len(batch) < seq_len:
295
+ batch = batch + [batch[-1]] * (seq_len - len(batch))
 
 
296
 
297
+ # Vectorized preprocessing: stack, transpose HWC->CHW, convert to float32
298
+ # (replaces per-frame PIL conversion + torchvision ToTensor + X*=255 undo)
299
+ X = np.ascontiguousarray(
300
+ np.stack(batch).transpose(0, 3, 1, 2),
301
+ dtype=np.float32
302
+ )
303
 
304
+ batch_list.append(X[np.newaxis]) # add batch dim: (1, seq_len, 3, H, W)
 
 
305
  idx_list.append(i)
306
 
307
  if len(batch_list) == batch_size: