Darknsu commited on
Commit
a719c01
·
verified ·
1 Parent(s): d631501

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +102 -179
main.py CHANGED
@@ -45,7 +45,7 @@ VIS_CONFIG = {
45
  'frame_highlight_gt': 'red',
46
  'frame_highlight_pred': 'black',
47
  'iou_threshold': 0.3,
48
- 'frame_scale_factor': 0.3, # Reduced for CPU
49
  'video_text_scale': 0.5,
50
  'video_gt_text_color': (180, 119, 31), # BGR
51
  'video_pred_text_color': (14, 127, 255), # BGR
@@ -55,8 +55,8 @@ VIS_CONFIG = {
55
  'video_pred_text_y': 0.45,
56
  'video_gt_text_y': 0.55,
57
  'video_footer_height': 150,
58
- 'video_gt_bar_y': 0.5,
59
- 'video_pred_bar_y': 0.8,
60
  'video_bar_height': 0.15,
61
  'video_bar_text_scale': 0.7,
62
  'min_segment_duration': 1.0,
@@ -108,7 +108,7 @@ def annotate_video_with_actions(
108
  (185, 218, 255), (255, 204, 204), (193, 182, 255), (201, 252, 189),
109
  (144, 128, 112), (112, 25, 25), (102, 51, 102), (0, 128, 128), (171, 71, 0)
110
  ]
111
- action_labels = set(seg['label'] for seg in gt_segments).union(seg['label'] for seg in pred_segments)
112
  action_color_map = {label: color_palette[i % len(color_palette)] for i, label in enumerate(action_labels)}
113
  gt_color_rgb = (gt_text_color[2], gt_text_color[1], gt_text_color[0])
114
  pred_color_rgb = (pred_text_color[2], pred_text_color[1], pred_text_color[0])
@@ -128,10 +128,10 @@ def annotate_video_with_actions(
128
  except IOError:
129
  font = None
130
  bar_font = None
131
- window_size = 20.0
132
  num_windows = int(np.ceil(duration / window_size))
133
  text_bar_gap = 48
134
- text_x = 10
135
  frame_idx = 0
136
  written_frames = 0
137
  while cap.isOpened():
@@ -148,12 +148,12 @@ def annotate_video_with_actions(
148
  window_duration = window_end - window_start
149
  window_timestamp = timestamp - window_start
150
  gt_labels = [seg['label'] for seg in gt_segments if seg['start'] <= timestamp <= seg['end']]
151
- gt_text = "GT: " + ", ".join(gt_labels) if gt_labels else ""
152
  pred_labels = [seg['label'] for seg in pred_segments if seg['start'] <= timestamp <= seg['end']]
153
- pred_text = "Pred: " + ", ".join(pred_labels) if pred_labels else ""
154
  footer_y = frame_height
155
- gt_bar_y = footer_y + int(0.2 * footer_height)
156
- pred_bar_y = footer_y + int(0.5 * footer_height)
157
  bar_height = int(VIS_CONFIG['video_bar_height'] * footer_height)
158
  if font:
159
  gt_text_bbox = bar_font.getbbox("GT")
@@ -204,20 +204,18 @@ def annotate_video_with_actions(
204
  frame_text_bbox = draw.textbbox((0, 0), frame_info, font=font)
205
  frame_text_width = frame_text_bbox[2] - frame_text_bbox[0]
206
  frame_text_x = (frame_width - frame_text_width) // 2
207
- draw.text((frame_text_x, 10), frame_info, font=font, fill=(0, 0, 0))
208
  window_info = f"{window_start:.1f}s - {window_end:.1f}s"
209
  window_text_bbox = draw.textbbox((0, 0), window_info, font=bar_font)
210
  window_text_width = window_text_bbox[2] - window_text_bbox[0]
211
  window_text_x = (frame_width - window_text_width) // 2
212
  draw.text((window_text_x, footer_y + 10), window_info, font=bar_font, fill=(0, 0, 0))
213
- if gt_text:
214
- gt_y = int(frame_height * VIS_CONFIG['video_gt_text_y'])
215
- draw.text((10, gt_y), gt_text, font=font, fill=gt_color_rgb)
216
- if pred_text:
217
- pred_y = int(frame_height * VIS_CONFIG['video_pred_text_y'])
218
- draw.text((10, pred_y), pred_text, font=font, fill=pred_color_rgb)
219
  draw.text((text_x, gt_bar_y + bar_height // 2), "GT", font=bar_font, fill=gt_color_rgb)
220
  draw.text((text_x, pred_bar_y + bar_height // 2), "Pred", font=bar_font, fill=pred_color_rgb)
 
 
 
 
221
  extended_frame = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
222
  else:
223
  frame_info = f"Frame: {frame_idx} | FPS: {fps:.2f}"
@@ -226,7 +224,7 @@ def annotate_video_with_actions(
226
  cv2.putText(
227
  extended_frame,
228
  frame_info,
229
- (frame_text_x, 30),
230
  cv2.FONT_HERSHEY_SIMPLEX,
231
  text_scale,
232
  (0, 0, 0),
@@ -246,28 +244,26 @@ def annotate_video_with_actions(
246
  1,
247
  cv2.LINE_AA
248
  )
249
- if gt_text:
250
- cv2.putText(
251
- extended_frame,
252
- gt_text,
253
- (10, int(frame_height * VIS_CONFIG['video_gt_text_y'])),
254
- cv2.FONT_HERSHEY_SIMPLEX,
255
- text_scale,
256
- gt_text_color,
257
- text_thickness,
258
- cv2.LINE_AA
259
- )
260
- if pred_text:
261
- cv2.putText(
262
- extended_frame,
263
- pred_text,
264
- (10, int(frame_height * VIS_CONFIG['video_pred_text_y'])),
265
- cv2.FONT_HERSHEY_SIMPLEX,
266
- text_scale,
267
- pred_text_color,
268
- text_thickness,
269
- cv2.LINE_AA
270
- )
271
  cv2.putText(
272
  extended_frame,
273
  "GT",
@@ -316,7 +312,7 @@ def visualize_action_lengths(
316
  if num_frames > VIS_CONFIG['max_frames']:
317
  frame_interval = duration / (VIS_CONFIG['max_frames'] - 1)
318
  num_frames = VIS_CONFIG['max_frames']
319
- frame_times = np.linspace(0, duration, num_frames, endpoint=False)
320
  frames = []
321
  cap = cv2.VideoCapture(video_path)
322
  if not cap.isOpened():
@@ -328,12 +324,12 @@ def visualize_action_lengths(
328
  ret, frame = cap.read()
329
  if ret:
330
  frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
331
- frame = cv2.resize(frame, (int(frame.shape[1] * 0.3), int(frame.shape[0] * 0.3))) # Smaller resize
332
  frames.append(frame)
333
  else:
334
  frames.append(np.ones((100, 100, 3), dtype=np.uint8) * 255)
335
  cap.release()
336
- fig = plt.figure(figsize=(num_frames * VIS_CONFIG['frame_scale_factor'], 6), constrained_layout=True)
337
  gs = fig.add_gridspec(3, num_frames, height_ratios=[3, 1, 1])
338
  for i, (t, frame) in enumerate(zip(frame_times, frames)):
339
  ax = fig.add_subplot(gs[0, i])
@@ -401,7 +397,7 @@ def train_one_epoch(opt, model, train_dataset, optimizer, warmup=False):
401
  train_loader = torch.utils.data.DataLoader(
402
  train_dataset,
403
  batch_size=opt['batch_size'], shuffle=True,
404
- num_workers=1, pin_memory=True, drop_last=True
405
  )
406
  epoch_cost = 0
407
  epoch_cost_cls = 0
@@ -416,7 +412,7 @@ def train_one_epoch(opt, model, train_dataset, optimizer, warmup=False):
416
  g['lr'] = n_iter * opt['lr'] / total_iter
417
  act_cls, act_reg, snip_cls = model(input_data.float().to(device))
418
  act_cls.register_hook(partial(cls_loss.collect_grad, cls_label))
419
- snip_cls.register_hook(partial_labels=snip_label)
420
  cost_reg = 0
421
  cost_cls = 0
422
  loss = cls_loss_func_(cls_loss, act_cls)
@@ -435,60 +431,12 @@ def train_one_epoch(opt, model, train_dataset, optimizer, warmup=False):
435
  optimizer.step()
436
  return n_iter, epoch_cost, epoch_cost_cls, epoch_cost_reg, epoch_cost_snip
437
 
438
- def eval_one_epoch(opt, model, test_dataset):
439
- device = torch.device("cpu")
440
- cls_loss, reg_loss, tot_loss, output_cls, output_reg, labels_cls, labels_reg, _, _ = eval_frame(opt, model, test_dataset)
441
- result_dict = eval_map_nms(opt, test_dataset, output_cls, output_reg, labels_cls, labels_reg)
442
- output_dict = {"version": "VERSION 1.3, "results": result_dict, "external_data": None}
443
- result_path = opt["result_file"].format(opt['exp"])
444
- os.makedirs(os.path.dirname(result_path), exist_ok=True)
445
- with open(result_path, "w") as outfile:
446
- json.dump(output_dict, outfile, indent=2)
447
- IoUmAP = evaluation_detection(opt, verbose=False)
448
- IoUmAP_5 = sum(IoUmAP[0:]) / len(IoUmAP[0:]) if IoUmAP else 0
449
- return cls_loss, reg_loss, tot_loss, IoUmAP_5
450
-
451
- def train(opt):
452
- writer = SummaryWriter()
453
- device = torch.device("cpu")
454
- model = MYNET(opt).to(device)
455
- rest_of_model_params = [param for name, param in model.named_parameters() if "history_unit" not in name]
456
- optimizer = optim.Adam([
457
- {'params': model.history_unit.parameters(), 'lr': 1e-6},
458
- {'params': rest_of_model_params}
459
- ], lr=opt["lr"], weight_decay=opt["weight_decay"])
460
- scheduler = optim.StepLR(optimizer, step_size=opt["lr_step"])
461
- train_dataset = VideoDataSet(opt, subset="training")
462
- test_dataset = VideoDataSet(opt, subset=opt['inference_subset'])
463
- warmup = False
464
- for n_epoch in range(opt['epoch']):
465
- if n_epoch >= 1:
466
- warmup = False
467
- n_iter, epoch_cost, epoch_cost_cls, epoch_cost_reg, epoch_cost_snip = train_one_epoch(opt, model, train_dataset, optimizer, warmup)
468
- writer.add_scalar('data/cost', epoch_cost / (n_iter + 1), n_epoch)
469
- print(f"Epoch {n_epoch}: Train Loss: {epoch_cost / (n_iter + 1):.4f}, cls: {epoch_cost_cls / (n_iter + 1):.4f}, reg: {epoch_cost_reg / (n_iter + 1):.4f}, snip: {epoch_cost_snip / (n_iter + 1):.4f}, lr: {optimizer.param_groups[-1]['lr']:.6f}")
470
- scheduler.step()
471
- model.eval()
472
- cls_loss, reg_loss, tot_loss, IoUmAP_5 = eval_one_epoch(opt, model, test_dataset)
473
- writer.add_scalar('data/mAP', IoUmAP_5, n_epoch)
474
- print(f"Epoch {n_epoch}: Test Loss: {tot_loss:.4f}, cls: {cls_loss:.4f}, reg: {reg_loss:.4f}, mAP: {IoUmAP_5:.4f}")
475
- state = {'epoch': n_epoch + 1, 'state_dict': model.state_dict()}
476
- checkpoint_path = os.path.join(opt["checkpoint_path"], f"{opt['exp']}_checkpoint_{n_epoch + 1}.pth.tar")
477
- os.makedirs(opt["checkpoint_path"], exist_ok=True)
478
- torch.save(state, checkpoint_path)
479
- if IoUmAP_5 > getattr(model, 'best_map', 0):
480
- model.best_map = IoUmAP_5
481
- torch.save(state, os.path.join(opt["checkpoint_path"], f"{opt['exp']}_ckp_best.pth.tar"))
482
- model.train()
483
- writer.close()
484
- return model.best_map
485
-
486
- def eval_frame_data(opt, model, dataset):
487
  device = torch.device("cpu")
488
  test_loader = torch.utils.data.DataLoader(
489
  dataset,
490
- batch_size=opt['batch_size'], shuffle=True,
491
- num_workers=1, pin_memory=True, drop_last=True
492
  )
493
  labels_cls = {video_name: [] for video_name in dataset.video_names}
494
  labels_reg = {video_name: [] for video_name in dataset.video_names}
@@ -499,11 +447,12 @@ def eval_frame_data(opt, model, dataset):
499
  epoch_cost = 0
500
  epoch_cost_cls = 0
501
  epoch_cost_reg = 0
502
- for n_iter, (input_data, cls_label, reg_label, _) in enumerate(tqdm(test_loader)):
 
503
  act_cls, act_reg, _ = model(input_data.float().to(device))
504
  cost_reg = 0
505
  cost_cls = 0
506
- loss = cls_loss_func(cls_label, act_cls)
507
  cost_cls = loss
508
  epoch_cost_cls += loss.detach().cpu().numpy()
509
  loss = regress_loss_func(reg_label, act_reg)
@@ -515,10 +464,10 @@ def eval_frame_data(opt, model, dataset):
515
  total_frames += input_data.size(0)
516
  for idx in range(input_data.size(0)):
517
  video_name, st, ed, data_idx = dataset.inputs[n_iter * opt['batch_size'] + idx]
518
- output_cls[video_name].append(act_cls[idx, :].detach().cpu().numpy())
519
- output_reg[video_name].append(act_reg[idx, :].detach().cpu().numpy())
520
- labels_cls[video_name].append(cls_label[idx, :].cpu().numpy())
521
- labels_reg[video_name].append(reg_label[idx, :].cpu().numpy())
522
  end_time = time.time()
523
  working_time = end_time - start_time
524
  for video_name in dataset.video_names:
@@ -547,7 +496,7 @@ def eval_map_nms(opt, dataset, output_cls, output_reg, labels_cls, labels_reg):
547
  reg_anc = output_reg[video_name][idx]
548
  proposal_anc_dict = []
549
  for anc_idx in range(len(anchors)):
550
- cls = np.argwhere(cls_anc[anc_idx][:-1] > opt['threshold']).reshape(-1)
551
  if len(cls) == 0:
552
  continue
553
  ed = idx + anchors[anc_idx] * reg_anc[anc_idx][0]
@@ -571,7 +520,7 @@ def eval_map_nms(opt, dataset, output_cls, output_reg, labels_cls, labels_reg):
571
  def eval_map_suppress(opt, dataset, output_cls, output_reg, labels_cls, labels_reg):
572
  device = torch.device("cpu")
573
  model = SuppressNet(opt).to(device)
574
- checkpoint_path = os.path.join(opt["checkpoint_path"], "ckp_best_suppress.pth.tar")
575
  if os.path.exists(checkpoint_path):
576
  checkpoint = torch.load(checkpoint_path, map_location=device)
577
  model.load_state_dict(checkpoint['state_dict'])
@@ -595,7 +544,7 @@ def eval_map_suppress(opt, dataset, output_cls, output_reg, labels_cls, labels_r
595
  reg_anc = output_reg[video_name][idx]
596
  proposal_anc_dict = []
597
  for anc_idx in range(len(anchors)):
598
- cls = np.argwhere(cls_anc[anc_idx][:-1] > opt['threshold']).reshape(-1)
599
  if len(cls) == 0:
600
  continue
601
  ed = idx + anchors[anc_idx] * reg_anc[anc_idx][0]
@@ -605,7 +554,7 @@ def eval_map_suppress(opt, dataset, output_cls, output_reg, labels_cls, labels_r
605
  label = cls[cidx]
606
  tmp_dict = {
607
  "segment": [float(st * frame_to_time / 100.0), float(ed * frame_to_time / 100.0)],
608
- "score": float(cls_anc][anc_idx][label]),
609
  "label": dataset.label_name[label],
610
  "gentime": float(idx * frame_to_time / 100.0)
611
  }
@@ -622,60 +571,13 @@ def eval_map_suppress(opt, dataset, output_cls, output_reg, labels_cls, labels_r
622
  for cls in range(num_class - 1):
623
  if suppress_conf[cls] > opt['sup_threshold']:
624
  for proposal in proposal_anc_dict:
625
- if proposal['label"] == dataset.label_name[cls]:
626
  if check_overlap_proposal(proposal_dict, proposal, overlapThresh=opt['soft_nms']) is None:
627
  proposal_dict.append(proposal)
628
  result_dict[video_name] = proposal_dict
629
  proposal_dict = []
630
  return result_dict
631
 
632
- def test_frame(opt, video_name=None):
633
- device = torch.device("cpu")
634
- model = MYNET(opt).to(device)
635
- checkpoint_path = os.path.join(opt["checkpoint_path"], "ckp_best.pth.tar")
636
- if not os.path.exists(checkpoint_path):
637
- print(f"[ERROR] Checkpoint {checkpoint_path} not found.")
638
- return 0, 0, 0
639
- checkpoint = torch.load(checkpoint_path, map_location=device)
640
- model.load_state_dict(checkpoint['state_dict'])
641
- model.eval()
642
- dataset = VideoDataSet(opt, subset=opt['inference_subset'], video_name=video_name)
643
- outfile = h5py.File(opt['output_file'].format(opt['exp']), 'w')
644
- cls_loss, reg_loss, tot_loss, output_cls, output_reg, labels_cls, labels_reg, working_time, total_frames = eval_frame(opt, model, dataset)
645
- print(f"testing loss: {tot_loss:.4f}, cls_loss: {cls_loss:.4f}, reg_loss: {reg_loss:.4f}")
646
- for video_name in dataset.video_names:
647
- o_cls = output_cls[video_name]
648
- o_reg = output_reg[video_name]
649
- l_cls = labels_cls[video_name]
650
- l_reg = labels_reg[video_name]
651
- dset_predcls = outfile.create_dataset(video_name + '/pred_cls', o_cls.shape, maxshape=o_cls.shape, chunks=True, dtype=np.float32)
652
- dset_predcls[:, :] = o_cls[:, :]
653
- dset_predreg = outfile.create_dataset(video_name + '/pred_reg', o_reg.shape, maxshape=o_reg.shape, chunks=True, dtype=np.float32)
654
- dset_predreg[:, :] = o_reg[:, :]
655
- dset_labelcls = outfile.create_dataset(video_name + '/label_cls', l_cls.shape, maxshape=l_cls.shape, chunks=True, dtype=np.float32)
656
- dset_labelcls[:, :] = l_cls[:, :]
657
- dset_labelreg = outfile.create_dataset(video_name + '/label_reg', l_reg.shape, maxshape=l_reg.shape, chunks=True, dtype=np.float32)
658
- dset_labelreg[:, :] = l_reg[:, :]
659
- outfile.close()
660
- print(f"[INFO] Working time: {working_time:.2f}s, {total_frames / working_time:.1f}fps, {total_frames} frames")
661
- return cls_loss, reg_loss, tot_loss
662
-
663
- def patch_attention(m):
664
- forward_orig = m.forward
665
- def wrap(*args, **kwargs):
666
- kwargs["need_weights"] = True
667
- kwargs["average_attn_weights"] = False
668
- return forward_orig(*args, **kwargs)
669
- m.forward = wrap
670
-
671
- class SaveOutput:
672
- def __init__(self):
673
- self.outputs = []
674
- def __call__(self, module, module_in, module_out):
675
- self.outputs.append(module_out[1])
676
- def clear(self):
677
- self.outputs = []
678
-
679
  def test(opt, video_name=None):
680
  device = torch.device("cpu")
681
  model = MYNET(opt).to(device)
@@ -701,12 +603,13 @@ def test(opt, video_name=None):
701
  with open(result_path, 'w') as f:
702
  json.dump(output_dict, f, indent=4)
703
  mAP = evaluation_detection(opt, verbose=False)
 
704
  if video_name:
705
  print(f"\n[INFO] Comparing Predicted and Ground Truth Actions for Video: {video_name}")
706
- anno_path = opt["annotations_path"]
707
  if not os.path.exists(anno_path):
708
  print(f"[ERROR] Annotation file {anno_path} not found. Skipping comparison.")
709
- return mAP, "", "", ""
710
  with open(anno_path, 'r') as f:
711
  anno_data = json.load(f)
712
  gt_annotations = anno_data['database'][video_name]['annotations']
@@ -723,7 +626,7 @@ def test(opt, video_name=None):
723
  'end': pred['segment'][1],
724
  'duration': pred['segment'][1] - pred['segment'][0],
725
  'score': pred['score']
726
- } for pred in result_dict[video_name]]
727
  matches = []
728
  iou_threshold = VIS_CONFIG['iou_threshold']
729
  used_gt_indices = set()
@@ -733,7 +636,7 @@ def test(opt, video_name=None):
733
  for gt_idx, gt in enumerate(gt_segments):
734
  if gt_idx in used_gt_indices:
735
  continue
736
- iou = calc_iou([pred['start'], pred['end']], [gt['start'], gt['end']])
737
  if iou > best_iou and iou >= iou_threshold:
738
  best_iou = iou
739
  best_gt_idx = gt_idx
@@ -778,7 +681,7 @@ def test(opt, video_name=None):
778
  comparison_text += f"- Total Predictions: {len(pred_segments)}\n"
779
  comparison_text += f"- Total Ground Truths: {len(gt_segments)}\n"
780
  comparison_text += f"- Matched Segments: {matched_count}\n"
781
- comparison_text += f"- Average Duration Difference: (s): {avg_duration_diff:.2f}s\n"
782
  comparison_text += f"- Average IoU (Matched): {avg_iou:.2f}\n"
783
  video_path = opt.get('video_path', '')
784
  viz_path = ""
@@ -791,7 +694,7 @@ def test(opt, video_name=None):
791
  video_path=video_path,
792
  duration=duration
793
  )
794
- video_out_path = annotate_action_with_video(
795
  video_id=video_name,
796
  pred_segments=pred_segments,
797
  gt_segments=gt_segments,
@@ -799,12 +702,12 @@ def test(opt, video_name=None):
799
  )
800
  else:
801
  print(f"[WARNING] Video {video_path} not found. Skipping visualization.")
802
- return mAP, comparison_text, viz_path, video_out_path
803
 
804
  def test_online(opt, video_name=None):
805
  device = torch.device("cpu")
806
  model = MYNET(opt).to(device)
807
- checkpoint_path = os.path.join(opt["checkpoint_path"], "ckp_best.pth.tar")
808
  if not os.path.exists(checkpoint_path):
809
  print(f"[ERROR] Checkpoint {checkpoint_path} not found.")
810
  return 0
@@ -812,7 +715,7 @@ def test_online(opt, video_name=None):
812
  model.load_state_dict(checkpoint['state_dict'])
813
  model.eval()
814
  sup_model = SuppressNet(opt).to(device)
815
- sup_checkpoint_path = os.path.join(opt["checkpoint_path"], "ckp_best_suppress.pth.tar")
816
  if os.path.exists(sup_checkpoint_path):
817
  checkpoint = torch.load(sup_checkpoint_path, map_location=device)
818
  sup_model.load_state_dict(checkpoint['state_dict'])
@@ -824,8 +727,8 @@ def test_online(opt, video_name=None):
824
  dataset,
825
  batch_size=1,
826
  shuffle=False,
827
- num_workers=1,
828
- pin_memory=True
829
  )
830
  result_dict = {}
831
  proposal_dict = []
@@ -844,7 +747,7 @@ def test_online(opt, video_name=None):
844
  for idx in range(duration):
845
  total_frames += 1
846
  input_queue[:-1, :] = input_queue[1:, :].clone()
847
- input_queue[-1, :] = dataset._get_base_data(video_name, idx)
848
  minput = input_queue.unsqueeze(0).to(device)
849
  act_cls, act_reg, _ = model(minput)
850
  act_cls = torch.softmax(act_cls, dim=-1)
@@ -852,7 +755,7 @@ def test_online(opt, video_name=None):
852
  reg_anc = act_reg.squeeze(0).detach().cpu().numpy()
853
  proposal_anc_dict = []
854
  for anc_idx in range(len(anchors)):
855
- cls = np.argwhere(cls_anc[anc_idx][:-1] > opt['threshold']).reshape(-1)
856
  if len(cls) == 0:
857
  continue
858
  ed = idx + anchors[anc_idx] * reg_anc[anc_idx][0]
@@ -893,7 +796,8 @@ def test_online(opt, video_name=None):
893
  with open(result_path, "w") as f:
894
  json.dump(output_dict, f, indent=4)
895
  mAP = evaluation_detection(opt, verbose=False)
896
- return mAP
 
897
 
898
  def main(opt, video_name=None):
899
  max_perf = 0
@@ -904,8 +808,6 @@ def main(opt, video_name=None):
904
  elif opt['mode'] == 'test':
905
  max_perf, comparison_text, viz_path, video_out_path = test(opt, video_name=video_name)
906
  return max_perf, comparison_text, viz_path, video_out_path
907
- elif opt['mode'] == 'test_frame':
908
- max_perf = test_frame(opt, video_name=video_name)
909
  elif opt['mode'] == 'test_online':
910
  max_perf = test_online(opt, video_name=video_name)
911
  elif opt['mode'] == 'eval':
@@ -924,13 +826,23 @@ def gradio_interface(video):
924
  opt_dict['mode'] = 'test'
925
  opt_dict['video_name'] = video_name
926
  opt_dict['video_path'] = video
927
- opt_dict['feature_dir'] = os.path.join(os.getcwd(), 'data', 'features')
 
928
  opt_dict['checkpoint_path'] = os.path.join(os.getcwd(), 'checkpoint')
929
- opt_dict['result_file'] = os.path.join(os.getcwd(), 'results', f"result_{{}}.json")
930
- opt_dict['annotations_path'] = os.path.join(os.getcwd(), 'data', 'annotations.json')
931
- opt_dict['frame_result_file'] = os.path.join(os.getcwd(), 'results', f"frame_result_{{}}.h5")
 
 
932
  opt_dict['batch_size'] = 1
 
 
933
  opt_dict['anchors'] = [int(item) for item in opt_dict['anchors'].split(',')]
 
 
 
 
 
934
  mAP, comparison_text, viz_path, video_out_path = main(opt_dict, video_name=video_name)
935
  return viz_path, video_out_path, f"mAP: {mAP:.4f}\n\n{comparison_text}"
936
 
@@ -938,12 +850,23 @@ if __name__ == "__main__":
938
  opt = opts.parse_opt()
939
  opt = vars(opt)
940
  opt['checkpoint_path'] = os.path.join(os.getcwd(), 'checkpoint')
941
- opt['result_file'] = os.path.join(os.getcwd(), 'results', f"result_{{}}.json")
942
- opt['frame_result_file'] = os.path.join(os.getcwd(), 'results', f"frame_result_{{}}.h5")
943
- opt['feature_dir'] = os.path.join(os.getcwd(), 'data', 'features')
 
 
 
 
 
 
 
 
 
 
 
944
  os.makedirs(opt["checkpoint_path"], exist_ok=True)
945
  os.makedirs(os.path.dirname(opt["result_file"].format(opt['exp'])), exist_ok=True)
946
- os.makedirs(opt["feature_dir"], exist_ok=True)
947
  with open(os.path.join(opt["checkpoint_path"], f"{opt['exp']}_opts.json"), "w") as f:
948
  json.dump(opt, f, indent=4)
949
  if opt['seed'] >= 0:
@@ -961,7 +884,7 @@ if __name__ == "__main__":
961
  gr.Textbox(label="Results and mAP")
962
  ],
963
  title="Action Detection Model",
964
- description="Upload a video to detect actions using pre-extracted I3D features. View visualizations and performance metrics."
965
  )
966
  iface.launch()
967
  else:
 
45
  'frame_highlight_gt': 'red',
46
  'frame_highlight_pred': 'black',
47
  'iou_threshold': 0.3,
48
+ 'frame_scale_factor': 0.3,
49
  'video_text_scale': 0.5,
50
  'video_gt_text_color': (180, 119, 31), # BGR
51
  'video_pred_text_color': (14, 127, 255), # BGR
 
55
  'video_pred_text_y': 0.45,
56
  'video_gt_text_y': 0.55,
57
  'video_footer_height': 150,
58
+ 'video_gt_bar_y': 0.2,
59
+ 'video_pred_bar_y': 0.5,
60
  'video_bar_height': 0.15,
61
  'video_bar_text_scale': 0.7,
62
  'min_segment_duration': 1.0,
 
108
  (185, 218, 255), (255, 204, 204), (193, 182, 255), (201, 252, 189),
109
  (144, 128, 112), (112, 25, 25), (102, 51, 102), (0, 128, 128), (171, 71, 0)
110
  ]
111
+ action_labels = set(seg['label'] for seg in gt_segments).union(set(seg['label'] for seg in pred_segments))
112
  action_color_map = {label: color_palette[i % len(color_palette)] for i, label in enumerate(action_labels)}
113
  gt_color_rgb = (gt_text_color[2], gt_text_color[1], gt_text_color[0])
114
  pred_color_rgb = (pred_text_color[2], pred_text_color[1], pred_text_color[0])
 
128
  except IOError:
129
  font = None
130
  bar_font = None
131
+ window_size = VIS_CONFIG['scroll_window_duration']
132
  num_windows = int(np.ceil(duration / window_size))
133
  text_bar_gap = 48
134
+ text_x = VIS_CONFIG['video_bar_label_x']
135
  frame_idx = 0
136
  written_frames = 0
137
  while cap.isOpened():
 
148
  window_duration = window_end - window_start
149
  window_timestamp = timestamp - window_start
150
  gt_labels = [seg['label'] for seg in gt_segments if seg['start'] <= timestamp <= seg['end']]
151
+ gt_text = "GT: " + ", ".join(gt_labels) if gt_labels else "GT: None"
152
  pred_labels = [seg['label'] for seg in pred_segments if seg['start'] <= timestamp <= seg['end']]
153
+ pred_text = "Pred: " + ", ".join(pred_labels) if pred_labels else "Pred: None"
154
  footer_y = frame_height
155
+ gt_bar_y = footer_y + int(VIS_CONFIG['video_gt_bar_y'] * footer_height)
156
+ pred_bar_y = footer_y + int(VIS_CONFIG['video_pred_bar_y'] * footer_height)
157
  bar_height = int(VIS_CONFIG['video_bar_height'] * footer_height)
158
  if font:
159
  gt_text_bbox = bar_font.getbbox("GT")
 
204
  frame_text_bbox = draw.textbbox((0, 0), frame_info, font=font)
205
  frame_text_width = frame_text_bbox[2] - frame_text_bbox[0]
206
  frame_text_x = (frame_width - frame_text_width) // 2
207
+ draw.text((frame_text_x, int(frame_height * VIS_CONFIG['video_frame_text_y'])), frame_info, font=font, fill=(0, 0, 0))
208
  window_info = f"{window_start:.1f}s - {window_end:.1f}s"
209
  window_text_bbox = draw.textbbox((0, 0), window_info, font=bar_font)
210
  window_text_width = window_text_bbox[2] - window_text_bbox[0]
211
  window_text_x = (frame_width - window_text_width) // 2
212
  draw.text((window_text_x, footer_y + 10), window_info, font=bar_font, fill=(0, 0, 0))
 
 
 
 
 
 
213
  draw.text((text_x, gt_bar_y + bar_height // 2), "GT", font=bar_font, fill=gt_color_rgb)
214
  draw.text((text_x, pred_bar_y + bar_height // 2), "Pred", font=bar_font, fill=pred_color_rgb)
215
+ gt_y = int(frame_height * VIS_CONFIG['video_gt_text_y'])
216
+ pred_y = int(frame_height * VIS_CONFIG['video_pred_text_y'])
217
+ draw.text((10, gt_y), gt_text, font=font, fill=gt_color_rgb)
218
+ draw.text((10, pred_y), pred_text, font=font, fill=pred_color_rgb)
219
  extended_frame = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
220
  else:
221
  frame_info = f"Frame: {frame_idx} | FPS: {fps:.2f}"
 
224
  cv2.putText(
225
  extended_frame,
226
  frame_info,
227
+ (frame_text_x, int(frame_height * VIS_CONFIG['video_frame_text_y']) + 20),
228
  cv2.FONT_HERSHEY_SIMPLEX,
229
  text_scale,
230
  (0, 0, 0),
 
244
  1,
245
  cv2.LINE_AA
246
  )
247
+ cv2.putText(
248
+ extended_frame,
249
+ gt_text,
250
+ (10, int(frame_height * VIS_CONFIG['video_gt_text_y'])),
251
+ cv2.FONT_HERSHEY_SIMPLEX,
252
+ text_scale,
253
+ gt_text_color,
254
+ text_thickness,
255
+ cv2.LINE_AA
256
+ )
257
+ cv2.putText(
258
+ extended_frame,
259
+ pred_text,
260
+ (10, int(frame_height * VIS_CONFIG['video_pred_text_y'])),
261
+ cv2.FONT_HERSHEY_SIMPLEX,
262
+ text_scale,
263
+ pred_text_color,
264
+ text_thickness,
265
+ cv2.LINE_AA
266
+ )
 
 
267
  cv2.putText(
268
  extended_frame,
269
  "GT",
 
312
  if num_frames > VIS_CONFIG['max_frames']:
313
  frame_interval = duration / (VIS_CONFIG['max_frames'] - 1)
314
  num_frames = VIS_CONFIG['max_frames']
315
+ frame_times = np.linspace(0, duration, num_frames, endpoint=True)
316
  frames = []
317
  cap = cv2.VideoCapture(video_path)
318
  if not cap.isOpened():
 
324
  ret, frame = cap.read()
325
  if ret:
326
  frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
327
+ frame = cv2.resize(frame, (int(frame.shape[1] * VIS_CONFIG['frame_scale_factor']), int(frame.shape[0] * VIS_CONFIG['frame_scale_factor'])))
328
  frames.append(frame)
329
  else:
330
  frames.append(np.ones((100, 100, 3), dtype=np.uint8) * 255)
331
  cap.release()
332
+ fig = plt.figure(figsize=(num_frames * 2, 6), constrained_layout=True)
333
  gs = fig.add_gridspec(3, num_frames, height_ratios=[3, 1, 1])
334
  for i, (t, frame) in enumerate(zip(frame_times, frames)):
335
  ax = fig.add_subplot(gs[0, i])
 
397
  train_loader = torch.utils.data.DataLoader(
398
  train_dataset,
399
  batch_size=opt['batch_size'], shuffle=True,
400
+ num_workers=0, pin_memory=False, drop_last=True
401
  )
402
  epoch_cost = 0
403
  epoch_cost_cls = 0
 
412
  g['lr'] = n_iter * opt['lr'] / total_iter
413
  act_cls, act_reg, snip_cls = model(input_data.float().to(device))
414
  act_cls.register_hook(partial(cls_loss.collect_grad, cls_label))
415
+ snip_cls.register_hook(lambda grad: snip_loss.collect_grad(grad, snip_label))
416
  cost_reg = 0
417
  cost_cls = 0
418
  loss = cls_loss_func_(cls_loss, act_cls)
 
431
  optimizer.step()
432
  return n_iter, epoch_cost, epoch_cost_cls, epoch_cost_reg, epoch_cost_snip
433
 
434
+ def eval_frame(opt, model, dataset):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
435
  device = torch.device("cpu")
436
  test_loader = torch.utils.data.DataLoader(
437
  dataset,
438
+ batch_size=opt['batch_size'], shuffle=False,
439
+ num_workers=0, pin_memory=False, drop_last=False
440
  )
441
  labels_cls = {video_name: [] for video_name in dataset.video_names}
442
  labels_reg = {video_name: [] for video_name in dataset.video_names}
 
447
  epoch_cost = 0
448
  epoch_cost_cls = 0
449
  epoch_cost_reg = 0
450
+ cls_loss_fn = MultiCrossEntropyLoss(focal=True)
451
+ for n_iter, (input_data, cls_label, reg_label, snip_label) in enumerate(tqdm(test_loader)):
452
  act_cls, act_reg, _ = model(input_data.float().to(device))
453
  cost_reg = 0
454
  cost_cls = 0
455
+ loss = cls_loss_func_(cls_loss_fn, act_cls)
456
  cost_cls = loss
457
  epoch_cost_cls += loss.detach().cpu().numpy()
458
  loss = regress_loss_func(reg_label, act_reg)
 
464
  total_frames += input_data.size(0)
465
  for idx in range(input_data.size(0)):
466
  video_name, st, ed, data_idx = dataset.inputs[n_iter * opt['batch_size'] + idx]
467
+ output_cls[video_name].append(act_cls[idx].detach().cpu().numpy())
468
+ output_reg[video_name].append(act_reg[idx].detach().cpu().numpy())
469
+ labels_cls[video_name].append(cls_label[idx].cpu().numpy())
470
+ labels_reg[video_name].append(reg_label[idx].cpu().numpy())
471
  end_time = time.time()
472
  working_time = end_time - start_time
473
  for video_name in dataset.video_names:
 
496
  reg_anc = output_reg[video_name][idx]
497
  proposal_anc_dict = []
498
  for anc_idx in range(len(anchors)):
499
+ cls = np.argwhere(cls_anc[anc_idx][:-1] > threshold).reshape(-1)
500
  if len(cls) == 0:
501
  continue
502
  ed = idx + anchors[anc_idx] * reg_anc[anc_idx][0]
 
520
  def eval_map_suppress(opt, dataset, output_cls, output_reg, labels_cls, labels_reg):
521
  device = torch.device("cpu")
522
  model = SuppressNet(opt).to(device)
523
+ checkpoint_path = os.path.join(opt["checkpoint_path"], f"ckp_best_suppress.pth.tar")
524
  if os.path.exists(checkpoint_path):
525
  checkpoint = torch.load(checkpoint_path, map_location=device)
526
  model.load_state_dict(checkpoint['state_dict'])
 
544
  reg_anc = output_reg[video_name][idx]
545
  proposal_anc_dict = []
546
  for anc_idx in range(len(anchors)):
547
+ cls = np.argwhere(cls_anc[anc_idx][:-1] > threshold).reshape(-1)
548
  if len(cls) == 0:
549
  continue
550
  ed = idx + anchors[anc_idx] * reg_anc[anc_idx][0]
 
554
  label = cls[cidx]
555
  tmp_dict = {
556
  "segment": [float(st * frame_to_time / 100.0), float(ed * frame_to_time / 100.0)],
557
+ "score": float(cls_anc[anc_idx][label]),
558
  "label": dataset.label_name[label],
559
  "gentime": float(idx * frame_to_time / 100.0)
560
  }
 
571
  for cls in range(num_class - 1):
572
  if suppress_conf[cls] > opt['sup_threshold']:
573
  for proposal in proposal_anc_dict:
574
+ if proposal['label'] == dataset.label_name[cls]:
575
  if check_overlap_proposal(proposal_dict, proposal, overlapThresh=opt['soft_nms']) is None:
576
  proposal_dict.append(proposal)
577
  result_dict[video_name] = proposal_dict
578
  proposal_dict = []
579
  return result_dict
580
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
581
  def test(opt, video_name=None):
582
  device = torch.device("cpu")
583
  model = MYNET(opt).to(device)
 
603
  with open(result_path, 'w') as f:
604
  json.dump(output_dict, f, indent=4)
605
  mAP = evaluation_detection(opt, verbose=False)
606
+ mAP_value = sum(mAP) / len(mAP) if mAP else 0
607
  if video_name:
608
  print(f"\n[INFO] Comparing Predicted and Ground Truth Actions for Video: {video_name}")
609
+ anno_path = opt["video_anno"].format(opt["split"])
610
  if not os.path.exists(anno_path):
611
  print(f"[ERROR] Annotation file {anno_path} not found. Skipping comparison.")
612
+ return mAP_value, "", "", ""
613
  with open(anno_path, 'r') as f:
614
  anno_data = json.load(f)
615
  gt_annotations = anno_data['database'][video_name]['annotations']
 
626
  'end': pred['segment'][1],
627
  'duration': pred['segment'][1] - pred['segment'][0],
628
  'score': pred['score']
629
+ } for pred in result_dict.get(video_name, [])]
630
  matches = []
631
  iou_threshold = VIS_CONFIG['iou_threshold']
632
  used_gt_indices = set()
 
636
  for gt_idx, gt in enumerate(gt_segments):
637
  if gt_idx in used_gt_indices:
638
  continue
639
+ iou = calc_iou([pred['end'], pred['duration']], [gt['end'], gt['duration']])
640
  if iou > best_iou and iou >= iou_threshold:
641
  best_iou = iou
642
  best_gt_idx = gt_idx
 
681
  comparison_text += f"- Total Predictions: {len(pred_segments)}\n"
682
  comparison_text += f"- Total Ground Truths: {len(gt_segments)}\n"
683
  comparison_text += f"- Matched Segments: {matched_count}\n"
684
+ comparison_text += f"- Average Duration Difference (s): {avg_duration_diff:.2f}\n"
685
  comparison_text += f"- Average IoU (Matched): {avg_iou:.2f}\n"
686
  video_path = opt.get('video_path', '')
687
  viz_path = ""
 
694
  video_path=video_path,
695
  duration=duration
696
  )
697
+ video_out_path = annotate_video_with_actions(
698
  video_id=video_name,
699
  pred_segments=pred_segments,
700
  gt_segments=gt_segments,
 
702
  )
703
  else:
704
  print(f"[WARNING] Video {video_path} not found. Skipping visualization.")
705
+ return mAP_value, comparison_text, viz_path, video_out_path
706
 
707
  def test_online(opt, video_name=None):
708
  device = torch.device("cpu")
709
  model = MYNET(opt).to(device)
710
+ checkpoint_path = os.path.join(opt["checkpoint_path"], f"{opt['exp']}_ckp_best.pth.tar")
711
  if not os.path.exists(checkpoint_path):
712
  print(f"[ERROR] Checkpoint {checkpoint_path} not found.")
713
  return 0
 
715
  model.load_state_dict(checkpoint['state_dict'])
716
  model.eval()
717
  sup_model = SuppressNet(opt).to(device)
718
+ sup_checkpoint_path = os.path.join(opt["checkpoint_path"], f"ckp_best_suppress.pth.tar")
719
  if os.path.exists(sup_checkpoint_path):
720
  checkpoint = torch.load(sup_checkpoint_path, map_location=device)
721
  sup_model.load_state_dict(checkpoint['state_dict'])
 
727
  dataset,
728
  batch_size=1,
729
  shuffle=False,
730
+ num_workers=0,
731
+ pin_memory=False
732
  )
733
  result_dict = {}
734
  proposal_dict = []
 
747
  for idx in range(duration):
748
  total_frames += 1
749
  input_queue[:-1, :] = input_queue[1:, :].clone()
750
+ input_queue[-1, :] = dataset._get_base_data(video_name, idx, idx + 1).squeeze(0)
751
  minput = input_queue.unsqueeze(0).to(device)
752
  act_cls, act_reg, _ = model(minput)
753
  act_cls = torch.softmax(act_cls, dim=-1)
 
755
  reg_anc = act_reg.squeeze(0).detach().cpu().numpy()
756
  proposal_anc_dict = []
757
  for anc_idx in range(len(anchors)):
758
+ cls = np.argwhere(cls_anc[anc_idx][:-1] > threshold).reshape(-1)
759
  if len(cls) == 0:
760
  continue
761
  ed = idx + anchors[anc_idx] * reg_anc[anc_idx][0]
 
796
  with open(result_path, "w") as f:
797
  json.dump(output_dict, f, indent=4)
798
  mAP = evaluation_detection(opt, verbose=False)
799
+ mAP_value = sum(mAP) / len(mAP) if mAP else 0
800
+ return mAP_value
801
 
802
  def main(opt, video_name=None):
803
  max_perf = 0
 
808
  elif opt['mode'] == 'test':
809
  max_perf, comparison_text, viz_path, video_out_path = test(opt, video_name=video_name)
810
  return max_perf, comparison_text, viz_path, video_out_path
 
 
811
  elif opt['mode'] == 'test_online':
812
  max_perf = test_online(opt, video_name=video_name)
813
  elif opt['mode'] == 'eval':
 
826
  opt_dict['mode'] = 'test'
827
  opt_dict['video_name'] = video_name
828
  opt_dict['video_path'] = video
829
+ opt_dict['video_anno'] = os.path.join(os.getcwd(), 'data', 'annotations.json')
830
+ opt_dict['video_feature_all_test'] = os.path.join(os.getcwd(), 'data', 'features') + os.sep
831
  opt_dict['checkpoint_path'] = os.path.join(os.getcwd(), 'checkpoint')
832
+ opt_dict['result_file'] = os.path.join(os.getcwd(), 'results', 'result_{}.json')
833
+ opt_dict['frame_result_file'] = os.path.join(os.getcwd(), 'results', 'frame_result_{}.h5')
834
+ opt_dict['video_len_file'] = os.path.join(os.getcwd(), 'data', 'video_len_{}.json')
835
+ opt_dict['proposal_label_file'] = os.path.join(os.getcwd(), 'data', 'proposal_label_{}.h5')
836
+ opt_dict['suppress_label_file'] = os.path.join(os.getcwd(), 'data', 'suppress_label_{}.h5')
837
  opt_dict['batch_size'] = 1
838
+ opt_dict['data_format'] = 'npz_i3d'
839
+ opt_dict['rgb_only'] = False
840
  opt_dict['anchors'] = [int(item) for item in opt_dict['anchors'].split(',')]
841
+ opt_dict['predefined_fps'] = 30 # Adjust if needed
842
+ opt_dict['split'] = 'test'
843
+ opt_dict['setup'] = 'default'
844
+ opt_dict['data_rescale'] = 1.0
845
+ opt_dict['pos_threshold'] = 0.5
846
  mAP, comparison_text, viz_path, video_out_path = main(opt_dict, video_name=video_name)
847
  return viz_path, video_out_path, f"mAP: {mAP:.4f}\n\n{comparison_text}"
848
 
 
850
  opt = opts.parse_opt()
851
  opt = vars(opt)
852
  opt['checkpoint_path'] = os.path.join(os.getcwd(), 'checkpoint')
853
+ opt['result_file'] = os.path.join(os.getcwd(), 'results', 'result_{}.json')
854
+ opt['frame_result_file'] = os.path.join(os.getcwd(), 'results', 'frame_result_{}.h5')
855
+ opt['video_anno'] = os.path.join(os.getcwd(), 'data', 'annotations.json')
856
+ opt['video_feature_all_test'] = os.path.join(os.getcwd(), 'data', 'features') + os.sep
857
+ opt['video_len_file'] = os.path.join(os.getcwd(), 'data', 'video_len_{}.json')
858
+ opt['proposal_label_file'] = os.path.join(os.getcwd(), 'data', 'proposal_label_{}.h5')
859
+ opt['suppress_label_file'] = os.path.join(os.getcwd(), 'data', 'suppress_label_{}.h5')
860
+ opt['data_format'] = 'npz_i3d'
861
+ opt['rgb_only'] = False
862
+ opt['predefined_fps'] = 30
863
+ opt['split'] = 'test'
864
+ opt['setup'] = 'default'
865
+ opt['data_rescale'] = 1.0
866
+ opt['pos_threshold'] = 0.5
867
  os.makedirs(opt["checkpoint_path"], exist_ok=True)
868
  os.makedirs(os.path.dirname(opt["result_file"].format(opt['exp'])), exist_ok=True)
869
+ os.makedirs(os.path.dirname(opt["video_anno"]), exist_ok=True)
870
  with open(os.path.join(opt["checkpoint_path"], f"{opt['exp']}_opts.json"), "w") as f:
871
  json.dump(opt, f, indent=4)
872
  if opt['seed'] >= 0:
 
884
  gr.Textbox(label="Results and mAP")
885
  ],
886
  title="Action Detection Model",
887
+ description="Upload a video to detect actions using pre-extracted I3D features. Ensure a corresponding .npz file exists in data/features/. View visualizations and performance metrics."
888
  )
889
  iface.launch()
890
  else: