import os import json import torch import torchvision import torch.nn.parallel import torch.nn.functional as F import torch.optim as optim import numpy as np import opts_egtea as opts import time import h5py from tqdm import tqdm from iou_utils import * from eval import evaluation_detection from tensorboardX import SummaryWriter from dataset import VideoDataSet, calc_iou from models import MYNET, SuppressNet from loss_func import cls_loss_func, cls_loss_func_, regress_loss_func from loss_func import MultiCrossEntropyLoss from functools import partial import matplotlib.pyplot as plt import matplotlib.patches as patches import cv2 from typing import List, Dict, Optional from PIL import Image, ImageDraw, ImageFont import warnings import gradio as gr import subprocess # Suppress non-critical warnings warnings.filterwarnings("ignore", category=UserWarning) warnings.filterwarnings("ignore", category=DeprecationWarning) # Visualization Configuration (Optimized for HF Free CPU) VIS_CONFIG = { 'frame_interval': 1.0, 'max_frames': 5, # Reduced for CPU memory 'save_dir': os.path.join(os.getcwd(), 'output', 'visualizations'), 'video_save_dir': os.path.join(os.getcwd(), 'output', 'videos'), 'gt_color': '#1f77b4', # Blue for ground truth 'pred_color': '#ff7f0e', # Orange for predictions 'fontsize_label': 10, 'fontsize_title': 14, 'frame_highlight_both': 'green', 'frame_highlight_gt': 'red', 'frame_highlight_pred': 'black', 'iou_threshold': 0.3, 'frame_scale_factor': 0.3, 'video_text_scale': 0.5, 'video_gt_text_color': (180, 119, 31), # BGR 'video_pred_text_color': (14, 127, 255), # BGR 'video_text_thickness': 1, 'video_font_path': os.path.join(os.getcwd(), 'fonts', 'Poppins ExtraBold Italic 800.ttf'), 'video_font_fallback': '/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf', 'video_pred_text_y': 0.45, 'video_gt_text_y': 0.55, 'video_footer_height': 150, 'video_gt_bar_y': 0.2, 'video_pred_bar_y': 0.5, 'video_bar_height': 0.15, 'video_bar_text_scale': 0.7, 'min_segment_duration': 1.0, 'video_frame_text_y': 0.05, 'video_bar_label_x': 10, 'video_bar_label_scale': 0.5, 'scroll_window_duration': 30.0, 'scroll_speed': 0.5, } def annotate_video_with_actions( video_id: str, pred_segments: List[Dict], gt_segments: List[Dict], video_path: str, save_dir: str = VIS_CONFIG['video_save_dir'], text_scale: float = VIS_CONFIG['video_text_scale'] * 1.5, gt_text_color: tuple = VIS_CONFIG['video_gt_text_color'], pred_text_color: tuple = VIS_CONFIG['video_pred_text_color'], text_thickness: int = VIS_CONFIG['video_text_thickness'] ) -> str: os.makedirs(save_dir, exist_ok=True) cap = cv2.VideoCapture(video_path) if not cap.isOpened(): print(f"Error: Could not open video {video_path}. Skipping.") return "" fps = cap.get(cv2.CAP_PROP_FPS) frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) duration = total_frames / fps footer_height = VIS_CONFIG['video_footer_height'] output_height = frame_height + footer_height output_path = os.path.join(save_dir, f"annotated_{video_id}_{opt['exp']}.avi") mp4_path = output_path.replace('.avi', '.mp4') fourcc = cv2.VideoWriter_fourcc(*'XVID') out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, output_height)) if not out.isOpened(): print(f"Error: Could not initialize video writer for {output_path}.") cap.release() return "" min_duration = VIS_CONFIG['min_segment_duration'] gt_segments = [seg for seg in gt_segments if seg['duration'] >= min_duration] pred_segments = [seg for seg in pred_segments if seg['duration'] >= min_duration] color_palette = [ (128, 0, 0), (60, 20, 220), (0, 128, 0), (128, 0, 128), (79, 69, 54), (128, 128, 0), (0, 0, 128), (130, 0, 75), (34, 139, 34), (0, 85, 204), (149, 146, 209), (235, 206, 135), (250, 230, 230), (191, 226, 159), (185, 218, 255), (255, 204, 204), (193, 182, 255), (201, 252, 189), (144, 128, 112), (112, 25, 25), (102, 51, 102), (0, 128, 128), (171, 71, 0) ] action_labels = set(seg['label'] for seg in gt_segments).union(set(seg['label'] for seg in pred_segments)) action_color_map = {label: color_palette[i % len(color_palette)] for i, label in enumerate(action_labels)} gt_color_rgb = (gt_text_color[2], gt_text_color[1], gt_text_color[0]) pred_color_rgb = (pred_text_color[2], pred_text_color[1], pred_text_color[0]) font_path = VIS_CONFIG['video_font_path'] font_fallback = VIS_CONFIG['video_font_fallback'] font_size = int(20 * text_scale) bar_font_size = int(20 * VIS_CONFIG['video_bar_text_scale']) font = None bar_font = None try: font = ImageFont.truetype(font_path, font_size) bar_font = ImageFont.truetype(font_path, bar_font_size) except IOError: try: font = ImageFont.truetype(font_fallback, font_size) bar_font = ImageFont.truetype(font_fallback, bar_font_size) except IOError: font = None bar_font = None window_size = VIS_CONFIG['scroll_window_duration'] num_windows = int(np.ceil(duration / window_size)) text_bar_gap = 48 text_x = VIS_CONFIG['video_bar_label_x'] frame_idx = 0 written_frames = 0 while cap.isOpened(): ret, frame = cap.read() if not ret: break extended_frame = np.zeros((output_height, frame_width, 3), dtype=np.uint8) extended_frame[:frame_height, :, :] = frame extended_frame[frame_height:, :, :] = 255 timestamp = frame_idx / fps window_idx = int(timestamp // window_size) window_start = window_idx * window_size window_end = min(window_start + window_size, duration) window_duration = window_end - window_start window_timestamp = timestamp - window_start gt_labels = [seg['label'] for seg in gt_segments if seg['start'] <= timestamp <= seg['end']] gt_text = "GT: " + ", ".join(gt_labels) if gt_labels else "GT: None" pred_labels = [seg['label'] for seg in pred_segments if seg['start'] <= timestamp <= seg['end']] pred_text = "Pred: " + ", ".join(pred_labels) if pred_labels else "Pred: None" footer_y = frame_height gt_bar_y = footer_y + int(VIS_CONFIG['video_gt_bar_y'] * footer_height) pred_bar_y = footer_y + int(VIS_CONFIG['video_pred_bar_y'] * footer_height) bar_height = int(VIS_CONFIG['video_bar_height'] * footer_height) if font: gt_text_bbox = bar_font.getbbox("GT") pred_text_bbox = bar_font.getbbox("Pred") gt_text_width = gt_text_bbox[2] - gt_text_bbox[0] pred_text_width = pred_text_bbox[2] - pred_text_bbox[0] else: gt_text_size = cv2.getTextSize("GT", cv2.FONT_HERSHEY_SIMPLEX, VIS_CONFIG['video_bar_text_scale'], 1)[0] pred_text_size = cv2.getTextSize("Pred", cv2.FONT_HERSHEY_SIMPLEX, VIS_CONFIG['video_bar_text_scale'], 1)[0] gt_text_width = gt_text_size[0] pred_text_width = pred_text_size[0] max_text_width = max(gt_text_width, pred_text_width) bar_start_x = text_x + max_text_width + text_bar_gap bar_width = frame_width - bar_start_x for seg in gt_segments: if seg['start'] <= window_end and seg['end'] >= window_start: start_t = max(seg['start'], window_start) end_t = min(seg['end'], window_start + window_timestamp) start_x = bar_start_x + int(((start_t - window_start) / window_duration) * bar_width) end_x = bar_start_x + int(((end_t - window_start) / window_duration) * bar_width) if end_x > start_x: cv2.rectangle( extended_frame, (start_x, gt_bar_y), (end_x, gt_bar_y + bar_height), action_color_map[seg['label']], -1 ) for seg in pred_segments: if seg['start'] <= window_end and seg['end'] >= window_start: start_t = max(seg['start'], window_start) end_t = min(seg['end'], window_start + window_timestamp) start_x = bar_start_x + int(((start_t - window_start) / window_duration) * bar_width) end_x = bar_start_x + int(((end_t - window_start) / window_duration) * bar_width) if end_x > start_x: cv2.rectangle( extended_frame, (start_x, pred_bar_y), (end_x, pred_bar_y + bar_height), action_color_map[seg['label']], -1 ) if font: frame_rgb = cv2.cvtColor(extended_frame, cv2.COLOR_BGR2RGB) pil_image = Image.fromarray(frame_rgb) draw = ImageDraw.Draw(pil_image) frame_info = f"Frame: {frame_idx} | FPS: {fps:.2f}" frame_text_bbox = draw.textbbox((0, 0), frame_info, font=font) frame_text_width = frame_text_bbox[2] - frame_text_bbox[0] frame_text_x = (frame_width - frame_text_width) // 2 draw.text((frame_text_x, int(frame_height * VIS_CONFIG['video_frame_text_y'])), frame_info, font=font, fill=(0, 0, 0)) window_info = f"{window_start:.1f}s - {window_end:.1f}s" window_text_bbox = draw.textbbox((0, 0), window_info, font=bar_font) window_text_width = window_text_bbox[2] - window_text_bbox[0] window_text_x = (frame_width - window_text_width) // 2 draw.text((window_text_x, footer_y + 10), window_info, font=bar_font, fill=(0, 0, 0)) draw.text((text_x, gt_bar_y + bar_height // 2), "GT", font=bar_font, fill=gt_color_rgb) draw.text((text_x, pred_bar_y + bar_height // 2), "Pred", font=bar_font, fill=pred_color_rgb) gt_y = int(frame_height * VIS_CONFIG['video_gt_text_y']) pred_y = int(frame_height * VIS_CONFIG['video_pred_text_y']) draw.text((10, gt_y), gt_text, font=font, fill=gt_color_rgb) draw.text((10, pred_y), pred_text, font=font, fill=pred_color_rgb) extended_frame = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR) else: frame_info = f"Frame: {frame_idx} | FPS: {fps:.2f}" text_size = cv2.getTextSize(frame_info, cv2.FONT_HERSHEY_SIMPLEX, text_scale, text_thickness)[0] frame_text_x = (frame_width - text_size[0]) // 2 cv2.putText( extended_frame, frame_info, (frame_text_x, int(frame_height * VIS_CONFIG['video_frame_text_y']) + 20), cv2.FONT_HERSHEY_SIMPLEX, text_scale, (0, 0, 0), text_thickness, cv2.LINE_AA ) window_info = f"{window_start:.1f}s - {window_end:.1f}s" window_text_size = cv2.getTextSize(window_info, cv2.FONT_HERSHEY_SIMPLEX, VIS_CONFIG['video_bar_text_scale'], 1)[0] window_text_x = (frame_width - window_text_size[0]) // 2 cv2.putText( extended_frame, window_info, (window_text_x, footer_y + 20), cv2.FONT_HERSHEY_SIMPLEX, VIS_CONFIG['video_bar_text_scale'], (0, 0, 0), 1, cv2.LINE_AA ) cv2.putText( extended_frame, gt_text, (10, int(frame_height * VIS_CONFIG['video_gt_text_y'])), cv2.FONT_HERSHEY_SIMPLEX, text_scale, gt_text_color, text_thickness, cv2.LINE_AA ) cv2.putText( extended_frame, pred_text, (10, int(frame_height * VIS_CONFIG['video_pred_text_y'])), cv2.FONT_HERSHEY_SIMPLEX, text_scale, pred_text_color, text_thickness, cv2.LINE_AA ) cv2.putText( extended_frame, "GT", (text_x, gt_bar_y + bar_height // 2 + 5), cv2.FONT_HERSHEY_SIMPLEX, VIS_CONFIG['video_bar_text_scale'], gt_text_color, 1, cv2.LINE_AA ) cv2.putText( extended_frame, "Pred", (text_x, pred_bar_y + bar_height // 2 + 5), cv2.FONT_HERSHEY_SIMPLEX, VIS_CONFIG['video_bar_text_scale'], pred_text_color, 1, cv2.LINE_AA ) out.write(extended_frame) written_frames += 1 frame_idx += 1 cap.release() out.release() print(f"[✅ Saved Annotated Video]: {output_path}, Frames={written_frames}") try: subprocess.run(['ffmpeg', '-i', output_path, '-vcodec', 'libx264', '-acodec', 'aac', mp4_path], check=True) print(f"[✅ Converted to MP4]: {mp4_path}") return mp4_path except (subprocess.CalledProcessError, FileNotFoundError): print("Note: FFmpeg not available or failed. Returning .avi (may not play in browsers).") return output_path if os.path.exists(output_path) else "" def visualize_action_lengths( video_id: str, pred_segments: List[Dict], gt_segments: List[Dict], video_path: str, duration: float, save_dir: str = VIS_CONFIG['save_dir'], frame_interval: float = VIS_CONFIG['frame_interval'] ) -> str: os.makedirs(save_dir, exist_ok=True) num_frames = int(duration / frame_interval) + 1 if num_frames > VIS_CONFIG['max_frames']: frame_interval = duration / (VIS_CONFIG['max_frames'] - 1) num_frames = VIS_CONFIG['max_frames'] frame_times = np.linspace(0, duration, num_frames, endpoint=True) frames = [] cap = cv2.VideoCapture(video_path) if not cap.isOpened(): print(f"Warning: Could not open video {video_path}. Using placeholders.") frames = [np.ones((100, 100, 3), dtype=np.uint8) * 255 for _ in frame_times] else: for t in frame_times: cap.set(cv2.CAP_PROP_POS_MSEC, t * 1000) ret, frame = cap.read() if ret: frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frame = cv2.resize(frame, (int(frame.shape[1] * VIS_CONFIG['frame_scale_factor']), int(frame.shape[0] * VIS_CONFIG['frame_scale_factor']))) frames.append(frame) else: frames.append(np.ones((100, 100, 3), dtype=np.uint8) * 255) cap.release() fig = plt.figure(figsize=(num_frames * 2, 6), constrained_layout=True) gs = fig.add_gridspec(3, num_frames, height_ratios=[3, 1, 1]) for i, (t, frame) in enumerate(zip(frame_times, frames)): ax = fig.add_subplot(gs[0, i]) gt_hit = any(seg['start'] <= t <= seg['end'] for seg in gt_segments) pred_hit = any(seg['start'] <= t <= seg['end'] for seg in pred_segments) border_color = None if gt_hit and pred_hit: border_color = VIS_CONFIG['frame_highlight_both'] elif gt_hit: border_color = VIS_CONFIG['frame_highlight_gt'] elif pred_hit: border_color = VIS_CONFIG['frame_highlight_pred'] ax.imshow(frame) ax.axis('off') if border_color: for spine in ax.spines.values(): spine.set_edgecolor(border_color) spine.set_linewidth(2) ax.set_title(f"{t:.1f}s", fontsize=VIS_CONFIG['fontsize_label'], color=border_color if border_color else 'black') ax_gt = fig.add_subplot(gs[1, :]) ax_gt.set_xlim(0, duration) ax_gt.set_ylim(0, 1) ax_gt.axis('off') ax_gt.text(-0.02 * duration, 0.5, "Ground Truth", fontsize=VIS_CONFIG['fontsize_title'], va='center', ha='right', weight='bold') for seg in gt_segments: start, end = seg['start'], seg['end'] width = end - start label = seg['label'][:10] + '...' if len(seg['label']) > 10 else seg['label'] ax_gt.add_patch(patches.Rectangle( (start, 0.3), width, 0.4, facecolor=VIS_CONFIG['gt_color'], edgecolor='black', alpha=0.8 )) ax_gt.text((start + end) / 2, 0.5, label, ha='center', va='center', fontsize=VIS_CONFIG['fontsize_label'], color='white') ax_gt.text(start, 0.2, f"{start:.1f}", ha='center', fontsize=8, color='black') ax_gt.text(end, 0.2, f"{end:.1f}", ha='center', fontsize=8, color='black') ax_pred = fig.add_subplot(gs[2, :]) ax_pred.set_xlim(0, duration) ax_pred.set_ylim(0, 1) ax_pred.axis('off') ax_pred.text(-0.02 * duration, 0.5, "Prediction", fontsize=VIS_CONFIG['fontsize_title'], va='center', ha='right', weight='bold') for seg in pred_segments: start, end = seg['start'], seg['end'] width = end - start label = seg['label'][:10] + '...' if len(seg['label']) > 10 else seg['label'] ax_pred.add_patch(patches.Rectangle( (start, 0.3), width, 0.4, facecolor=VIS_CONFIG['pred_color'], edgecolor='black', alpha=0.8 )) ax_pred.text((start + end) / 2, 0.5, label, ha='center', va='center', fontsize=VIS_CONFIG['fontsize_label'], color='white') ax_pred.text(start, 0.8, f"{start:.1f}", ha='center', fontsize=8, color='black') ax_pred.text(end, 0.8, f"{end:.1f}", ha='center', fontsize=8, color='black') jpg_path = os.path.join(save_dir, f"viz_{video_id}_{opt['exp']}.png") plt.savefig(jpg_path, dpi=100, bbox_inches='tight') plt.close() print(f"[INFO] Saved visualization: {jpg_path}") return jpg_path def train_one_epoch(opt, model, train_dataset, optimizer, warmup=False): device = torch.device("cpu") train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=opt['batch_size'], shuffle=True, num_workers=0, pin_memory=False, drop_last=True ) epoch_cost = 0 epoch_cost_cls = 0 epoch_cost_reg = 0 epoch_cost_snip = 0 total_iter = len(train_dataset) // opt['batch_size'] cls_loss = MultiCrossEntropyLoss(focal=True) snip_loss = MultiCrossEntropyLoss(focal=True) for n_iter, (input_data, cls_label, reg_label, snip_label) in enumerate(tqdm(train_loader)): if warmup: for g in optimizer.param_groups: g['lr'] = n_iter * opt['lr'] / total_iter act_cls, act_reg, snip_cls = model(input_data.float().to(device)) act_cls.register_hook(partial(cls_loss.collect_grad, cls_label)) snip_cls.register_hook(lambda grad: snip_loss.collect_grad(grad, snip_label)) cost_reg = 0 cost_cls = 0 loss = cls_loss_func_(cls_loss, act_cls) cost_cls = loss epoch_cost_cls += loss.detach().cpu().numpy() loss = regress_loss_func(reg_label, act_reg) cost_reg = loss epoch_cost_reg += loss.detach().cpu().numpy() loss = cls_loss_func_(snip_loss, snip_label, snip_cls) cost_snip = loss epoch_cost_snip += loss.detach().cpu().numpy() cost = opt['alpha'] * cost_cls + opt['beta'] * cost_reg + opt['gamma'] * cost_snip epoch_cost += cost.detach().cpu().numpy() optimizer.zero_grad() cost.backward() optimizer.step() return n_iter, epoch_cost, epoch_cost_cls, epoch_cost_reg, epoch_cost_snip def eval_frame(opt, model, dataset): device = torch.device("cpu") test_loader = torch.utils.data.DataLoader( dataset, batch_size=opt['batch_size'], shuffle=False, num_workers=0, pin_memory=False, drop_last=False ) labels_cls = {video_name: [] for video_name in dataset.video_names} labels_reg = {video_name: [] for video_name in dataset.video_names} output_cls = {video_name: [] for video_name in dataset.video_names} output_reg = {video_name: [] for video_name in dataset.video_names} start_time = time.time() total_frames = 0 epoch_cost = 0 epoch_cost_cls = 0 epoch_cost_reg = 0 cls_loss_fn = MultiCrossEntropyLoss(focal=True) for n_iter, (input_data, cls_label, reg_label, snip_label) in enumerate(tqdm(test_loader)): act_cls, act_reg, _ = model(input_data.float().to(device)) cost_reg = 0 cost_cls = 0 loss = cls_loss_func_(cls_loss_fn, act_cls) cost_cls = loss epoch_cost_cls += loss.detach().cpu().numpy() loss = regress_loss_func(reg_label, act_reg) cost_reg = loss epoch_cost_reg += loss.detach().cpu().numpy() cost = opt['alpha'] * cost_cls + opt['beta'] * cost_reg epoch_cost += cost.detach().cpu().numpy() act_cls = torch.softmax(act_cls, dim=-1) total_frames += input_data.size(0) for idx in range(input_data.size(0)): video_name, st, ed, data_idx = dataset.inputs[n_iter * opt['batch_size'] + idx] output_cls[video_name].append(act_cls[idx].detach().cpu().numpy()) output_reg[video_name].append(act_reg[idx].detach().cpu().numpy()) labels_cls[video_name].append(cls_label[idx].cpu().numpy()) labels_reg[video_name].append(reg_label[idx].cpu().numpy()) end_time = time.time() working_time = end_time - start_time for video_name in dataset.video_names: labels_cls[video_name] = np.stack(labels_cls[video_name], axis=0) labels_reg[video_name] = np.stack(labels_reg[video_name], axis=0) output_cls[video_name] = np.stack(output_cls[video_name], axis=0) output_reg[video_name] = np.stack(output_reg[video_name], axis=0) cls_loss = epoch_cost_cls / (n_iter + 1) if n_iter > 0 else 0 reg_loss = epoch_cost_reg / (n_iter + 1) if n_iter > 0 else 0 tot_loss = epoch_cost / (n_iter + 1) if n_iter > 0 else 0 return cls_loss, reg_loss, tot_loss, output_cls, output_reg, labels_cls, labels_reg, working_time, total_frames def eval_map_nms(opt, dataset, output_cls, output_reg, labels_cls, labels_reg): result_dict = {} proposal_dict = [] num_class = opt["num_of_class"] unit_size = opt['segment_size'] threshold = opt['threshold'] anchors = opt['anchors'] for video_name in dataset.video_names: duration = dataset.video_len[video_name] video_time = float(dataset.video_dict[video_name]["duration"]) frame_to_time = 100.0 * video_time / duration for idx in range(duration): cls_anc = output_cls[video_name][idx] reg_anc = output_reg[video_name][idx] proposal_anc_dict = [] for anc_idx in range(len(anchors)): cls = np.argwhere(cls_anc[anc_idx][:-1] > threshold).reshape(-1) if len(cls) == 0: continue ed = idx + anchors[anc_idx] * reg_anc[anc_idx][0] length = anchors[anc_idx] * np.exp(reg_anc[anc_idx][1]) st = ed - length for cidx in range(len(cls)): label = cls[cidx] tmp_dict = { "segment": [float(st * frame_to_time / 100.0), float(ed * frame_to_time / 100.0)], "score": float(cls_anc[anc_idx][label]), "label": dataset.label_name[label], "gentime": float(idx * frame_to_time / 100.0) } proposal_anc_dict.append(tmp_dict) proposal_dict += proposal_anc_dict proposal_dict = non_max_suppression(proposal_dict, overlapThresh=opt['soft_nms']) result_dict[video_name] = proposal_dict proposal_dict = [] return result_dict def eval_map_suppress(opt, dataset, output_cls, output_reg, labels_cls, labels_reg): device = torch.device("cpu") model = SuppressNet(opt).to(device) checkpoint_path = os.path.join(opt["checkpoint_path"], f"ckp_best_suppress.pth.tar") if os.path.exists(checkpoint_path): checkpoint = torch.load(checkpoint_path, map_location=device) model.load_state_dict(checkpoint['state_dict']) model.eval() else: print(f"[WARNING] SuppressNet checkpoint {checkpoint_path} not found. Using NMS.") return eval_map_nms(opt, dataset, output_cls, output_reg, labels_cls, labels_reg) result_dict = {} proposal_dict = [] num_class = opt["num_of_class"] unit_size = opt['segment_size'] threshold = opt['threshold'] anchors = opt['anchors'] for video_name in dataset.video_names: duration = dataset.video_len[video_name] video_time = float(dataset.video_dict[video_name]["duration"]) frame_to_time = 100.0 * video_time / duration conf_queue = torch.zeros((unit_size, num_class - 1)) for idx in range(duration): cls_anc = output_cls[video_name][idx] reg_anc = output_reg[video_name][idx] proposal_anc_dict = [] for anc_idx in range(len(anchors)): cls = np.argwhere(cls_anc[anc_idx][:-1] > threshold).reshape(-1) if len(cls) == 0: continue ed = idx + anchors[anc_idx] * reg_anc[anc_idx][0] length = anchors[anc_idx] * np.exp(reg_anc[anc_idx][1]) st = ed - length for cidx in range(len(cls)): label = cls[cidx] tmp_dict = { "segment": [float(st * frame_to_time / 100.0), float(ed * frame_to_time / 100.0)], "score": float(cls_anc[anc_idx][label]), "label": dataset.label_name[label], "gentime": float(idx * frame_to_time / 100.0) } proposal_anc_dict.append(tmp_dict) proposal_anc_dict = non_max_suppression(proposal_anc_dict, overlapThresh=opt['soft_nms']) conf_queue[:-1, :] = conf_queue[1:, :].clone() conf_queue[-1, :] = 0 for proposal in proposal_anc_dict: cls_idx = dataset.label_name.index(proposal['label']) conf_queue[-1, cls_idx] = proposal["score"] minput = conf_queue.unsqueeze(0).to(device) suppress_conf = model(minput) suppress_conf = suppress_conf.squeeze(0).detach().cpu().numpy() for cls in range(num_class - 1): if suppress_conf[cls] > opt['sup_threshold']: for proposal in proposal_anc_dict: if proposal['label'] == dataset.label_name[cls]: if check_overlap_proposal(proposal_dict, proposal, overlapThresh=opt['soft_nms']) is None: proposal_dict.append(proposal) result_dict[video_name] = proposal_dict proposal_dict = [] return result_dict def test(opt, video_name=None): device = torch.device("cpu") model = MYNET(opt).to(device) checkpoint_path = os.path.join(opt["checkpoint_path"], f"{opt['exp']}_ckp_best.pth.tar") if not os.path.exists(checkpoint_path): print(f"[ERROR] Checkpoint {checkpoint_path} not found.") return None, "", "", "" checkpoint = torch.load(checkpoint_path, map_location=device) model.load_state_dict(checkpoint['state_dict']) model.eval() dataset = VideoDataSet(opt, subset=opt['inference_subset'], video_name=video_name) cls_loss, reg_loss, tot_loss, output_cls, output_reg, labels_cls, labels_reg, _, _ = eval_frame(opt, model, dataset) if opt["pptype"] == "nms": result_dict = eval_map_nms(opt, dataset, output_cls, output_reg, labels_cls, labels_reg) elif opt["pptype"] == "net": result_dict = eval_map_suppress(opt, dataset, output_cls, output_reg, labels_cls, labels_reg) else: print(f"[WARNING] Unknown pptype {opt['pptype']}. Using NMS.") result_dict = eval_map_nms(opt, dataset, output_cls, output_reg, labels_cls, labels_reg) output_dict = {"version": "VERSION 1.3", "results": result_dict, "external_data": None} result_path = opt["result_file"].format(opt['exp']) os.makedirs(os.path.dirname(result_path), exist_ok=True) with open(result_path, 'w') as f: json.dump(output_dict, f, indent=4) mAP = evaluation_detection(opt, verbose=False) mAP_value = sum(mAP) / len(mAP) if mAP else 0 if video_name: print(f"\n[INFO] Comparing Predicted and Ground Truth Actions for Video: {video_name}") anno_path = opt["video_anno"].format(opt["split"]) if not os.path.exists(anno_path): print(f"[ERROR] Annotation file {anno_path} not found. Skipping comparison.") return mAP_value, "", "", "" with open(anno_path, 'r') as f: anno_data = json.load(f) gt_annotations = anno_data['database'][video_name]['annotations'] duration = anno_data['database'][video_name]['duration'] gt_segments = [{ 'label': anno['label'], 'start': anno['segment'][0], 'end': anno['segment'][1], 'duration': anno['segment'][1] - anno['segment'][0] } for anno in gt_annotations] pred_segments = [{ 'label': pred['label'], 'start': pred['segment'][0], 'end': pred['segment'][1], 'duration': pred['segment'][1] - pred['segment'][0], 'score': pred['score'] } for pred in result_dict.get(video_name, [])] matches = [] iou_threshold = VIS_CONFIG['iou_threshold'] used_gt_indices = set() for pred in pred_segments: best_iou = 0 best_gt_idx = None for gt_idx, gt in enumerate(gt_segments): if gt_idx in used_gt_indices: continue iou = calc_iou([pred['end'], pred['duration']], [gt['end'], gt['duration']]) if iou > best_iou and iou >= iou_threshold: best_iou = iou best_gt_idx = gt_idx if best_gt_idx is not None: matches.append({ 'pred': pred, 'gt': gt_segments[best_gt_idx], 'iou': best_iou }) used_gt_indices.add(best_gt_idx) else: matches.append({'pred': pred, 'gt': None, 'iou': 0}) for gt_idx, gt in enumerate(gt_segments): if gt_idx not in used_gt_indices: matches.append({'pred': None, 'gt': gt, 'iou': 0}) comparison_text = "\n{:<20} {:<30} {:<30} {:<15} {:<10}\n".format( "Action Label", "Predicted Segment (s)", "Ground Truth Segment (s)", "Duration Diff (s)", "IoU") comparison_text += "-" * 105 + "\n" for match in matches: pred = match['pred'] gt = match['gt'] iou = match['iou'] if pred and gt: label = pred['label'] if pred['label'] == gt['label'] else f"{pred['label']} (GT: {gt['label']})" pred_str = f"[{pred['start']:.2f}, {pred['end']:.2f}] ({pred['duration']:.2f}s)" gt_str = f"[{gt['start']:.2f}, {gt['end']:.2f}] ({gt['duration']:.2f}s)" duration_diff = pred['duration'] - gt['duration'] comparison_text += "{:<20} {:<30} {:<30} {:<15.2f} {:<10.2f}\n".format( label, pred_str, gt_str, duration_diff, iou) elif pred: pred_str = f"[{pred['start']:.2f}, {pred['end']:.2f}] ({pred['duration']:.2f}s)" comparison_text += "{:<20} {:<30} {:<30} {:<15} {:<10.2f}\n".format( pred['label'], pred_str, "None", "N/A", iou) elif gt: gt_str = f"[{gt['start']:.2f}, {gt['end']:.2f}] ({gt['duration']:.2f}s)" comparison_text += "{:<20} {:<30} {:<30} {:<15} {:<10.2f}\n".format( gt['label'], "None", gt_str, "N/A", iou) matched_count = sum(1 for m in matches if m['pred'] and m['gt']) avg_duration_diff = np.mean([m['pred']['duration'] - m['gt']['duration'] for m in matches if m['pred'] and m['gt']]) if matched_count > 0 else 0 avg_iou = np.mean([m['iou'] for m in matches if m['iou'] > 0]) if any(m['iou'] > 0 for m in matches) else 0 comparison_text += f"\nSummary:\n" comparison_text += f"- Total Predictions: {len(pred_segments)}\n" comparison_text += f"- Total Ground Truths: {len(gt_segments)}\n" comparison_text += f"- Matched Segments: {matched_count}\n" comparison_text += f"- Average Duration Difference (s): {avg_duration_diff:.2f}\n" comparison_text += f"- Average IoU (Matched): {avg_iou:.2f}\n" video_path = opt.get('video_path', '') viz_path = "" video_out_path = "" if os.path.exists(video_path): viz_path = visualize_action_lengths( video_id=video_name, pred_segments=pred_segments, gt_segments=gt_segments, video_path=video_path, duration=duration ) video_out_path = annotate_video_with_actions( video_id=video_name, pred_segments=pred_segments, gt_segments=gt_segments, video_path=video_path ) else: print(f"[WARNING] Video {video_path} not found. Skipping visualization.") return mAP_value, comparison_text, viz_path, video_out_path def test_online(opt, video_name=None): device = torch.device("cpu") model = MYNET(opt).to(device) checkpoint_path = os.path.join(opt["checkpoint_path"], f"{opt['exp']}_ckp_best.pth.tar") if not os.path.exists(checkpoint_path): print(f"[ERROR] Checkpoint {checkpoint_path} not found.") return 0 checkpoint = torch.load(checkpoint_path, map_location=device) model.load_state_dict(checkpoint['state_dict']) model.eval() sup_model = SuppressNet(opt).to(device) sup_checkpoint_path = os.path.join(opt["checkpoint_path"], f"ckp_best_suppress.pth.tar") if os.path.exists(sup_checkpoint_path): checkpoint = torch.load(sup_checkpoint_path, map_location=device) sup_model.load_state_dict(checkpoint['state_dict']) sup_model.eval() else: print(f"[WARNING] SuppressNet checkpoint {sup_checkpoint_path} not found.") dataset = VideoDataSet(opt, subset=opt['inference_subset'], video_name=video_name) test_loader = torch.utils.data.DataLoader( dataset, batch_size=1, shuffle=False, num_workers=0, pin_memory=False ) result_dict = {} proposal_dict = [] num_class = opt["num_of_class"] unit_size = opt['segment_size'] threshold = opt['threshold'] anchors = opt['anchors'] start_time = time.time() total_frames = 0 for video_name in dataset.video_names: input_queue = torch.zeros((unit_size, opt['feat_dim'])) sup_queue = torch.zeros((unit_size, num_class - 1)) duration = dataset.video_len[video_name] video_time = float(dataset.video_dict[video_name]["duration"]) frame_to_time = 100.0 * video_time / duration for idx in range(duration): total_frames += 1 input_queue[:-1, :] = input_queue[1:, :].clone() input_queue[-1, :] = dataset._get_base_data(video_name, idx, idx + 1).squeeze(0) minput = input_queue.unsqueeze(0).to(device) act_cls, act_reg, _ = model(minput) act_cls = torch.softmax(act_cls, dim=-1) cls_anc = act_cls.squeeze(0).detach().cpu().numpy() reg_anc = act_reg.squeeze(0).detach().cpu().numpy() proposal_anc_dict = [] for anc_idx in range(len(anchors)): cls = np.argwhere(cls_anc[anc_idx][:-1] > threshold).reshape(-1) if len(cls) == 0: continue ed = idx + anchors[anc_idx] * reg_anc[anc_idx][0] length = anchors[anc_idx] * np.exp(reg_anc[anc_idx][1]) st = ed - length for cidx in range(len(cls)): label = cls[cidx] tmp_dict = { "segment": [float(st * frame_to_time / 100.0), float(ed * frame_to_time / 100.0)], "score": float(cls_anc[anc_idx][label]), "label": dataset.label_name[label], "gentime": float(idx * frame_to_time / 100.0) } proposal_anc_dict.append(tmp_dict) proposal_anc_dict = non_max_suppression(proposal_anc_dict, overlapThresh=opt['soft_nms']) sup_queue[:-1, :] = sup_queue[1:, :].clone() sup_queue[-1, :] = 0 for proposal in proposal_anc_dict: cls_idx = dataset.label_name.index(proposal['label']) sup_queue[-1, cls_idx] = proposal["score"] minput = sup_queue.unsqueeze(0).to(device) suppress_conf = sup_model(minput) suppress_conf = suppress_conf.squeeze().detach().cpu().numpy() for cls in range(num_class - 1): if suppress_conf[cls] > opt['sup_threshold']: for proposal in proposal_anc_dict: if proposal['label'] == dataset.label_name[cls]: if check_overlap_proposal(proposal_dict, proposal, overlapThresh=opt['soft_nms']) is None: proposal_dict.append(proposal) result_dict[video_name] = proposal_dict proposal_dict = [] end_time = time.time() working_time = end_time - start_time print(f"[INFO] Working time: {working_time:.2f}s, {total_frames / working_time:.1f}fps, {total_frames} frames") output_dict = {"version": "VERSION 1.3", "results": result_dict, "external_data": None} result_path = opt["result_file"].format(opt['exp']) os.makedirs(os.path.dirname(result_path), exist_ok=True) with open(result_path, "w") as f: json.dump(output_dict, f, indent=4) mAP = evaluation_detection(opt, verbose=False) mAP_value = sum(mAP) / len(mAP) if mAP else 0 return mAP_value def main(opt, video_name=None): max_perf = 0 if not video_name and 'video_name' in opt: video_name = opt['video_name'] if opt['mode'] == 'train': max_perf = train(opt) elif opt['mode'] == 'test': max_perf, comparison_text, viz_path, video_out_path = test(opt, video_name=video_name) return max_perf, comparison_text, viz_path, video_out_path elif opt['mode'] == 'test_online': max_perf = test_online(opt, video_name=video_name) elif opt['mode'] == 'eval': max_perf = evaluation_detection(opt, verbose=False) return max_perf def gradio_interface(video): global opt if not video: return None, None, "Please upload a video." video_name = os.path.splitext(os.path.basename(video))[0] feature_path = os.path.join(os.getcwd(), 'data', 'features', f"{video_name}.npz") if not os.path.exists(feature_path): return None, None, f"[ERROR] Feature file {feature_path} not found for video {video_name}." opt_dict = vars(opts.parse_opt()) opt_dict['mode'] = 'test' opt_dict['video_name'] = video_name opt_dict['video_path'] = video opt_dict['video_anno'] = os.path.join(os.getcwd(), 'data', 'annotations.json') opt_dict['video_feature_all_test'] = os.path.join(os.getcwd(), 'data', 'features') + os.sep opt_dict['checkpoint_path'] = os.path.join(os.getcwd(), 'checkpoint') opt_dict['result_file'] = os.path.join(os.getcwd(), 'results', 'result_{}.json') opt_dict['frame_result_file'] = os.path.join(os.getcwd(), 'results', 'frame_result_{}.h5') opt_dict['video_len_file'] = os.path.join(os.getcwd(), 'data', 'video_len_{}.json') opt_dict['proposal_label_file'] = os.path.join(os.getcwd(), 'data', 'proposal_label_{}.h5') opt_dict['suppress_label_file'] = os.path.join(os.getcwd(), 'data', 'suppress_label_{}.h5') opt_dict['batch_size'] = 1 opt_dict['data_format'] = 'npz_i3d' opt_dict['rgb_only'] = False opt_dict['anchors'] = [int(item) for item in opt_dict['anchors'].split(',')] opt_dict['predefined_fps'] = 30 # Adjust if needed opt_dict['split'] = 'test' opt_dict['setup'] = 'default' opt_dict['data_rescale'] = 1.0 opt_dict['pos_threshold'] = 0.5 mAP, comparison_text, viz_path, video_out_path = main(opt_dict, video_name=video_name) return viz_path, video_out_path, f"mAP: {mAP:.4f}\n\n{comparison_text}" if __name__ == "__main__": opt = opts.parse_opt() opt = vars(opt) opt['checkpoint_path'] = os.path.join(os.getcwd(), 'checkpoint') opt['result_file'] = os.path.join(os.getcwd(), 'results', 'result_{}.json') opt['frame_result_file'] = os.path.join(os.getcwd(), 'results', 'frame_result_{}.h5') opt['video_anno'] = os.path.join(os.getcwd(), 'data', 'annotations.json') opt['video_feature_all_test'] = os.path.join(os.getcwd(), 'data', 'features') + os.sep opt['video_len_file'] = os.path.join(os.getcwd(), 'data', 'video_len_{}.json') opt['proposal_label_file'] = os.path.join(os.getcwd(), 'data', 'proposal_label_{}.h5') opt['suppress_label_file'] = os.path.join(os.getcwd(), 'data', 'suppress_label_{}.h5') opt['data_format'] = 'npz_i3d' opt['rgb_only'] = False opt['predefined_fps'] = 30 opt['split'] = 'test' opt['setup'] = 'default' opt['data_rescale'] = 1.0 opt['pos_threshold'] = 0.5 os.makedirs(opt["checkpoint_path"], exist_ok=True) os.makedirs(os.path.dirname(opt["result_file"].format(opt['exp'])), exist_ok=True) os.makedirs(os.path.dirname(opt["video_anno"]), exist_ok=True) with open(os.path.join(opt["checkpoint_path"], f"{opt['exp']}_opts.json"), "w") as f: json.dump(opt, f, indent=4) if opt['seed'] >= 0: torch.manual_seed(opt['seed']) np.random.seed(opt['seed']) opt['anchors'] = [int(item) for item in opt['anchors'].split(',')] video_name = opt.get('video_name', None) if opt.get('gradio', False): iface = gr.Interface( fn=gradio_interface, inputs=gr.Video(label="Upload Video"), outputs=[ gr.Image(label="Action Length Visualization"), gr.Video(label="Annotated Video"), gr.Textbox(label="Results and mAP") ], title="Action Detection Model", description="Upload a video to detect actions using pre-extracted I3D features. Ensure a corresponding .npz file exists in data/features/. View visualizations and performance metrics." ) iface.launch() else: main(opt, video_name=video_name)