Update main.py
Browse files
main.py
CHANGED
|
@@ -45,7 +45,7 @@ VIS_CONFIG = {
|
|
| 45 |
'frame_highlight_gt': 'red',
|
| 46 |
'frame_highlight_pred': 'black',
|
| 47 |
'iou_threshold': 0.3,
|
| 48 |
-
'frame_scale_factor': 0.3,
|
| 49 |
'video_text_scale': 0.5,
|
| 50 |
'video_gt_text_color': (180, 119, 31), # BGR
|
| 51 |
'video_pred_text_color': (14, 127, 255), # BGR
|
|
@@ -55,8 +55,8 @@ VIS_CONFIG = {
|
|
| 55 |
'video_pred_text_y': 0.45,
|
| 56 |
'video_gt_text_y': 0.55,
|
| 57 |
'video_footer_height': 150,
|
| 58 |
-
'video_gt_bar_y': 0.
|
| 59 |
-
'video_pred_bar_y': 0.
|
| 60 |
'video_bar_height': 0.15,
|
| 61 |
'video_bar_text_scale': 0.7,
|
| 62 |
'min_segment_duration': 1.0,
|
|
@@ -108,7 +108,7 @@ def annotate_video_with_actions(
|
|
| 108 |
(185, 218, 255), (255, 204, 204), (193, 182, 255), (201, 252, 189),
|
| 109 |
(144, 128, 112), (112, 25, 25), (102, 51, 102), (0, 128, 128), (171, 71, 0)
|
| 110 |
]
|
| 111 |
-
action_labels = set(seg['label'] for seg in gt_segments).union(seg['label'] for seg in pred_segments)
|
| 112 |
action_color_map = {label: color_palette[i % len(color_palette)] for i, label in enumerate(action_labels)}
|
| 113 |
gt_color_rgb = (gt_text_color[2], gt_text_color[1], gt_text_color[0])
|
| 114 |
pred_color_rgb = (pred_text_color[2], pred_text_color[1], pred_text_color[0])
|
|
@@ -128,10 +128,10 @@ def annotate_video_with_actions(
|
|
| 128 |
except IOError:
|
| 129 |
font = None
|
| 130 |
bar_font = None
|
| 131 |
-
window_size =
|
| 132 |
num_windows = int(np.ceil(duration / window_size))
|
| 133 |
text_bar_gap = 48
|
| 134 |
-
text_x =
|
| 135 |
frame_idx = 0
|
| 136 |
written_frames = 0
|
| 137 |
while cap.isOpened():
|
|
@@ -148,12 +148,12 @@ def annotate_video_with_actions(
|
|
| 148 |
window_duration = window_end - window_start
|
| 149 |
window_timestamp = timestamp - window_start
|
| 150 |
gt_labels = [seg['label'] for seg in gt_segments if seg['start'] <= timestamp <= seg['end']]
|
| 151 |
-
gt_text = "GT: " + ", ".join(gt_labels) if gt_labels else ""
|
| 152 |
pred_labels = [seg['label'] for seg in pred_segments if seg['start'] <= timestamp <= seg['end']]
|
| 153 |
-
pred_text = "Pred: " + ", ".join(pred_labels) if pred_labels else ""
|
| 154 |
footer_y = frame_height
|
| 155 |
-
gt_bar_y = footer_y + int(
|
| 156 |
-
pred_bar_y = footer_y + int(
|
| 157 |
bar_height = int(VIS_CONFIG['video_bar_height'] * footer_height)
|
| 158 |
if font:
|
| 159 |
gt_text_bbox = bar_font.getbbox("GT")
|
|
@@ -204,20 +204,18 @@ def annotate_video_with_actions(
|
|
| 204 |
frame_text_bbox = draw.textbbox((0, 0), frame_info, font=font)
|
| 205 |
frame_text_width = frame_text_bbox[2] - frame_text_bbox[0]
|
| 206 |
frame_text_x = (frame_width - frame_text_width) // 2
|
| 207 |
-
draw.text((frame_text_x,
|
| 208 |
window_info = f"{window_start:.1f}s - {window_end:.1f}s"
|
| 209 |
window_text_bbox = draw.textbbox((0, 0), window_info, font=bar_font)
|
| 210 |
window_text_width = window_text_bbox[2] - window_text_bbox[0]
|
| 211 |
window_text_x = (frame_width - window_text_width) // 2
|
| 212 |
draw.text((window_text_x, footer_y + 10), window_info, font=bar_font, fill=(0, 0, 0))
|
| 213 |
-
if gt_text:
|
| 214 |
-
gt_y = int(frame_height * VIS_CONFIG['video_gt_text_y'])
|
| 215 |
-
draw.text((10, gt_y), gt_text, font=font, fill=gt_color_rgb)
|
| 216 |
-
if pred_text:
|
| 217 |
-
pred_y = int(frame_height * VIS_CONFIG['video_pred_text_y'])
|
| 218 |
-
draw.text((10, pred_y), pred_text, font=font, fill=pred_color_rgb)
|
| 219 |
draw.text((text_x, gt_bar_y + bar_height // 2), "GT", font=bar_font, fill=gt_color_rgb)
|
| 220 |
draw.text((text_x, pred_bar_y + bar_height // 2), "Pred", font=bar_font, fill=pred_color_rgb)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
extended_frame = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
|
| 222 |
else:
|
| 223 |
frame_info = f"Frame: {frame_idx} | FPS: {fps:.2f}"
|
|
@@ -226,7 +224,7 @@ def annotate_video_with_actions(
|
|
| 226 |
cv2.putText(
|
| 227 |
extended_frame,
|
| 228 |
frame_info,
|
| 229 |
-
(frame_text_x,
|
| 230 |
cv2.FONT_HERSHEY_SIMPLEX,
|
| 231 |
text_scale,
|
| 232 |
(0, 0, 0),
|
|
@@ -246,28 +244,26 @@ def annotate_video_with_actions(
|
|
| 246 |
1,
|
| 247 |
cv2.LINE_AA
|
| 248 |
)
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
cv2.LINE_AA
|
| 270 |
-
)
|
| 271 |
cv2.putText(
|
| 272 |
extended_frame,
|
| 273 |
"GT",
|
|
@@ -316,7 +312,7 @@ def visualize_action_lengths(
|
|
| 316 |
if num_frames > VIS_CONFIG['max_frames']:
|
| 317 |
frame_interval = duration / (VIS_CONFIG['max_frames'] - 1)
|
| 318 |
num_frames = VIS_CONFIG['max_frames']
|
| 319 |
-
frame_times = np.linspace(0, duration, num_frames, endpoint=
|
| 320 |
frames = []
|
| 321 |
cap = cv2.VideoCapture(video_path)
|
| 322 |
if not cap.isOpened():
|
|
@@ -328,12 +324,12 @@ def visualize_action_lengths(
|
|
| 328 |
ret, frame = cap.read()
|
| 329 |
if ret:
|
| 330 |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 331 |
-
frame = cv2.resize(frame, (int(frame.shape[1] *
|
| 332 |
frames.append(frame)
|
| 333 |
else:
|
| 334 |
frames.append(np.ones((100, 100, 3), dtype=np.uint8) * 255)
|
| 335 |
cap.release()
|
| 336 |
-
fig = plt.figure(figsize=(num_frames *
|
| 337 |
gs = fig.add_gridspec(3, num_frames, height_ratios=[3, 1, 1])
|
| 338 |
for i, (t, frame) in enumerate(zip(frame_times, frames)):
|
| 339 |
ax = fig.add_subplot(gs[0, i])
|
|
@@ -401,7 +397,7 @@ def train_one_epoch(opt, model, train_dataset, optimizer, warmup=False):
|
|
| 401 |
train_loader = torch.utils.data.DataLoader(
|
| 402 |
train_dataset,
|
| 403 |
batch_size=opt['batch_size'], shuffle=True,
|
| 404 |
-
num_workers=
|
| 405 |
)
|
| 406 |
epoch_cost = 0
|
| 407 |
epoch_cost_cls = 0
|
|
@@ -416,7 +412,7 @@ def train_one_epoch(opt, model, train_dataset, optimizer, warmup=False):
|
|
| 416 |
g['lr'] = n_iter * opt['lr'] / total_iter
|
| 417 |
act_cls, act_reg, snip_cls = model(input_data.float().to(device))
|
| 418 |
act_cls.register_hook(partial(cls_loss.collect_grad, cls_label))
|
| 419 |
-
snip_cls.register_hook(
|
| 420 |
cost_reg = 0
|
| 421 |
cost_cls = 0
|
| 422 |
loss = cls_loss_func_(cls_loss, act_cls)
|
|
@@ -435,60 +431,12 @@ def train_one_epoch(opt, model, train_dataset, optimizer, warmup=False):
|
|
| 435 |
optimizer.step()
|
| 436 |
return n_iter, epoch_cost, epoch_cost_cls, epoch_cost_reg, epoch_cost_snip
|
| 437 |
|
| 438 |
-
def
|
| 439 |
-
device = torch.device("cpu")
|
| 440 |
-
cls_loss, reg_loss, tot_loss, output_cls, output_reg, labels_cls, labels_reg, _, _ = eval_frame(opt, model, test_dataset)
|
| 441 |
-
result_dict = eval_map_nms(opt, test_dataset, output_cls, output_reg, labels_cls, labels_reg)
|
| 442 |
-
output_dict = {"version": "VERSION 1.3, "results": result_dict, "external_data": None}
|
| 443 |
-
result_path = opt["result_file"].format(opt['exp"])
|
| 444 |
-
os.makedirs(os.path.dirname(result_path), exist_ok=True)
|
| 445 |
-
with open(result_path, "w") as outfile:
|
| 446 |
-
json.dump(output_dict, outfile, indent=2)
|
| 447 |
-
IoUmAP = evaluation_detection(opt, verbose=False)
|
| 448 |
-
IoUmAP_5 = sum(IoUmAP[0:]) / len(IoUmAP[0:]) if IoUmAP else 0
|
| 449 |
-
return cls_loss, reg_loss, tot_loss, IoUmAP_5
|
| 450 |
-
|
| 451 |
-
def train(opt):
|
| 452 |
-
writer = SummaryWriter()
|
| 453 |
-
device = torch.device("cpu")
|
| 454 |
-
model = MYNET(opt).to(device)
|
| 455 |
-
rest_of_model_params = [param for name, param in model.named_parameters() if "history_unit" not in name]
|
| 456 |
-
optimizer = optim.Adam([
|
| 457 |
-
{'params': model.history_unit.parameters(), 'lr': 1e-6},
|
| 458 |
-
{'params': rest_of_model_params}
|
| 459 |
-
], lr=opt["lr"], weight_decay=opt["weight_decay"])
|
| 460 |
-
scheduler = optim.StepLR(optimizer, step_size=opt["lr_step"])
|
| 461 |
-
train_dataset = VideoDataSet(opt, subset="training")
|
| 462 |
-
test_dataset = VideoDataSet(opt, subset=opt['inference_subset'])
|
| 463 |
-
warmup = False
|
| 464 |
-
for n_epoch in range(opt['epoch']):
|
| 465 |
-
if n_epoch >= 1:
|
| 466 |
-
warmup = False
|
| 467 |
-
n_iter, epoch_cost, epoch_cost_cls, epoch_cost_reg, epoch_cost_snip = train_one_epoch(opt, model, train_dataset, optimizer, warmup)
|
| 468 |
-
writer.add_scalar('data/cost', epoch_cost / (n_iter + 1), n_epoch)
|
| 469 |
-
print(f"Epoch {n_epoch}: Train Loss: {epoch_cost / (n_iter + 1):.4f}, cls: {epoch_cost_cls / (n_iter + 1):.4f}, reg: {epoch_cost_reg / (n_iter + 1):.4f}, snip: {epoch_cost_snip / (n_iter + 1):.4f}, lr: {optimizer.param_groups[-1]['lr']:.6f}")
|
| 470 |
-
scheduler.step()
|
| 471 |
-
model.eval()
|
| 472 |
-
cls_loss, reg_loss, tot_loss, IoUmAP_5 = eval_one_epoch(opt, model, test_dataset)
|
| 473 |
-
writer.add_scalar('data/mAP', IoUmAP_5, n_epoch)
|
| 474 |
-
print(f"Epoch {n_epoch}: Test Loss: {tot_loss:.4f}, cls: {cls_loss:.4f}, reg: {reg_loss:.4f}, mAP: {IoUmAP_5:.4f}")
|
| 475 |
-
state = {'epoch': n_epoch + 1, 'state_dict': model.state_dict()}
|
| 476 |
-
checkpoint_path = os.path.join(opt["checkpoint_path"], f"{opt['exp']}_checkpoint_{n_epoch + 1}.pth.tar")
|
| 477 |
-
os.makedirs(opt["checkpoint_path"], exist_ok=True)
|
| 478 |
-
torch.save(state, checkpoint_path)
|
| 479 |
-
if IoUmAP_5 > getattr(model, 'best_map', 0):
|
| 480 |
-
model.best_map = IoUmAP_5
|
| 481 |
-
torch.save(state, os.path.join(opt["checkpoint_path"], f"{opt['exp']}_ckp_best.pth.tar"))
|
| 482 |
-
model.train()
|
| 483 |
-
writer.close()
|
| 484 |
-
return model.best_map
|
| 485 |
-
|
| 486 |
-
def eval_frame_data(opt, model, dataset):
|
| 487 |
device = torch.device("cpu")
|
| 488 |
test_loader = torch.utils.data.DataLoader(
|
| 489 |
dataset,
|
| 490 |
-
batch_size=opt['batch_size'], shuffle=
|
| 491 |
-
num_workers=
|
| 492 |
)
|
| 493 |
labels_cls = {video_name: [] for video_name in dataset.video_names}
|
| 494 |
labels_reg = {video_name: [] for video_name in dataset.video_names}
|
|
@@ -499,11 +447,12 @@ def eval_frame_data(opt, model, dataset):
|
|
| 499 |
epoch_cost = 0
|
| 500 |
epoch_cost_cls = 0
|
| 501 |
epoch_cost_reg = 0
|
| 502 |
-
|
|
|
|
| 503 |
act_cls, act_reg, _ = model(input_data.float().to(device))
|
| 504 |
cost_reg = 0
|
| 505 |
cost_cls = 0
|
| 506 |
-
loss =
|
| 507 |
cost_cls = loss
|
| 508 |
epoch_cost_cls += loss.detach().cpu().numpy()
|
| 509 |
loss = regress_loss_func(reg_label, act_reg)
|
|
@@ -515,10 +464,10 @@ def eval_frame_data(opt, model, dataset):
|
|
| 515 |
total_frames += input_data.size(0)
|
| 516 |
for idx in range(input_data.size(0)):
|
| 517 |
video_name, st, ed, data_idx = dataset.inputs[n_iter * opt['batch_size'] + idx]
|
| 518 |
-
output_cls[video_name].append(act_cls[idx
|
| 519 |
-
output_reg[video_name].append(act_reg[idx
|
| 520 |
-
labels_cls[video_name].append(cls_label[idx
|
| 521 |
-
labels_reg[video_name].append(reg_label[idx
|
| 522 |
end_time = time.time()
|
| 523 |
working_time = end_time - start_time
|
| 524 |
for video_name in dataset.video_names:
|
|
@@ -547,7 +496,7 @@ def eval_map_nms(opt, dataset, output_cls, output_reg, labels_cls, labels_reg):
|
|
| 547 |
reg_anc = output_reg[video_name][idx]
|
| 548 |
proposal_anc_dict = []
|
| 549 |
for anc_idx in range(len(anchors)):
|
| 550 |
-
cls = np.argwhere(cls_anc[anc_idx][:-1] >
|
| 551 |
if len(cls) == 0:
|
| 552 |
continue
|
| 553 |
ed = idx + anchors[anc_idx] * reg_anc[anc_idx][0]
|
|
@@ -571,7 +520,7 @@ def eval_map_nms(opt, dataset, output_cls, output_reg, labels_cls, labels_reg):
|
|
| 571 |
def eval_map_suppress(opt, dataset, output_cls, output_reg, labels_cls, labels_reg):
|
| 572 |
device = torch.device("cpu")
|
| 573 |
model = SuppressNet(opt).to(device)
|
| 574 |
-
checkpoint_path = os.path.join(opt["checkpoint_path"], "ckp_best_suppress.pth.tar")
|
| 575 |
if os.path.exists(checkpoint_path):
|
| 576 |
checkpoint = torch.load(checkpoint_path, map_location=device)
|
| 577 |
model.load_state_dict(checkpoint['state_dict'])
|
|
@@ -595,7 +544,7 @@ def eval_map_suppress(opt, dataset, output_cls, output_reg, labels_cls, labels_r
|
|
| 595 |
reg_anc = output_reg[video_name][idx]
|
| 596 |
proposal_anc_dict = []
|
| 597 |
for anc_idx in range(len(anchors)):
|
| 598 |
-
cls = np.argwhere(cls_anc[anc_idx][:-1] >
|
| 599 |
if len(cls) == 0:
|
| 600 |
continue
|
| 601 |
ed = idx + anchors[anc_idx] * reg_anc[anc_idx][0]
|
|
@@ -605,7 +554,7 @@ def eval_map_suppress(opt, dataset, output_cls, output_reg, labels_cls, labels_r
|
|
| 605 |
label = cls[cidx]
|
| 606 |
tmp_dict = {
|
| 607 |
"segment": [float(st * frame_to_time / 100.0), float(ed * frame_to_time / 100.0)],
|
| 608 |
-
"score": float(cls_anc
|
| 609 |
"label": dataset.label_name[label],
|
| 610 |
"gentime": float(idx * frame_to_time / 100.0)
|
| 611 |
}
|
|
@@ -622,60 +571,13 @@ def eval_map_suppress(opt, dataset, output_cls, output_reg, labels_cls, labels_r
|
|
| 622 |
for cls in range(num_class - 1):
|
| 623 |
if suppress_conf[cls] > opt['sup_threshold']:
|
| 624 |
for proposal in proposal_anc_dict:
|
| 625 |
-
if proposal['label
|
| 626 |
if check_overlap_proposal(proposal_dict, proposal, overlapThresh=opt['soft_nms']) is None:
|
| 627 |
proposal_dict.append(proposal)
|
| 628 |
result_dict[video_name] = proposal_dict
|
| 629 |
proposal_dict = []
|
| 630 |
return result_dict
|
| 631 |
|
| 632 |
-
def test_frame(opt, video_name=None):
|
| 633 |
-
device = torch.device("cpu")
|
| 634 |
-
model = MYNET(opt).to(device)
|
| 635 |
-
checkpoint_path = os.path.join(opt["checkpoint_path"], "ckp_best.pth.tar")
|
| 636 |
-
if not os.path.exists(checkpoint_path):
|
| 637 |
-
print(f"[ERROR] Checkpoint {checkpoint_path} not found.")
|
| 638 |
-
return 0, 0, 0
|
| 639 |
-
checkpoint = torch.load(checkpoint_path, map_location=device)
|
| 640 |
-
model.load_state_dict(checkpoint['state_dict'])
|
| 641 |
-
model.eval()
|
| 642 |
-
dataset = VideoDataSet(opt, subset=opt['inference_subset'], video_name=video_name)
|
| 643 |
-
outfile = h5py.File(opt['output_file'].format(opt['exp']), 'w')
|
| 644 |
-
cls_loss, reg_loss, tot_loss, output_cls, output_reg, labels_cls, labels_reg, working_time, total_frames = eval_frame(opt, model, dataset)
|
| 645 |
-
print(f"testing loss: {tot_loss:.4f}, cls_loss: {cls_loss:.4f}, reg_loss: {reg_loss:.4f}")
|
| 646 |
-
for video_name in dataset.video_names:
|
| 647 |
-
o_cls = output_cls[video_name]
|
| 648 |
-
o_reg = output_reg[video_name]
|
| 649 |
-
l_cls = labels_cls[video_name]
|
| 650 |
-
l_reg = labels_reg[video_name]
|
| 651 |
-
dset_predcls = outfile.create_dataset(video_name + '/pred_cls', o_cls.shape, maxshape=o_cls.shape, chunks=True, dtype=np.float32)
|
| 652 |
-
dset_predcls[:, :] = o_cls[:, :]
|
| 653 |
-
dset_predreg = outfile.create_dataset(video_name + '/pred_reg', o_reg.shape, maxshape=o_reg.shape, chunks=True, dtype=np.float32)
|
| 654 |
-
dset_predreg[:, :] = o_reg[:, :]
|
| 655 |
-
dset_labelcls = outfile.create_dataset(video_name + '/label_cls', l_cls.shape, maxshape=l_cls.shape, chunks=True, dtype=np.float32)
|
| 656 |
-
dset_labelcls[:, :] = l_cls[:, :]
|
| 657 |
-
dset_labelreg = outfile.create_dataset(video_name + '/label_reg', l_reg.shape, maxshape=l_reg.shape, chunks=True, dtype=np.float32)
|
| 658 |
-
dset_labelreg[:, :] = l_reg[:, :]
|
| 659 |
-
outfile.close()
|
| 660 |
-
print(f"[INFO] Working time: {working_time:.2f}s, {total_frames / working_time:.1f}fps, {total_frames} frames")
|
| 661 |
-
return cls_loss, reg_loss, tot_loss
|
| 662 |
-
|
| 663 |
-
def patch_attention(m):
|
| 664 |
-
forward_orig = m.forward
|
| 665 |
-
def wrap(*args, **kwargs):
|
| 666 |
-
kwargs["need_weights"] = True
|
| 667 |
-
kwargs["average_attn_weights"] = False
|
| 668 |
-
return forward_orig(*args, **kwargs)
|
| 669 |
-
m.forward = wrap
|
| 670 |
-
|
| 671 |
-
class SaveOutput:
|
| 672 |
-
def __init__(self):
|
| 673 |
-
self.outputs = []
|
| 674 |
-
def __call__(self, module, module_in, module_out):
|
| 675 |
-
self.outputs.append(module_out[1])
|
| 676 |
-
def clear(self):
|
| 677 |
-
self.outputs = []
|
| 678 |
-
|
| 679 |
def test(opt, video_name=None):
|
| 680 |
device = torch.device("cpu")
|
| 681 |
model = MYNET(opt).to(device)
|
|
@@ -701,12 +603,13 @@ def test(opt, video_name=None):
|
|
| 701 |
with open(result_path, 'w') as f:
|
| 702 |
json.dump(output_dict, f, indent=4)
|
| 703 |
mAP = evaluation_detection(opt, verbose=False)
|
|
|
|
| 704 |
if video_name:
|
| 705 |
print(f"\n[INFO] Comparing Predicted and Ground Truth Actions for Video: {video_name}")
|
| 706 |
-
anno_path = opt["
|
| 707 |
if not os.path.exists(anno_path):
|
| 708 |
print(f"[ERROR] Annotation file {anno_path} not found. Skipping comparison.")
|
| 709 |
-
return
|
| 710 |
with open(anno_path, 'r') as f:
|
| 711 |
anno_data = json.load(f)
|
| 712 |
gt_annotations = anno_data['database'][video_name]['annotations']
|
|
@@ -723,7 +626,7 @@ def test(opt, video_name=None):
|
|
| 723 |
'end': pred['segment'][1],
|
| 724 |
'duration': pred['segment'][1] - pred['segment'][0],
|
| 725 |
'score': pred['score']
|
| 726 |
-
} for pred in result_dict
|
| 727 |
matches = []
|
| 728 |
iou_threshold = VIS_CONFIG['iou_threshold']
|
| 729 |
used_gt_indices = set()
|
|
@@ -733,7 +636,7 @@ def test(opt, video_name=None):
|
|
| 733 |
for gt_idx, gt in enumerate(gt_segments):
|
| 734 |
if gt_idx in used_gt_indices:
|
| 735 |
continue
|
| 736 |
-
iou = calc_iou([pred['
|
| 737 |
if iou > best_iou and iou >= iou_threshold:
|
| 738 |
best_iou = iou
|
| 739 |
best_gt_idx = gt_idx
|
|
@@ -778,7 +681,7 @@ def test(opt, video_name=None):
|
|
| 778 |
comparison_text += f"- Total Predictions: {len(pred_segments)}\n"
|
| 779 |
comparison_text += f"- Total Ground Truths: {len(gt_segments)}\n"
|
| 780 |
comparison_text += f"- Matched Segments: {matched_count}\n"
|
| 781 |
-
comparison_text += f"- Average Duration Difference
|
| 782 |
comparison_text += f"- Average IoU (Matched): {avg_iou:.2f}\n"
|
| 783 |
video_path = opt.get('video_path', '')
|
| 784 |
viz_path = ""
|
|
@@ -791,7 +694,7 @@ def test(opt, video_name=None):
|
|
| 791 |
video_path=video_path,
|
| 792 |
duration=duration
|
| 793 |
)
|
| 794 |
-
video_out_path =
|
| 795 |
video_id=video_name,
|
| 796 |
pred_segments=pred_segments,
|
| 797 |
gt_segments=gt_segments,
|
|
@@ -799,12 +702,12 @@ def test(opt, video_name=None):
|
|
| 799 |
)
|
| 800 |
else:
|
| 801 |
print(f"[WARNING] Video {video_path} not found. Skipping visualization.")
|
| 802 |
-
return
|
| 803 |
|
| 804 |
def test_online(opt, video_name=None):
|
| 805 |
device = torch.device("cpu")
|
| 806 |
model = MYNET(opt).to(device)
|
| 807 |
-
checkpoint_path = os.path.join(opt["checkpoint_path"], "
|
| 808 |
if not os.path.exists(checkpoint_path):
|
| 809 |
print(f"[ERROR] Checkpoint {checkpoint_path} not found.")
|
| 810 |
return 0
|
|
@@ -812,7 +715,7 @@ def test_online(opt, video_name=None):
|
|
| 812 |
model.load_state_dict(checkpoint['state_dict'])
|
| 813 |
model.eval()
|
| 814 |
sup_model = SuppressNet(opt).to(device)
|
| 815 |
-
sup_checkpoint_path = os.path.join(opt["checkpoint_path"], "ckp_best_suppress.pth.tar")
|
| 816 |
if os.path.exists(sup_checkpoint_path):
|
| 817 |
checkpoint = torch.load(sup_checkpoint_path, map_location=device)
|
| 818 |
sup_model.load_state_dict(checkpoint['state_dict'])
|
|
@@ -824,8 +727,8 @@ def test_online(opt, video_name=None):
|
|
| 824 |
dataset,
|
| 825 |
batch_size=1,
|
| 826 |
shuffle=False,
|
| 827 |
-
num_workers=
|
| 828 |
-
pin_memory=
|
| 829 |
)
|
| 830 |
result_dict = {}
|
| 831 |
proposal_dict = []
|
|
@@ -844,7 +747,7 @@ def test_online(opt, video_name=None):
|
|
| 844 |
for idx in range(duration):
|
| 845 |
total_frames += 1
|
| 846 |
input_queue[:-1, :] = input_queue[1:, :].clone()
|
| 847 |
-
input_queue[-1, :] = dataset._get_base_data(video_name, idx)
|
| 848 |
minput = input_queue.unsqueeze(0).to(device)
|
| 849 |
act_cls, act_reg, _ = model(minput)
|
| 850 |
act_cls = torch.softmax(act_cls, dim=-1)
|
|
@@ -852,7 +755,7 @@ def test_online(opt, video_name=None):
|
|
| 852 |
reg_anc = act_reg.squeeze(0).detach().cpu().numpy()
|
| 853 |
proposal_anc_dict = []
|
| 854 |
for anc_idx in range(len(anchors)):
|
| 855 |
-
cls = np.argwhere(cls_anc[anc_idx][:-1] >
|
| 856 |
if len(cls) == 0:
|
| 857 |
continue
|
| 858 |
ed = idx + anchors[anc_idx] * reg_anc[anc_idx][0]
|
|
@@ -893,7 +796,8 @@ def test_online(opt, video_name=None):
|
|
| 893 |
with open(result_path, "w") as f:
|
| 894 |
json.dump(output_dict, f, indent=4)
|
| 895 |
mAP = evaluation_detection(opt, verbose=False)
|
| 896 |
-
|
|
|
|
| 897 |
|
| 898 |
def main(opt, video_name=None):
|
| 899 |
max_perf = 0
|
|
@@ -904,8 +808,6 @@ def main(opt, video_name=None):
|
|
| 904 |
elif opt['mode'] == 'test':
|
| 905 |
max_perf, comparison_text, viz_path, video_out_path = test(opt, video_name=video_name)
|
| 906 |
return max_perf, comparison_text, viz_path, video_out_path
|
| 907 |
-
elif opt['mode'] == 'test_frame':
|
| 908 |
-
max_perf = test_frame(opt, video_name=video_name)
|
| 909 |
elif opt['mode'] == 'test_online':
|
| 910 |
max_perf = test_online(opt, video_name=video_name)
|
| 911 |
elif opt['mode'] == 'eval':
|
|
@@ -924,13 +826,23 @@ def gradio_interface(video):
|
|
| 924 |
opt_dict['mode'] = 'test'
|
| 925 |
opt_dict['video_name'] = video_name
|
| 926 |
opt_dict['video_path'] = video
|
| 927 |
-
opt_dict['
|
|
|
|
| 928 |
opt_dict['checkpoint_path'] = os.path.join(os.getcwd(), 'checkpoint')
|
| 929 |
-
opt_dict['result_file'] = os.path.join(os.getcwd(), 'results',
|
| 930 |
-
opt_dict['
|
| 931 |
-
opt_dict['
|
|
|
|
|
|
|
| 932 |
opt_dict['batch_size'] = 1
|
|
|
|
|
|
|
| 933 |
opt_dict['anchors'] = [int(item) for item in opt_dict['anchors'].split(',')]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 934 |
mAP, comparison_text, viz_path, video_out_path = main(opt_dict, video_name=video_name)
|
| 935 |
return viz_path, video_out_path, f"mAP: {mAP:.4f}\n\n{comparison_text}"
|
| 936 |
|
|
@@ -938,12 +850,23 @@ if __name__ == "__main__":
|
|
| 938 |
opt = opts.parse_opt()
|
| 939 |
opt = vars(opt)
|
| 940 |
opt['checkpoint_path'] = os.path.join(os.getcwd(), 'checkpoint')
|
| 941 |
-
opt['result_file'] = os.path.join(os.getcwd(), 'results',
|
| 942 |
-
opt['frame_result_file'] = os.path.join(os.getcwd(), 'results',
|
| 943 |
-
opt['
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 944 |
os.makedirs(opt["checkpoint_path"], exist_ok=True)
|
| 945 |
os.makedirs(os.path.dirname(opt["result_file"].format(opt['exp'])), exist_ok=True)
|
| 946 |
-
os.makedirs(opt["
|
| 947 |
with open(os.path.join(opt["checkpoint_path"], f"{opt['exp']}_opts.json"), "w") as f:
|
| 948 |
json.dump(opt, f, indent=4)
|
| 949 |
if opt['seed'] >= 0:
|
|
@@ -961,7 +884,7 @@ if __name__ == "__main__":
|
|
| 961 |
gr.Textbox(label="Results and mAP")
|
| 962 |
],
|
| 963 |
title="Action Detection Model",
|
| 964 |
-
description="Upload a video to detect actions using pre-extracted I3D features. View visualizations and performance metrics."
|
| 965 |
)
|
| 966 |
iface.launch()
|
| 967 |
else:
|
|
|
|
| 45 |
'frame_highlight_gt': 'red',
|
| 46 |
'frame_highlight_pred': 'black',
|
| 47 |
'iou_threshold': 0.3,
|
| 48 |
+
'frame_scale_factor': 0.3,
|
| 49 |
'video_text_scale': 0.5,
|
| 50 |
'video_gt_text_color': (180, 119, 31), # BGR
|
| 51 |
'video_pred_text_color': (14, 127, 255), # BGR
|
|
|
|
| 55 |
'video_pred_text_y': 0.45,
|
| 56 |
'video_gt_text_y': 0.55,
|
| 57 |
'video_footer_height': 150,
|
| 58 |
+
'video_gt_bar_y': 0.2,
|
| 59 |
+
'video_pred_bar_y': 0.5,
|
| 60 |
'video_bar_height': 0.15,
|
| 61 |
'video_bar_text_scale': 0.7,
|
| 62 |
'min_segment_duration': 1.0,
|
|
|
|
| 108 |
(185, 218, 255), (255, 204, 204), (193, 182, 255), (201, 252, 189),
|
| 109 |
(144, 128, 112), (112, 25, 25), (102, 51, 102), (0, 128, 128), (171, 71, 0)
|
| 110 |
]
|
| 111 |
+
action_labels = set(seg['label'] for seg in gt_segments).union(set(seg['label'] for seg in pred_segments))
|
| 112 |
action_color_map = {label: color_palette[i % len(color_palette)] for i, label in enumerate(action_labels)}
|
| 113 |
gt_color_rgb = (gt_text_color[2], gt_text_color[1], gt_text_color[0])
|
| 114 |
pred_color_rgb = (pred_text_color[2], pred_text_color[1], pred_text_color[0])
|
|
|
|
| 128 |
except IOError:
|
| 129 |
font = None
|
| 130 |
bar_font = None
|
| 131 |
+
window_size = VIS_CONFIG['scroll_window_duration']
|
| 132 |
num_windows = int(np.ceil(duration / window_size))
|
| 133 |
text_bar_gap = 48
|
| 134 |
+
text_x = VIS_CONFIG['video_bar_label_x']
|
| 135 |
frame_idx = 0
|
| 136 |
written_frames = 0
|
| 137 |
while cap.isOpened():
|
|
|
|
| 148 |
window_duration = window_end - window_start
|
| 149 |
window_timestamp = timestamp - window_start
|
| 150 |
gt_labels = [seg['label'] for seg in gt_segments if seg['start'] <= timestamp <= seg['end']]
|
| 151 |
+
gt_text = "GT: " + ", ".join(gt_labels) if gt_labels else "GT: None"
|
| 152 |
pred_labels = [seg['label'] for seg in pred_segments if seg['start'] <= timestamp <= seg['end']]
|
| 153 |
+
pred_text = "Pred: " + ", ".join(pred_labels) if pred_labels else "Pred: None"
|
| 154 |
footer_y = frame_height
|
| 155 |
+
gt_bar_y = footer_y + int(VIS_CONFIG['video_gt_bar_y'] * footer_height)
|
| 156 |
+
pred_bar_y = footer_y + int(VIS_CONFIG['video_pred_bar_y'] * footer_height)
|
| 157 |
bar_height = int(VIS_CONFIG['video_bar_height'] * footer_height)
|
| 158 |
if font:
|
| 159 |
gt_text_bbox = bar_font.getbbox("GT")
|
|
|
|
| 204 |
frame_text_bbox = draw.textbbox((0, 0), frame_info, font=font)
|
| 205 |
frame_text_width = frame_text_bbox[2] - frame_text_bbox[0]
|
| 206 |
frame_text_x = (frame_width - frame_text_width) // 2
|
| 207 |
+
draw.text((frame_text_x, int(frame_height * VIS_CONFIG['video_frame_text_y'])), frame_info, font=font, fill=(0, 0, 0))
|
| 208 |
window_info = f"{window_start:.1f}s - {window_end:.1f}s"
|
| 209 |
window_text_bbox = draw.textbbox((0, 0), window_info, font=bar_font)
|
| 210 |
window_text_width = window_text_bbox[2] - window_text_bbox[0]
|
| 211 |
window_text_x = (frame_width - window_text_width) // 2
|
| 212 |
draw.text((window_text_x, footer_y + 10), window_info, font=bar_font, fill=(0, 0, 0))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
draw.text((text_x, gt_bar_y + bar_height // 2), "GT", font=bar_font, fill=gt_color_rgb)
|
| 214 |
draw.text((text_x, pred_bar_y + bar_height // 2), "Pred", font=bar_font, fill=pred_color_rgb)
|
| 215 |
+
gt_y = int(frame_height * VIS_CONFIG['video_gt_text_y'])
|
| 216 |
+
pred_y = int(frame_height * VIS_CONFIG['video_pred_text_y'])
|
| 217 |
+
draw.text((10, gt_y), gt_text, font=font, fill=gt_color_rgb)
|
| 218 |
+
draw.text((10, pred_y), pred_text, font=font, fill=pred_color_rgb)
|
| 219 |
extended_frame = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
|
| 220 |
else:
|
| 221 |
frame_info = f"Frame: {frame_idx} | FPS: {fps:.2f}"
|
|
|
|
| 224 |
cv2.putText(
|
| 225 |
extended_frame,
|
| 226 |
frame_info,
|
| 227 |
+
(frame_text_x, int(frame_height * VIS_CONFIG['video_frame_text_y']) + 20),
|
| 228 |
cv2.FONT_HERSHEY_SIMPLEX,
|
| 229 |
text_scale,
|
| 230 |
(0, 0, 0),
|
|
|
|
| 244 |
1,
|
| 245 |
cv2.LINE_AA
|
| 246 |
)
|
| 247 |
+
cv2.putText(
|
| 248 |
+
extended_frame,
|
| 249 |
+
gt_text,
|
| 250 |
+
(10, int(frame_height * VIS_CONFIG['video_gt_text_y'])),
|
| 251 |
+
cv2.FONT_HERSHEY_SIMPLEX,
|
| 252 |
+
text_scale,
|
| 253 |
+
gt_text_color,
|
| 254 |
+
text_thickness,
|
| 255 |
+
cv2.LINE_AA
|
| 256 |
+
)
|
| 257 |
+
cv2.putText(
|
| 258 |
+
extended_frame,
|
| 259 |
+
pred_text,
|
| 260 |
+
(10, int(frame_height * VIS_CONFIG['video_pred_text_y'])),
|
| 261 |
+
cv2.FONT_HERSHEY_SIMPLEX,
|
| 262 |
+
text_scale,
|
| 263 |
+
pred_text_color,
|
| 264 |
+
text_thickness,
|
| 265 |
+
cv2.LINE_AA
|
| 266 |
+
)
|
|
|
|
|
|
|
| 267 |
cv2.putText(
|
| 268 |
extended_frame,
|
| 269 |
"GT",
|
|
|
|
| 312 |
if num_frames > VIS_CONFIG['max_frames']:
|
| 313 |
frame_interval = duration / (VIS_CONFIG['max_frames'] - 1)
|
| 314 |
num_frames = VIS_CONFIG['max_frames']
|
| 315 |
+
frame_times = np.linspace(0, duration, num_frames, endpoint=True)
|
| 316 |
frames = []
|
| 317 |
cap = cv2.VideoCapture(video_path)
|
| 318 |
if not cap.isOpened():
|
|
|
|
| 324 |
ret, frame = cap.read()
|
| 325 |
if ret:
|
| 326 |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 327 |
+
frame = cv2.resize(frame, (int(frame.shape[1] * VIS_CONFIG['frame_scale_factor']), int(frame.shape[0] * VIS_CONFIG['frame_scale_factor'])))
|
| 328 |
frames.append(frame)
|
| 329 |
else:
|
| 330 |
frames.append(np.ones((100, 100, 3), dtype=np.uint8) * 255)
|
| 331 |
cap.release()
|
| 332 |
+
fig = plt.figure(figsize=(num_frames * 2, 6), constrained_layout=True)
|
| 333 |
gs = fig.add_gridspec(3, num_frames, height_ratios=[3, 1, 1])
|
| 334 |
for i, (t, frame) in enumerate(zip(frame_times, frames)):
|
| 335 |
ax = fig.add_subplot(gs[0, i])
|
|
|
|
| 397 |
train_loader = torch.utils.data.DataLoader(
|
| 398 |
train_dataset,
|
| 399 |
batch_size=opt['batch_size'], shuffle=True,
|
| 400 |
+
num_workers=0, pin_memory=False, drop_last=True
|
| 401 |
)
|
| 402 |
epoch_cost = 0
|
| 403 |
epoch_cost_cls = 0
|
|
|
|
| 412 |
g['lr'] = n_iter * opt['lr'] / total_iter
|
| 413 |
act_cls, act_reg, snip_cls = model(input_data.float().to(device))
|
| 414 |
act_cls.register_hook(partial(cls_loss.collect_grad, cls_label))
|
| 415 |
+
snip_cls.register_hook(lambda grad: snip_loss.collect_grad(grad, snip_label))
|
| 416 |
cost_reg = 0
|
| 417 |
cost_cls = 0
|
| 418 |
loss = cls_loss_func_(cls_loss, act_cls)
|
|
|
|
| 431 |
optimizer.step()
|
| 432 |
return n_iter, epoch_cost, epoch_cost_cls, epoch_cost_reg, epoch_cost_snip
|
| 433 |
|
| 434 |
+
def eval_frame(opt, model, dataset):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 435 |
device = torch.device("cpu")
|
| 436 |
test_loader = torch.utils.data.DataLoader(
|
| 437 |
dataset,
|
| 438 |
+
batch_size=opt['batch_size'], shuffle=False,
|
| 439 |
+
num_workers=0, pin_memory=False, drop_last=False
|
| 440 |
)
|
| 441 |
labels_cls = {video_name: [] for video_name in dataset.video_names}
|
| 442 |
labels_reg = {video_name: [] for video_name in dataset.video_names}
|
|
|
|
| 447 |
epoch_cost = 0
|
| 448 |
epoch_cost_cls = 0
|
| 449 |
epoch_cost_reg = 0
|
| 450 |
+
cls_loss_fn = MultiCrossEntropyLoss(focal=True)
|
| 451 |
+
for n_iter, (input_data, cls_label, reg_label, snip_label) in enumerate(tqdm(test_loader)):
|
| 452 |
act_cls, act_reg, _ = model(input_data.float().to(device))
|
| 453 |
cost_reg = 0
|
| 454 |
cost_cls = 0
|
| 455 |
+
loss = cls_loss_func_(cls_loss_fn, act_cls)
|
| 456 |
cost_cls = loss
|
| 457 |
epoch_cost_cls += loss.detach().cpu().numpy()
|
| 458 |
loss = regress_loss_func(reg_label, act_reg)
|
|
|
|
| 464 |
total_frames += input_data.size(0)
|
| 465 |
for idx in range(input_data.size(0)):
|
| 466 |
video_name, st, ed, data_idx = dataset.inputs[n_iter * opt['batch_size'] + idx]
|
| 467 |
+
output_cls[video_name].append(act_cls[idx].detach().cpu().numpy())
|
| 468 |
+
output_reg[video_name].append(act_reg[idx].detach().cpu().numpy())
|
| 469 |
+
labels_cls[video_name].append(cls_label[idx].cpu().numpy())
|
| 470 |
+
labels_reg[video_name].append(reg_label[idx].cpu().numpy())
|
| 471 |
end_time = time.time()
|
| 472 |
working_time = end_time - start_time
|
| 473 |
for video_name in dataset.video_names:
|
|
|
|
| 496 |
reg_anc = output_reg[video_name][idx]
|
| 497 |
proposal_anc_dict = []
|
| 498 |
for anc_idx in range(len(anchors)):
|
| 499 |
+
cls = np.argwhere(cls_anc[anc_idx][:-1] > threshold).reshape(-1)
|
| 500 |
if len(cls) == 0:
|
| 501 |
continue
|
| 502 |
ed = idx + anchors[anc_idx] * reg_anc[anc_idx][0]
|
|
|
|
| 520 |
def eval_map_suppress(opt, dataset, output_cls, output_reg, labels_cls, labels_reg):
|
| 521 |
device = torch.device("cpu")
|
| 522 |
model = SuppressNet(opt).to(device)
|
| 523 |
+
checkpoint_path = os.path.join(opt["checkpoint_path"], f"ckp_best_suppress.pth.tar")
|
| 524 |
if os.path.exists(checkpoint_path):
|
| 525 |
checkpoint = torch.load(checkpoint_path, map_location=device)
|
| 526 |
model.load_state_dict(checkpoint['state_dict'])
|
|
|
|
| 544 |
reg_anc = output_reg[video_name][idx]
|
| 545 |
proposal_anc_dict = []
|
| 546 |
for anc_idx in range(len(anchors)):
|
| 547 |
+
cls = np.argwhere(cls_anc[anc_idx][:-1] > threshold).reshape(-1)
|
| 548 |
if len(cls) == 0:
|
| 549 |
continue
|
| 550 |
ed = idx + anchors[anc_idx] * reg_anc[anc_idx][0]
|
|
|
|
| 554 |
label = cls[cidx]
|
| 555 |
tmp_dict = {
|
| 556 |
"segment": [float(st * frame_to_time / 100.0), float(ed * frame_to_time / 100.0)],
|
| 557 |
+
"score": float(cls_anc[anc_idx][label]),
|
| 558 |
"label": dataset.label_name[label],
|
| 559 |
"gentime": float(idx * frame_to_time / 100.0)
|
| 560 |
}
|
|
|
|
| 571 |
for cls in range(num_class - 1):
|
| 572 |
if suppress_conf[cls] > opt['sup_threshold']:
|
| 573 |
for proposal in proposal_anc_dict:
|
| 574 |
+
if proposal['label'] == dataset.label_name[cls]:
|
| 575 |
if check_overlap_proposal(proposal_dict, proposal, overlapThresh=opt['soft_nms']) is None:
|
| 576 |
proposal_dict.append(proposal)
|
| 577 |
result_dict[video_name] = proposal_dict
|
| 578 |
proposal_dict = []
|
| 579 |
return result_dict
|
| 580 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 581 |
def test(opt, video_name=None):
|
| 582 |
device = torch.device("cpu")
|
| 583 |
model = MYNET(opt).to(device)
|
|
|
|
| 603 |
with open(result_path, 'w') as f:
|
| 604 |
json.dump(output_dict, f, indent=4)
|
| 605 |
mAP = evaluation_detection(opt, verbose=False)
|
| 606 |
+
mAP_value = sum(mAP) / len(mAP) if mAP else 0
|
| 607 |
if video_name:
|
| 608 |
print(f"\n[INFO] Comparing Predicted and Ground Truth Actions for Video: {video_name}")
|
| 609 |
+
anno_path = opt["video_anno"].format(opt["split"])
|
| 610 |
if not os.path.exists(anno_path):
|
| 611 |
print(f"[ERROR] Annotation file {anno_path} not found. Skipping comparison.")
|
| 612 |
+
return mAP_value, "", "", ""
|
| 613 |
with open(anno_path, 'r') as f:
|
| 614 |
anno_data = json.load(f)
|
| 615 |
gt_annotations = anno_data['database'][video_name]['annotations']
|
|
|
|
| 626 |
'end': pred['segment'][1],
|
| 627 |
'duration': pred['segment'][1] - pred['segment'][0],
|
| 628 |
'score': pred['score']
|
| 629 |
+
} for pred in result_dict.get(video_name, [])]
|
| 630 |
matches = []
|
| 631 |
iou_threshold = VIS_CONFIG['iou_threshold']
|
| 632 |
used_gt_indices = set()
|
|
|
|
| 636 |
for gt_idx, gt in enumerate(gt_segments):
|
| 637 |
if gt_idx in used_gt_indices:
|
| 638 |
continue
|
| 639 |
+
iou = calc_iou([pred['end'], pred['duration']], [gt['end'], gt['duration']])
|
| 640 |
if iou > best_iou and iou >= iou_threshold:
|
| 641 |
best_iou = iou
|
| 642 |
best_gt_idx = gt_idx
|
|
|
|
| 681 |
comparison_text += f"- Total Predictions: {len(pred_segments)}\n"
|
| 682 |
comparison_text += f"- Total Ground Truths: {len(gt_segments)}\n"
|
| 683 |
comparison_text += f"- Matched Segments: {matched_count}\n"
|
| 684 |
+
comparison_text += f"- Average Duration Difference (s): {avg_duration_diff:.2f}\n"
|
| 685 |
comparison_text += f"- Average IoU (Matched): {avg_iou:.2f}\n"
|
| 686 |
video_path = opt.get('video_path', '')
|
| 687 |
viz_path = ""
|
|
|
|
| 694 |
video_path=video_path,
|
| 695 |
duration=duration
|
| 696 |
)
|
| 697 |
+
video_out_path = annotate_video_with_actions(
|
| 698 |
video_id=video_name,
|
| 699 |
pred_segments=pred_segments,
|
| 700 |
gt_segments=gt_segments,
|
|
|
|
| 702 |
)
|
| 703 |
else:
|
| 704 |
print(f"[WARNING] Video {video_path} not found. Skipping visualization.")
|
| 705 |
+
return mAP_value, comparison_text, viz_path, video_out_path
|
| 706 |
|
| 707 |
def test_online(opt, video_name=None):
|
| 708 |
device = torch.device("cpu")
|
| 709 |
model = MYNET(opt).to(device)
|
| 710 |
+
checkpoint_path = os.path.join(opt["checkpoint_path"], f"{opt['exp']}_ckp_best.pth.tar")
|
| 711 |
if not os.path.exists(checkpoint_path):
|
| 712 |
print(f"[ERROR] Checkpoint {checkpoint_path} not found.")
|
| 713 |
return 0
|
|
|
|
| 715 |
model.load_state_dict(checkpoint['state_dict'])
|
| 716 |
model.eval()
|
| 717 |
sup_model = SuppressNet(opt).to(device)
|
| 718 |
+
sup_checkpoint_path = os.path.join(opt["checkpoint_path"], f"ckp_best_suppress.pth.tar")
|
| 719 |
if os.path.exists(sup_checkpoint_path):
|
| 720 |
checkpoint = torch.load(sup_checkpoint_path, map_location=device)
|
| 721 |
sup_model.load_state_dict(checkpoint['state_dict'])
|
|
|
|
| 727 |
dataset,
|
| 728 |
batch_size=1,
|
| 729 |
shuffle=False,
|
| 730 |
+
num_workers=0,
|
| 731 |
+
pin_memory=False
|
| 732 |
)
|
| 733 |
result_dict = {}
|
| 734 |
proposal_dict = []
|
|
|
|
| 747 |
for idx in range(duration):
|
| 748 |
total_frames += 1
|
| 749 |
input_queue[:-1, :] = input_queue[1:, :].clone()
|
| 750 |
+
input_queue[-1, :] = dataset._get_base_data(video_name, idx, idx + 1).squeeze(0)
|
| 751 |
minput = input_queue.unsqueeze(0).to(device)
|
| 752 |
act_cls, act_reg, _ = model(minput)
|
| 753 |
act_cls = torch.softmax(act_cls, dim=-1)
|
|
|
|
| 755 |
reg_anc = act_reg.squeeze(0).detach().cpu().numpy()
|
| 756 |
proposal_anc_dict = []
|
| 757 |
for anc_idx in range(len(anchors)):
|
| 758 |
+
cls = np.argwhere(cls_anc[anc_idx][:-1] > threshold).reshape(-1)
|
| 759 |
if len(cls) == 0:
|
| 760 |
continue
|
| 761 |
ed = idx + anchors[anc_idx] * reg_anc[anc_idx][0]
|
|
|
|
| 796 |
with open(result_path, "w") as f:
|
| 797 |
json.dump(output_dict, f, indent=4)
|
| 798 |
mAP = evaluation_detection(opt, verbose=False)
|
| 799 |
+
mAP_value = sum(mAP) / len(mAP) if mAP else 0
|
| 800 |
+
return mAP_value
|
| 801 |
|
| 802 |
def main(opt, video_name=None):
|
| 803 |
max_perf = 0
|
|
|
|
| 808 |
elif opt['mode'] == 'test':
|
| 809 |
max_perf, comparison_text, viz_path, video_out_path = test(opt, video_name=video_name)
|
| 810 |
return max_perf, comparison_text, viz_path, video_out_path
|
|
|
|
|
|
|
| 811 |
elif opt['mode'] == 'test_online':
|
| 812 |
max_perf = test_online(opt, video_name=video_name)
|
| 813 |
elif opt['mode'] == 'eval':
|
|
|
|
| 826 |
opt_dict['mode'] = 'test'
|
| 827 |
opt_dict['video_name'] = video_name
|
| 828 |
opt_dict['video_path'] = video
|
| 829 |
+
opt_dict['video_anno'] = os.path.join(os.getcwd(), 'data', 'annotations.json')
|
| 830 |
+
opt_dict['video_feature_all_test'] = os.path.join(os.getcwd(), 'data', 'features') + os.sep
|
| 831 |
opt_dict['checkpoint_path'] = os.path.join(os.getcwd(), 'checkpoint')
|
| 832 |
+
opt_dict['result_file'] = os.path.join(os.getcwd(), 'results', 'result_{}.json')
|
| 833 |
+
opt_dict['frame_result_file'] = os.path.join(os.getcwd(), 'results', 'frame_result_{}.h5')
|
| 834 |
+
opt_dict['video_len_file'] = os.path.join(os.getcwd(), 'data', 'video_len_{}.json')
|
| 835 |
+
opt_dict['proposal_label_file'] = os.path.join(os.getcwd(), 'data', 'proposal_label_{}.h5')
|
| 836 |
+
opt_dict['suppress_label_file'] = os.path.join(os.getcwd(), 'data', 'suppress_label_{}.h5')
|
| 837 |
opt_dict['batch_size'] = 1
|
| 838 |
+
opt_dict['data_format'] = 'npz_i3d'
|
| 839 |
+
opt_dict['rgb_only'] = False
|
| 840 |
opt_dict['anchors'] = [int(item) for item in opt_dict['anchors'].split(',')]
|
| 841 |
+
opt_dict['predefined_fps'] = 30 # Adjust if needed
|
| 842 |
+
opt_dict['split'] = 'test'
|
| 843 |
+
opt_dict['setup'] = 'default'
|
| 844 |
+
opt_dict['data_rescale'] = 1.0
|
| 845 |
+
opt_dict['pos_threshold'] = 0.5
|
| 846 |
mAP, comparison_text, viz_path, video_out_path = main(opt_dict, video_name=video_name)
|
| 847 |
return viz_path, video_out_path, f"mAP: {mAP:.4f}\n\n{comparison_text}"
|
| 848 |
|
|
|
|
| 850 |
opt = opts.parse_opt()
|
| 851 |
opt = vars(opt)
|
| 852 |
opt['checkpoint_path'] = os.path.join(os.getcwd(), 'checkpoint')
|
| 853 |
+
opt['result_file'] = os.path.join(os.getcwd(), 'results', 'result_{}.json')
|
| 854 |
+
opt['frame_result_file'] = os.path.join(os.getcwd(), 'results', 'frame_result_{}.h5')
|
| 855 |
+
opt['video_anno'] = os.path.join(os.getcwd(), 'data', 'annotations.json')
|
| 856 |
+
opt['video_feature_all_test'] = os.path.join(os.getcwd(), 'data', 'features') + os.sep
|
| 857 |
+
opt['video_len_file'] = os.path.join(os.getcwd(), 'data', 'video_len_{}.json')
|
| 858 |
+
opt['proposal_label_file'] = os.path.join(os.getcwd(), 'data', 'proposal_label_{}.h5')
|
| 859 |
+
opt['suppress_label_file'] = os.path.join(os.getcwd(), 'data', 'suppress_label_{}.h5')
|
| 860 |
+
opt['data_format'] = 'npz_i3d'
|
| 861 |
+
opt['rgb_only'] = False
|
| 862 |
+
opt['predefined_fps'] = 30
|
| 863 |
+
opt['split'] = 'test'
|
| 864 |
+
opt['setup'] = 'default'
|
| 865 |
+
opt['data_rescale'] = 1.0
|
| 866 |
+
opt['pos_threshold'] = 0.5
|
| 867 |
os.makedirs(opt["checkpoint_path"], exist_ok=True)
|
| 868 |
os.makedirs(os.path.dirname(opt["result_file"].format(opt['exp'])), exist_ok=True)
|
| 869 |
+
os.makedirs(os.path.dirname(opt["video_anno"]), exist_ok=True)
|
| 870 |
with open(os.path.join(opt["checkpoint_path"], f"{opt['exp']}_opts.json"), "w") as f:
|
| 871 |
json.dump(opt, f, indent=4)
|
| 872 |
if opt['seed'] >= 0:
|
|
|
|
| 884 |
gr.Textbox(label="Results and mAP")
|
| 885 |
],
|
| 886 |
title="Action Detection Model",
|
| 887 |
+
description="Upload a video to detect actions using pre-extracted I3D features. Ensure a corresponding .npz file exists in data/features/. View visualizations and performance metrics."
|
| 888 |
)
|
| 889 |
iface.launch()
|
| 890 |
else:
|