|
|
import os |
|
|
import json |
|
|
import torch |
|
|
import torchvision |
|
|
import torch.nn.parallel |
|
|
import torch.nn.functional as F |
|
|
import torch.optim as optim |
|
|
import numpy as np |
|
|
import opts_egtea as opts |
|
|
import time |
|
|
import h5py |
|
|
from tqdm import tqdm |
|
|
from iou_utils import * |
|
|
from eval import evaluation_detection |
|
|
from tensorboardX import SummaryWriter |
|
|
from dataset import VideoDataSet, calc_iou |
|
|
from models import MYNET, SuppressNet |
|
|
from loss_func import cls_loss_func, cls_loss_func_, regress_loss_func |
|
|
from loss_func import MultiCrossEntropyLoss |
|
|
from functools import partial |
|
|
import matplotlib.pyplot as plt |
|
|
import matplotlib.patches as patches |
|
|
import cv2 |
|
|
from typing import List, Dict, Optional |
|
|
from PIL import Image, ImageDraw, ImageFont |
|
|
import warnings |
|
|
import gradio as gr |
|
|
import subprocess |
|
|
|
|
|
|
|
|
warnings.filterwarnings("ignore", category=UserWarning) |
|
|
warnings.filterwarnings("ignore", category=DeprecationWarning) |
|
|
|
|
|
|
|
|
VIS_CONFIG = { |
|
|
'frame_interval': 1.0, |
|
|
'max_frames': 5, |
|
|
'save_dir': os.path.join(os.getcwd(), 'output', 'visualizations'), |
|
|
'video_save_dir': os.path.join(os.getcwd(), 'output', 'videos'), |
|
|
'gt_color': '#1f77b4', |
|
|
'pred_color': '#ff7f0e', |
|
|
'fontsize_label': 10, |
|
|
'fontsize_title': 14, |
|
|
'frame_highlight_both': 'green', |
|
|
'frame_highlight_gt': 'red', |
|
|
'frame_highlight_pred': 'black', |
|
|
'iou_threshold': 0.3, |
|
|
'frame_scale_factor': 0.3, |
|
|
'video_text_scale': 0.5, |
|
|
'video_gt_text_color': (180, 119, 31), |
|
|
'video_pred_text_color': (14, 127, 255), |
|
|
'video_text_thickness': 1, |
|
|
'video_font_path': os.path.join(os.getcwd(), 'fonts', 'Poppins ExtraBold Italic 800.ttf'), |
|
|
'video_font_fallback': '/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf', |
|
|
'video_pred_text_y': 0.45, |
|
|
'video_gt_text_y': 0.55, |
|
|
'video_footer_height': 150, |
|
|
'video_gt_bar_y': 0.2, |
|
|
'video_pred_bar_y': 0.5, |
|
|
'video_bar_height': 0.15, |
|
|
'video_bar_text_scale': 0.7, |
|
|
'min_segment_duration': 1.0, |
|
|
'video_frame_text_y': 0.05, |
|
|
'video_bar_label_x': 10, |
|
|
'video_bar_label_scale': 0.5, |
|
|
'scroll_window_duration': 30.0, |
|
|
'scroll_speed': 0.5, |
|
|
} |
|
|
|
|
|
def annotate_video_with_actions( |
|
|
video_id: str, |
|
|
pred_segments: List[Dict], |
|
|
gt_segments: List[Dict], |
|
|
video_path: str, |
|
|
save_dir: str = VIS_CONFIG['video_save_dir'], |
|
|
text_scale: float = VIS_CONFIG['video_text_scale'] * 1.5, |
|
|
gt_text_color: tuple = VIS_CONFIG['video_gt_text_color'], |
|
|
pred_text_color: tuple = VIS_CONFIG['video_pred_text_color'], |
|
|
text_thickness: int = VIS_CONFIG['video_text_thickness'] |
|
|
) -> str: |
|
|
os.makedirs(save_dir, exist_ok=True) |
|
|
cap = cv2.VideoCapture(video_path) |
|
|
if not cap.isOpened(): |
|
|
print(f"Error: Could not open video {video_path}. Skipping.") |
|
|
return "" |
|
|
fps = cap.get(cv2.CAP_PROP_FPS) |
|
|
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
|
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
duration = total_frames / fps |
|
|
footer_height = VIS_CONFIG['video_footer_height'] |
|
|
output_height = frame_height + footer_height |
|
|
output_path = os.path.join(save_dir, f"annotated_{video_id}_{opt['exp']}.avi") |
|
|
mp4_path = output_path.replace('.avi', '.mp4') |
|
|
fourcc = cv2.VideoWriter_fourcc(*'XVID') |
|
|
out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, output_height)) |
|
|
if not out.isOpened(): |
|
|
print(f"Error: Could not initialize video writer for {output_path}.") |
|
|
cap.release() |
|
|
return "" |
|
|
min_duration = VIS_CONFIG['min_segment_duration'] |
|
|
gt_segments = [seg for seg in gt_segments if seg['duration'] >= min_duration] |
|
|
pred_segments = [seg for seg in pred_segments if seg['duration'] >= min_duration] |
|
|
color_palette = [ |
|
|
(128, 0, 0), (60, 20, 220), (0, 128, 0), (128, 0, 128), (79, 69, 54), |
|
|
(128, 128, 0), (0, 0, 128), (130, 0, 75), (34, 139, 34), (0, 85, 204), |
|
|
(149, 146, 209), (235, 206, 135), (250, 230, 230), (191, 226, 159), |
|
|
(185, 218, 255), (255, 204, 204), (193, 182, 255), (201, 252, 189), |
|
|
(144, 128, 112), (112, 25, 25), (102, 51, 102), (0, 128, 128), (171, 71, 0) |
|
|
] |
|
|
action_labels = set(seg['label'] for seg in gt_segments).union(set(seg['label'] for seg in pred_segments)) |
|
|
action_color_map = {label: color_palette[i % len(color_palette)] for i, label in enumerate(action_labels)} |
|
|
gt_color_rgb = (gt_text_color[2], gt_text_color[1], gt_text_color[0]) |
|
|
pred_color_rgb = (pred_text_color[2], pred_text_color[1], pred_text_color[0]) |
|
|
font_path = VIS_CONFIG['video_font_path'] |
|
|
font_fallback = VIS_CONFIG['video_font_fallback'] |
|
|
font_size = int(20 * text_scale) |
|
|
bar_font_size = int(20 * VIS_CONFIG['video_bar_text_scale']) |
|
|
font = None |
|
|
bar_font = None |
|
|
try: |
|
|
font = ImageFont.truetype(font_path, font_size) |
|
|
bar_font = ImageFont.truetype(font_path, bar_font_size) |
|
|
except IOError: |
|
|
try: |
|
|
font = ImageFont.truetype(font_fallback, font_size) |
|
|
bar_font = ImageFont.truetype(font_fallback, bar_font_size) |
|
|
except IOError: |
|
|
font = None |
|
|
bar_font = None |
|
|
window_size = VIS_CONFIG['scroll_window_duration'] |
|
|
num_windows = int(np.ceil(duration / window_size)) |
|
|
text_bar_gap = 48 |
|
|
text_x = VIS_CONFIG['video_bar_label_x'] |
|
|
frame_idx = 0 |
|
|
written_frames = 0 |
|
|
while cap.isOpened(): |
|
|
ret, frame = cap.read() |
|
|
if not ret: |
|
|
break |
|
|
extended_frame = np.zeros((output_height, frame_width, 3), dtype=np.uint8) |
|
|
extended_frame[:frame_height, :, :] = frame |
|
|
extended_frame[frame_height:, :, :] = 255 |
|
|
timestamp = frame_idx / fps |
|
|
window_idx = int(timestamp // window_size) |
|
|
window_start = window_idx * window_size |
|
|
window_end = min(window_start + window_size, duration) |
|
|
window_duration = window_end - window_start |
|
|
window_timestamp = timestamp - window_start |
|
|
gt_labels = [seg['label'] for seg in gt_segments if seg['start'] <= timestamp <= seg['end']] |
|
|
gt_text = "GT: " + ", ".join(gt_labels) if gt_labels else "GT: None" |
|
|
pred_labels = [seg['label'] for seg in pred_segments if seg['start'] <= timestamp <= seg['end']] |
|
|
pred_text = "Pred: " + ", ".join(pred_labels) if pred_labels else "Pred: None" |
|
|
footer_y = frame_height |
|
|
gt_bar_y = footer_y + int(VIS_CONFIG['video_gt_bar_y'] * footer_height) |
|
|
pred_bar_y = footer_y + int(VIS_CONFIG['video_pred_bar_y'] * footer_height) |
|
|
bar_height = int(VIS_CONFIG['video_bar_height'] * footer_height) |
|
|
if font: |
|
|
gt_text_bbox = bar_font.getbbox("GT") |
|
|
pred_text_bbox = bar_font.getbbox("Pred") |
|
|
gt_text_width = gt_text_bbox[2] - gt_text_bbox[0] |
|
|
pred_text_width = pred_text_bbox[2] - pred_text_bbox[0] |
|
|
else: |
|
|
gt_text_size = cv2.getTextSize("GT", cv2.FONT_HERSHEY_SIMPLEX, VIS_CONFIG['video_bar_text_scale'], 1)[0] |
|
|
pred_text_size = cv2.getTextSize("Pred", cv2.FONT_HERSHEY_SIMPLEX, VIS_CONFIG['video_bar_text_scale'], 1)[0] |
|
|
gt_text_width = gt_text_size[0] |
|
|
pred_text_width = pred_text_size[0] |
|
|
max_text_width = max(gt_text_width, pred_text_width) |
|
|
bar_start_x = text_x + max_text_width + text_bar_gap |
|
|
bar_width = frame_width - bar_start_x |
|
|
for seg in gt_segments: |
|
|
if seg['start'] <= window_end and seg['end'] >= window_start: |
|
|
start_t = max(seg['start'], window_start) |
|
|
end_t = min(seg['end'], window_start + window_timestamp) |
|
|
start_x = bar_start_x + int(((start_t - window_start) / window_duration) * bar_width) |
|
|
end_x = bar_start_x + int(((end_t - window_start) / window_duration) * bar_width) |
|
|
if end_x > start_x: |
|
|
cv2.rectangle( |
|
|
extended_frame, |
|
|
(start_x, gt_bar_y), |
|
|
(end_x, gt_bar_y + bar_height), |
|
|
action_color_map[seg['label']], |
|
|
-1 |
|
|
) |
|
|
for seg in pred_segments: |
|
|
if seg['start'] <= window_end and seg['end'] >= window_start: |
|
|
start_t = max(seg['start'], window_start) |
|
|
end_t = min(seg['end'], window_start + window_timestamp) |
|
|
start_x = bar_start_x + int(((start_t - window_start) / window_duration) * bar_width) |
|
|
end_x = bar_start_x + int(((end_t - window_start) / window_duration) * bar_width) |
|
|
if end_x > start_x: |
|
|
cv2.rectangle( |
|
|
extended_frame, |
|
|
(start_x, pred_bar_y), |
|
|
(end_x, pred_bar_y + bar_height), |
|
|
action_color_map[seg['label']], |
|
|
-1 |
|
|
) |
|
|
if font: |
|
|
frame_rgb = cv2.cvtColor(extended_frame, cv2.COLOR_BGR2RGB) |
|
|
pil_image = Image.fromarray(frame_rgb) |
|
|
draw = ImageDraw.Draw(pil_image) |
|
|
frame_info = f"Frame: {frame_idx} | FPS: {fps:.2f}" |
|
|
frame_text_bbox = draw.textbbox((0, 0), frame_info, font=font) |
|
|
frame_text_width = frame_text_bbox[2] - frame_text_bbox[0] |
|
|
frame_text_x = (frame_width - frame_text_width) // 2 |
|
|
draw.text((frame_text_x, int(frame_height * VIS_CONFIG['video_frame_text_y'])), frame_info, font=font, fill=(0, 0, 0)) |
|
|
window_info = f"{window_start:.1f}s - {window_end:.1f}s" |
|
|
window_text_bbox = draw.textbbox((0, 0), window_info, font=bar_font) |
|
|
window_text_width = window_text_bbox[2] - window_text_bbox[0] |
|
|
window_text_x = (frame_width - window_text_width) // 2 |
|
|
draw.text((window_text_x, footer_y + 10), window_info, font=bar_font, fill=(0, 0, 0)) |
|
|
draw.text((text_x, gt_bar_y + bar_height // 2), "GT", font=bar_font, fill=gt_color_rgb) |
|
|
draw.text((text_x, pred_bar_y + bar_height // 2), "Pred", font=bar_font, fill=pred_color_rgb) |
|
|
gt_y = int(frame_height * VIS_CONFIG['video_gt_text_y']) |
|
|
pred_y = int(frame_height * VIS_CONFIG['video_pred_text_y']) |
|
|
draw.text((10, gt_y), gt_text, font=font, fill=gt_color_rgb) |
|
|
draw.text((10, pred_y), pred_text, font=font, fill=pred_color_rgb) |
|
|
extended_frame = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR) |
|
|
else: |
|
|
frame_info = f"Frame: {frame_idx} | FPS: {fps:.2f}" |
|
|
text_size = cv2.getTextSize(frame_info, cv2.FONT_HERSHEY_SIMPLEX, text_scale, text_thickness)[0] |
|
|
frame_text_x = (frame_width - text_size[0]) // 2 |
|
|
cv2.putText( |
|
|
extended_frame, |
|
|
frame_info, |
|
|
(frame_text_x, int(frame_height * VIS_CONFIG['video_frame_text_y']) + 20), |
|
|
cv2.FONT_HERSHEY_SIMPLEX, |
|
|
text_scale, |
|
|
(0, 0, 0), |
|
|
text_thickness, |
|
|
cv2.LINE_AA |
|
|
) |
|
|
window_info = f"{window_start:.1f}s - {window_end:.1f}s" |
|
|
window_text_size = cv2.getTextSize(window_info, cv2.FONT_HERSHEY_SIMPLEX, VIS_CONFIG['video_bar_text_scale'], 1)[0] |
|
|
window_text_x = (frame_width - window_text_size[0]) // 2 |
|
|
cv2.putText( |
|
|
extended_frame, |
|
|
window_info, |
|
|
(window_text_x, footer_y + 20), |
|
|
cv2.FONT_HERSHEY_SIMPLEX, |
|
|
VIS_CONFIG['video_bar_text_scale'], |
|
|
(0, 0, 0), |
|
|
1, |
|
|
cv2.LINE_AA |
|
|
) |
|
|
cv2.putText( |
|
|
extended_frame, |
|
|
gt_text, |
|
|
(10, int(frame_height * VIS_CONFIG['video_gt_text_y'])), |
|
|
cv2.FONT_HERSHEY_SIMPLEX, |
|
|
text_scale, |
|
|
gt_text_color, |
|
|
text_thickness, |
|
|
cv2.LINE_AA |
|
|
) |
|
|
cv2.putText( |
|
|
extended_frame, |
|
|
pred_text, |
|
|
(10, int(frame_height * VIS_CONFIG['video_pred_text_y'])), |
|
|
cv2.FONT_HERSHEY_SIMPLEX, |
|
|
text_scale, |
|
|
pred_text_color, |
|
|
text_thickness, |
|
|
cv2.LINE_AA |
|
|
) |
|
|
cv2.putText( |
|
|
extended_frame, |
|
|
"GT", |
|
|
(text_x, gt_bar_y + bar_height // 2 + 5), |
|
|
cv2.FONT_HERSHEY_SIMPLEX, |
|
|
VIS_CONFIG['video_bar_text_scale'], |
|
|
gt_text_color, |
|
|
1, |
|
|
cv2.LINE_AA |
|
|
) |
|
|
cv2.putText( |
|
|
extended_frame, |
|
|
"Pred", |
|
|
(text_x, pred_bar_y + bar_height // 2 + 5), |
|
|
cv2.FONT_HERSHEY_SIMPLEX, |
|
|
VIS_CONFIG['video_bar_text_scale'], |
|
|
pred_text_color, |
|
|
1, |
|
|
cv2.LINE_AA |
|
|
) |
|
|
out.write(extended_frame) |
|
|
written_frames += 1 |
|
|
frame_idx += 1 |
|
|
cap.release() |
|
|
out.release() |
|
|
print(f"[✅ Saved Annotated Video]: {output_path}, Frames={written_frames}") |
|
|
try: |
|
|
subprocess.run(['ffmpeg', '-i', output_path, '-vcodec', 'libx264', '-acodec', 'aac', mp4_path], check=True) |
|
|
print(f"[✅ Converted to MP4]: {mp4_path}") |
|
|
return mp4_path |
|
|
except (subprocess.CalledProcessError, FileNotFoundError): |
|
|
print("Note: FFmpeg not available or failed. Returning .avi (may not play in browsers).") |
|
|
return output_path if os.path.exists(output_path) else "" |
|
|
|
|
|
def visualize_action_lengths( |
|
|
video_id: str, |
|
|
pred_segments: List[Dict], |
|
|
gt_segments: List[Dict], |
|
|
video_path: str, |
|
|
duration: float, |
|
|
save_dir: str = VIS_CONFIG['save_dir'], |
|
|
frame_interval: float = VIS_CONFIG['frame_interval'] |
|
|
) -> str: |
|
|
os.makedirs(save_dir, exist_ok=True) |
|
|
num_frames = int(duration / frame_interval) + 1 |
|
|
if num_frames > VIS_CONFIG['max_frames']: |
|
|
frame_interval = duration / (VIS_CONFIG['max_frames'] - 1) |
|
|
num_frames = VIS_CONFIG['max_frames'] |
|
|
frame_times = np.linspace(0, duration, num_frames, endpoint=True) |
|
|
frames = [] |
|
|
cap = cv2.VideoCapture(video_path) |
|
|
if not cap.isOpened(): |
|
|
print(f"Warning: Could not open video {video_path}. Using placeholders.") |
|
|
frames = [np.ones((100, 100, 3), dtype=np.uint8) * 255 for _ in frame_times] |
|
|
else: |
|
|
for t in frame_times: |
|
|
cap.set(cv2.CAP_PROP_POS_MSEC, t * 1000) |
|
|
ret, frame = cap.read() |
|
|
if ret: |
|
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
|
frame = cv2.resize(frame, (int(frame.shape[1] * VIS_CONFIG['frame_scale_factor']), int(frame.shape[0] * VIS_CONFIG['frame_scale_factor']))) |
|
|
frames.append(frame) |
|
|
else: |
|
|
frames.append(np.ones((100, 100, 3), dtype=np.uint8) * 255) |
|
|
cap.release() |
|
|
fig = plt.figure(figsize=(num_frames * 2, 6), constrained_layout=True) |
|
|
gs = fig.add_gridspec(3, num_frames, height_ratios=[3, 1, 1]) |
|
|
for i, (t, frame) in enumerate(zip(frame_times, frames)): |
|
|
ax = fig.add_subplot(gs[0, i]) |
|
|
gt_hit = any(seg['start'] <= t <= seg['end'] for seg in gt_segments) |
|
|
pred_hit = any(seg['start'] <= t <= seg['end'] for seg in pred_segments) |
|
|
border_color = None |
|
|
if gt_hit and pred_hit: |
|
|
border_color = VIS_CONFIG['frame_highlight_both'] |
|
|
elif gt_hit: |
|
|
border_color = VIS_CONFIG['frame_highlight_gt'] |
|
|
elif pred_hit: |
|
|
border_color = VIS_CONFIG['frame_highlight_pred'] |
|
|
ax.imshow(frame) |
|
|
ax.axis('off') |
|
|
if border_color: |
|
|
for spine in ax.spines.values(): |
|
|
spine.set_edgecolor(border_color) |
|
|
spine.set_linewidth(2) |
|
|
ax.set_title(f"{t:.1f}s", fontsize=VIS_CONFIG['fontsize_label'], |
|
|
color=border_color if border_color else 'black') |
|
|
ax_gt = fig.add_subplot(gs[1, :]) |
|
|
ax_gt.set_xlim(0, duration) |
|
|
ax_gt.set_ylim(0, 1) |
|
|
ax_gt.axis('off') |
|
|
ax_gt.text(-0.02 * duration, 0.5, "Ground Truth", fontsize=VIS_CONFIG['fontsize_title'], |
|
|
va='center', ha='right', weight='bold') |
|
|
for seg in gt_segments: |
|
|
start, end = seg['start'], seg['end'] |
|
|
width = end - start |
|
|
label = seg['label'][:10] + '...' if len(seg['label']) > 10 else seg['label'] |
|
|
ax_gt.add_patch(patches.Rectangle( |
|
|
(start, 0.3), width, 0.4, facecolor=VIS_CONFIG['gt_color'], |
|
|
edgecolor='black', alpha=0.8 |
|
|
)) |
|
|
ax_gt.text((start + end) / 2, 0.5, label, ha='center', va='center', |
|
|
fontsize=VIS_CONFIG['fontsize_label'], color='white') |
|
|
ax_gt.text(start, 0.2, f"{start:.1f}", ha='center', fontsize=8, color='black') |
|
|
ax_gt.text(end, 0.2, f"{end:.1f}", ha='center', fontsize=8, color='black') |
|
|
ax_pred = fig.add_subplot(gs[2, :]) |
|
|
ax_pred.set_xlim(0, duration) |
|
|
ax_pred.set_ylim(0, 1) |
|
|
ax_pred.axis('off') |
|
|
ax_pred.text(-0.02 * duration, 0.5, "Prediction", fontsize=VIS_CONFIG['fontsize_title'], |
|
|
va='center', ha='right', weight='bold') |
|
|
for seg in pred_segments: |
|
|
start, end = seg['start'], seg['end'] |
|
|
width = end - start |
|
|
label = seg['label'][:10] + '...' if len(seg['label']) > 10 else seg['label'] |
|
|
ax_pred.add_patch(patches.Rectangle( |
|
|
(start, 0.3), width, 0.4, facecolor=VIS_CONFIG['pred_color'], |
|
|
edgecolor='black', alpha=0.8 |
|
|
)) |
|
|
ax_pred.text((start + end) / 2, 0.5, label, ha='center', va='center', |
|
|
fontsize=VIS_CONFIG['fontsize_label'], color='white') |
|
|
ax_pred.text(start, 0.8, f"{start:.1f}", ha='center', fontsize=8, color='black') |
|
|
ax_pred.text(end, 0.8, f"{end:.1f}", ha='center', fontsize=8, color='black') |
|
|
jpg_path = os.path.join(save_dir, f"viz_{video_id}_{opt['exp']}.png") |
|
|
plt.savefig(jpg_path, dpi=100, bbox_inches='tight') |
|
|
plt.close() |
|
|
print(f"[INFO] Saved visualization: {jpg_path}") |
|
|
return jpg_path |
|
|
|
|
|
def train_one_epoch(opt, model, train_dataset, optimizer, warmup=False): |
|
|
device = torch.device("cpu") |
|
|
train_loader = torch.utils.data.DataLoader( |
|
|
train_dataset, |
|
|
batch_size=opt['batch_size'], shuffle=True, |
|
|
num_workers=0, pin_memory=False, drop_last=True |
|
|
) |
|
|
epoch_cost = 0 |
|
|
epoch_cost_cls = 0 |
|
|
epoch_cost_reg = 0 |
|
|
epoch_cost_snip = 0 |
|
|
total_iter = len(train_dataset) // opt['batch_size'] |
|
|
cls_loss = MultiCrossEntropyLoss(focal=True) |
|
|
snip_loss = MultiCrossEntropyLoss(focal=True) |
|
|
for n_iter, (input_data, cls_label, reg_label, snip_label) in enumerate(tqdm(train_loader)): |
|
|
if warmup: |
|
|
for g in optimizer.param_groups: |
|
|
g['lr'] = n_iter * opt['lr'] / total_iter |
|
|
act_cls, act_reg, snip_cls = model(input_data.float().to(device)) |
|
|
act_cls.register_hook(partial(cls_loss.collect_grad, cls_label)) |
|
|
snip_cls.register_hook(lambda grad: snip_loss.collect_grad(grad, snip_label)) |
|
|
cost_reg = 0 |
|
|
cost_cls = 0 |
|
|
loss = cls_loss_func_(cls_loss, act_cls) |
|
|
cost_cls = loss |
|
|
epoch_cost_cls += loss.detach().cpu().numpy() |
|
|
loss = regress_loss_func(reg_label, act_reg) |
|
|
cost_reg = loss |
|
|
epoch_cost_reg += loss.detach().cpu().numpy() |
|
|
loss = cls_loss_func_(snip_loss, snip_label, snip_cls) |
|
|
cost_snip = loss |
|
|
epoch_cost_snip += loss.detach().cpu().numpy() |
|
|
cost = opt['alpha'] * cost_cls + opt['beta'] * cost_reg + opt['gamma'] * cost_snip |
|
|
epoch_cost += cost.detach().cpu().numpy() |
|
|
optimizer.zero_grad() |
|
|
cost.backward() |
|
|
optimizer.step() |
|
|
return n_iter, epoch_cost, epoch_cost_cls, epoch_cost_reg, epoch_cost_snip |
|
|
|
|
|
def eval_frame(opt, model, dataset): |
|
|
device = torch.device("cpu") |
|
|
test_loader = torch.utils.data.DataLoader( |
|
|
dataset, |
|
|
batch_size=opt['batch_size'], shuffle=False, |
|
|
num_workers=0, pin_memory=False, drop_last=False |
|
|
) |
|
|
labels_cls = {video_name: [] for video_name in dataset.video_names} |
|
|
labels_reg = {video_name: [] for video_name in dataset.video_names} |
|
|
output_cls = {video_name: [] for video_name in dataset.video_names} |
|
|
output_reg = {video_name: [] for video_name in dataset.video_names} |
|
|
start_time = time.time() |
|
|
total_frames = 0 |
|
|
epoch_cost = 0 |
|
|
epoch_cost_cls = 0 |
|
|
epoch_cost_reg = 0 |
|
|
cls_loss_fn = MultiCrossEntropyLoss(focal=True) |
|
|
for n_iter, (input_data, cls_label, reg_label, snip_label) in enumerate(tqdm(test_loader)): |
|
|
act_cls, act_reg, _ = model(input_data.float().to(device)) |
|
|
cost_reg = 0 |
|
|
cost_cls = 0 |
|
|
loss = cls_loss_func_(cls_loss_fn, act_cls) |
|
|
cost_cls = loss |
|
|
epoch_cost_cls += loss.detach().cpu().numpy() |
|
|
loss = regress_loss_func(reg_label, act_reg) |
|
|
cost_reg = loss |
|
|
epoch_cost_reg += loss.detach().cpu().numpy() |
|
|
cost = opt['alpha'] * cost_cls + opt['beta'] * cost_reg |
|
|
epoch_cost += cost.detach().cpu().numpy() |
|
|
act_cls = torch.softmax(act_cls, dim=-1) |
|
|
total_frames += input_data.size(0) |
|
|
for idx in range(input_data.size(0)): |
|
|
video_name, st, ed, data_idx = dataset.inputs[n_iter * opt['batch_size'] + idx] |
|
|
output_cls[video_name].append(act_cls[idx].detach().cpu().numpy()) |
|
|
output_reg[video_name].append(act_reg[idx].detach().cpu().numpy()) |
|
|
labels_cls[video_name].append(cls_label[idx].cpu().numpy()) |
|
|
labels_reg[video_name].append(reg_label[idx].cpu().numpy()) |
|
|
end_time = time.time() |
|
|
working_time = end_time - start_time |
|
|
for video_name in dataset.video_names: |
|
|
labels_cls[video_name] = np.stack(labels_cls[video_name], axis=0) |
|
|
labels_reg[video_name] = np.stack(labels_reg[video_name], axis=0) |
|
|
output_cls[video_name] = np.stack(output_cls[video_name], axis=0) |
|
|
output_reg[video_name] = np.stack(output_reg[video_name], axis=0) |
|
|
cls_loss = epoch_cost_cls / (n_iter + 1) if n_iter > 0 else 0 |
|
|
reg_loss = epoch_cost_reg / (n_iter + 1) if n_iter > 0 else 0 |
|
|
tot_loss = epoch_cost / (n_iter + 1) if n_iter > 0 else 0 |
|
|
return cls_loss, reg_loss, tot_loss, output_cls, output_reg, labels_cls, labels_reg, working_time, total_frames |
|
|
|
|
|
def eval_map_nms(opt, dataset, output_cls, output_reg, labels_cls, labels_reg): |
|
|
result_dict = {} |
|
|
proposal_dict = [] |
|
|
num_class = opt["num_of_class"] |
|
|
unit_size = opt['segment_size'] |
|
|
threshold = opt['threshold'] |
|
|
anchors = opt['anchors'] |
|
|
for video_name in dataset.video_names: |
|
|
duration = dataset.video_len[video_name] |
|
|
video_time = float(dataset.video_dict[video_name]["duration"]) |
|
|
frame_to_time = 100.0 * video_time / duration |
|
|
for idx in range(duration): |
|
|
cls_anc = output_cls[video_name][idx] |
|
|
reg_anc = output_reg[video_name][idx] |
|
|
proposal_anc_dict = [] |
|
|
for anc_idx in range(len(anchors)): |
|
|
cls = np.argwhere(cls_anc[anc_idx][:-1] > threshold).reshape(-1) |
|
|
if len(cls) == 0: |
|
|
continue |
|
|
ed = idx + anchors[anc_idx] * reg_anc[anc_idx][0] |
|
|
length = anchors[anc_idx] * np.exp(reg_anc[anc_idx][1]) |
|
|
st = ed - length |
|
|
for cidx in range(len(cls)): |
|
|
label = cls[cidx] |
|
|
tmp_dict = { |
|
|
"segment": [float(st * frame_to_time / 100.0), float(ed * frame_to_time / 100.0)], |
|
|
"score": float(cls_anc[anc_idx][label]), |
|
|
"label": dataset.label_name[label], |
|
|
"gentime": float(idx * frame_to_time / 100.0) |
|
|
} |
|
|
proposal_anc_dict.append(tmp_dict) |
|
|
proposal_dict += proposal_anc_dict |
|
|
proposal_dict = non_max_suppression(proposal_dict, overlapThresh=opt['soft_nms']) |
|
|
result_dict[video_name] = proposal_dict |
|
|
proposal_dict = [] |
|
|
return result_dict |
|
|
|
|
|
def eval_map_suppress(opt, dataset, output_cls, output_reg, labels_cls, labels_reg): |
|
|
device = torch.device("cpu") |
|
|
model = SuppressNet(opt).to(device) |
|
|
checkpoint_path = os.path.join(opt["checkpoint_path"], f"ckp_best_suppress.pth.tar") |
|
|
if os.path.exists(checkpoint_path): |
|
|
checkpoint = torch.load(checkpoint_path, map_location=device) |
|
|
model.load_state_dict(checkpoint['state_dict']) |
|
|
model.eval() |
|
|
else: |
|
|
print(f"[WARNING] SuppressNet checkpoint {checkpoint_path} not found. Using NMS.") |
|
|
return eval_map_nms(opt, dataset, output_cls, output_reg, labels_cls, labels_reg) |
|
|
result_dict = {} |
|
|
proposal_dict = [] |
|
|
num_class = opt["num_of_class"] |
|
|
unit_size = opt['segment_size'] |
|
|
threshold = opt['threshold'] |
|
|
anchors = opt['anchors'] |
|
|
for video_name in dataset.video_names: |
|
|
duration = dataset.video_len[video_name] |
|
|
video_time = float(dataset.video_dict[video_name]["duration"]) |
|
|
frame_to_time = 100.0 * video_time / duration |
|
|
conf_queue = torch.zeros((unit_size, num_class - 1)) |
|
|
for idx in range(duration): |
|
|
cls_anc = output_cls[video_name][idx] |
|
|
reg_anc = output_reg[video_name][idx] |
|
|
proposal_anc_dict = [] |
|
|
for anc_idx in range(len(anchors)): |
|
|
cls = np.argwhere(cls_anc[anc_idx][:-1] > threshold).reshape(-1) |
|
|
if len(cls) == 0: |
|
|
continue |
|
|
ed = idx + anchors[anc_idx] * reg_anc[anc_idx][0] |
|
|
length = anchors[anc_idx] * np.exp(reg_anc[anc_idx][1]) |
|
|
st = ed - length |
|
|
for cidx in range(len(cls)): |
|
|
label = cls[cidx] |
|
|
tmp_dict = { |
|
|
"segment": [float(st * frame_to_time / 100.0), float(ed * frame_to_time / 100.0)], |
|
|
"score": float(cls_anc[anc_idx][label]), |
|
|
"label": dataset.label_name[label], |
|
|
"gentime": float(idx * frame_to_time / 100.0) |
|
|
} |
|
|
proposal_anc_dict.append(tmp_dict) |
|
|
proposal_anc_dict = non_max_suppression(proposal_anc_dict, overlapThresh=opt['soft_nms']) |
|
|
conf_queue[:-1, :] = conf_queue[1:, :].clone() |
|
|
conf_queue[-1, :] = 0 |
|
|
for proposal in proposal_anc_dict: |
|
|
cls_idx = dataset.label_name.index(proposal['label']) |
|
|
conf_queue[-1, cls_idx] = proposal["score"] |
|
|
minput = conf_queue.unsqueeze(0).to(device) |
|
|
suppress_conf = model(minput) |
|
|
suppress_conf = suppress_conf.squeeze(0).detach().cpu().numpy() |
|
|
for cls in range(num_class - 1): |
|
|
if suppress_conf[cls] > opt['sup_threshold']: |
|
|
for proposal in proposal_anc_dict: |
|
|
if proposal['label'] == dataset.label_name[cls]: |
|
|
if check_overlap_proposal(proposal_dict, proposal, overlapThresh=opt['soft_nms']) is None: |
|
|
proposal_dict.append(proposal) |
|
|
result_dict[video_name] = proposal_dict |
|
|
proposal_dict = [] |
|
|
return result_dict |
|
|
|
|
|
def test(opt, video_name=None): |
|
|
device = torch.device("cpu") |
|
|
model = MYNET(opt).to(device) |
|
|
checkpoint_path = os.path.join(opt["checkpoint_path"], f"{opt['exp']}_ckp_best.pth.tar") |
|
|
if not os.path.exists(checkpoint_path): |
|
|
print(f"[ERROR] Checkpoint {checkpoint_path} not found.") |
|
|
return None, "", "", "" |
|
|
checkpoint = torch.load(checkpoint_path, map_location=device) |
|
|
model.load_state_dict(checkpoint['state_dict']) |
|
|
model.eval() |
|
|
dataset = VideoDataSet(opt, subset=opt['inference_subset'], video_name=video_name) |
|
|
cls_loss, reg_loss, tot_loss, output_cls, output_reg, labels_cls, labels_reg, _, _ = eval_frame(opt, model, dataset) |
|
|
if opt["pptype"] == "nms": |
|
|
result_dict = eval_map_nms(opt, dataset, output_cls, output_reg, labels_cls, labels_reg) |
|
|
elif opt["pptype"] == "net": |
|
|
result_dict = eval_map_suppress(opt, dataset, output_cls, output_reg, labels_cls, labels_reg) |
|
|
else: |
|
|
print(f"[WARNING] Unknown pptype {opt['pptype']}. Using NMS.") |
|
|
result_dict = eval_map_nms(opt, dataset, output_cls, output_reg, labels_cls, labels_reg) |
|
|
output_dict = {"version": "VERSION 1.3", "results": result_dict, "external_data": None} |
|
|
result_path = opt["result_file"].format(opt['exp']) |
|
|
os.makedirs(os.path.dirname(result_path), exist_ok=True) |
|
|
with open(result_path, 'w') as f: |
|
|
json.dump(output_dict, f, indent=4) |
|
|
mAP = evaluation_detection(opt, verbose=False) |
|
|
mAP_value = sum(mAP) / len(mAP) if mAP else 0 |
|
|
if video_name: |
|
|
print(f"\n[INFO] Comparing Predicted and Ground Truth Actions for Video: {video_name}") |
|
|
anno_path = opt["video_anno"].format(opt["split"]) |
|
|
if not os.path.exists(anno_path): |
|
|
print(f"[ERROR] Annotation file {anno_path} not found. Skipping comparison.") |
|
|
return mAP_value, "", "", "" |
|
|
with open(anno_path, 'r') as f: |
|
|
anno_data = json.load(f) |
|
|
gt_annotations = anno_data['database'][video_name]['annotations'] |
|
|
duration = anno_data['database'][video_name]['duration'] |
|
|
gt_segments = [{ |
|
|
'label': anno['label'], |
|
|
'start': anno['segment'][0], |
|
|
'end': anno['segment'][1], |
|
|
'duration': anno['segment'][1] - anno['segment'][0] |
|
|
} for anno in gt_annotations] |
|
|
pred_segments = [{ |
|
|
'label': pred['label'], |
|
|
'start': pred['segment'][0], |
|
|
'end': pred['segment'][1], |
|
|
'duration': pred['segment'][1] - pred['segment'][0], |
|
|
'score': pred['score'] |
|
|
} for pred in result_dict.get(video_name, [])] |
|
|
matches = [] |
|
|
iou_threshold = VIS_CONFIG['iou_threshold'] |
|
|
used_gt_indices = set() |
|
|
for pred in pred_segments: |
|
|
best_iou = 0 |
|
|
best_gt_idx = None |
|
|
for gt_idx, gt in enumerate(gt_segments): |
|
|
if gt_idx in used_gt_indices: |
|
|
continue |
|
|
iou = calc_iou([pred['end'], pred['duration']], [gt['end'], gt['duration']]) |
|
|
if iou > best_iou and iou >= iou_threshold: |
|
|
best_iou = iou |
|
|
best_gt_idx = gt_idx |
|
|
if best_gt_idx is not None: |
|
|
matches.append({ |
|
|
'pred': pred, |
|
|
'gt': gt_segments[best_gt_idx], |
|
|
'iou': best_iou |
|
|
}) |
|
|
used_gt_indices.add(best_gt_idx) |
|
|
else: |
|
|
matches.append({'pred': pred, 'gt': None, 'iou': 0}) |
|
|
for gt_idx, gt in enumerate(gt_segments): |
|
|
if gt_idx not in used_gt_indices: |
|
|
matches.append({'pred': None, 'gt': gt, 'iou': 0}) |
|
|
comparison_text = "\n{:<20} {:<30} {:<30} {:<15} {:<10}\n".format( |
|
|
"Action Label", "Predicted Segment (s)", "Ground Truth Segment (s)", "Duration Diff (s)", "IoU") |
|
|
comparison_text += "-" * 105 + "\n" |
|
|
for match in matches: |
|
|
pred = match['pred'] |
|
|
gt = match['gt'] |
|
|
iou = match['iou'] |
|
|
if pred and gt: |
|
|
label = pred['label'] if pred['label'] == gt['label'] else f"{pred['label']} (GT: {gt['label']})" |
|
|
pred_str = f"[{pred['start']:.2f}, {pred['end']:.2f}] ({pred['duration']:.2f}s)" |
|
|
gt_str = f"[{gt['start']:.2f}, {gt['end']:.2f}] ({gt['duration']:.2f}s)" |
|
|
duration_diff = pred['duration'] - gt['duration'] |
|
|
comparison_text += "{:<20} {:<30} {:<30} {:<15.2f} {:<10.2f}\n".format( |
|
|
label, pred_str, gt_str, duration_diff, iou) |
|
|
elif pred: |
|
|
pred_str = f"[{pred['start']:.2f}, {pred['end']:.2f}] ({pred['duration']:.2f}s)" |
|
|
comparison_text += "{:<20} {:<30} {:<30} {:<15} {:<10.2f}\n".format( |
|
|
pred['label'], pred_str, "None", "N/A", iou) |
|
|
elif gt: |
|
|
gt_str = f"[{gt['start']:.2f}, {gt['end']:.2f}] ({gt['duration']:.2f}s)" |
|
|
comparison_text += "{:<20} {:<30} {:<30} {:<15} {:<10.2f}\n".format( |
|
|
gt['label'], "None", gt_str, "N/A", iou) |
|
|
matched_count = sum(1 for m in matches if m['pred'] and m['gt']) |
|
|
avg_duration_diff = np.mean([m['pred']['duration'] - m['gt']['duration'] for m in matches if m['pred'] and m['gt']]) if matched_count > 0 else 0 |
|
|
avg_iou = np.mean([m['iou'] for m in matches if m['iou'] > 0]) if any(m['iou'] > 0 for m in matches) else 0 |
|
|
comparison_text += f"\nSummary:\n" |
|
|
comparison_text += f"- Total Predictions: {len(pred_segments)}\n" |
|
|
comparison_text += f"- Total Ground Truths: {len(gt_segments)}\n" |
|
|
comparison_text += f"- Matched Segments: {matched_count}\n" |
|
|
comparison_text += f"- Average Duration Difference (s): {avg_duration_diff:.2f}\n" |
|
|
comparison_text += f"- Average IoU (Matched): {avg_iou:.2f}\n" |
|
|
video_path = opt.get('video_path', '') |
|
|
viz_path = "" |
|
|
video_out_path = "" |
|
|
if os.path.exists(video_path): |
|
|
viz_path = visualize_action_lengths( |
|
|
video_id=video_name, |
|
|
pred_segments=pred_segments, |
|
|
gt_segments=gt_segments, |
|
|
video_path=video_path, |
|
|
duration=duration |
|
|
) |
|
|
video_out_path = annotate_video_with_actions( |
|
|
video_id=video_name, |
|
|
pred_segments=pred_segments, |
|
|
gt_segments=gt_segments, |
|
|
video_path=video_path |
|
|
) |
|
|
else: |
|
|
print(f"[WARNING] Video {video_path} not found. Skipping visualization.") |
|
|
return mAP_value, comparison_text, viz_path, video_out_path |
|
|
|
|
|
def test_online(opt, video_name=None): |
|
|
device = torch.device("cpu") |
|
|
model = MYNET(opt).to(device) |
|
|
checkpoint_path = os.path.join(opt["checkpoint_path"], f"{opt['exp']}_ckp_best.pth.tar") |
|
|
if not os.path.exists(checkpoint_path): |
|
|
print(f"[ERROR] Checkpoint {checkpoint_path} not found.") |
|
|
return 0 |
|
|
checkpoint = torch.load(checkpoint_path, map_location=device) |
|
|
model.load_state_dict(checkpoint['state_dict']) |
|
|
model.eval() |
|
|
sup_model = SuppressNet(opt).to(device) |
|
|
sup_checkpoint_path = os.path.join(opt["checkpoint_path"], f"ckp_best_suppress.pth.tar") |
|
|
if os.path.exists(sup_checkpoint_path): |
|
|
checkpoint = torch.load(sup_checkpoint_path, map_location=device) |
|
|
sup_model.load_state_dict(checkpoint['state_dict']) |
|
|
sup_model.eval() |
|
|
else: |
|
|
print(f"[WARNING] SuppressNet checkpoint {sup_checkpoint_path} not found.") |
|
|
dataset = VideoDataSet(opt, subset=opt['inference_subset'], video_name=video_name) |
|
|
test_loader = torch.utils.data.DataLoader( |
|
|
dataset, |
|
|
batch_size=1, |
|
|
shuffle=False, |
|
|
num_workers=0, |
|
|
pin_memory=False |
|
|
) |
|
|
result_dict = {} |
|
|
proposal_dict = [] |
|
|
num_class = opt["num_of_class"] |
|
|
unit_size = opt['segment_size'] |
|
|
threshold = opt['threshold'] |
|
|
anchors = opt['anchors'] |
|
|
start_time = time.time() |
|
|
total_frames = 0 |
|
|
for video_name in dataset.video_names: |
|
|
input_queue = torch.zeros((unit_size, opt['feat_dim'])) |
|
|
sup_queue = torch.zeros((unit_size, num_class - 1)) |
|
|
duration = dataset.video_len[video_name] |
|
|
video_time = float(dataset.video_dict[video_name]["duration"]) |
|
|
frame_to_time = 100.0 * video_time / duration |
|
|
for idx in range(duration): |
|
|
total_frames += 1 |
|
|
input_queue[:-1, :] = input_queue[1:, :].clone() |
|
|
input_queue[-1, :] = dataset._get_base_data(video_name, idx, idx + 1).squeeze(0) |
|
|
minput = input_queue.unsqueeze(0).to(device) |
|
|
act_cls, act_reg, _ = model(minput) |
|
|
act_cls = torch.softmax(act_cls, dim=-1) |
|
|
cls_anc = act_cls.squeeze(0).detach().cpu().numpy() |
|
|
reg_anc = act_reg.squeeze(0).detach().cpu().numpy() |
|
|
proposal_anc_dict = [] |
|
|
for anc_idx in range(len(anchors)): |
|
|
cls = np.argwhere(cls_anc[anc_idx][:-1] > threshold).reshape(-1) |
|
|
if len(cls) == 0: |
|
|
continue |
|
|
ed = idx + anchors[anc_idx] * reg_anc[anc_idx][0] |
|
|
length = anchors[anc_idx] * np.exp(reg_anc[anc_idx][1]) |
|
|
st = ed - length |
|
|
for cidx in range(len(cls)): |
|
|
label = cls[cidx] |
|
|
tmp_dict = { |
|
|
"segment": [float(st * frame_to_time / 100.0), float(ed * frame_to_time / 100.0)], |
|
|
"score": float(cls_anc[anc_idx][label]), |
|
|
"label": dataset.label_name[label], |
|
|
"gentime": float(idx * frame_to_time / 100.0) |
|
|
} |
|
|
proposal_anc_dict.append(tmp_dict) |
|
|
proposal_anc_dict = non_max_suppression(proposal_anc_dict, overlapThresh=opt['soft_nms']) |
|
|
sup_queue[:-1, :] = sup_queue[1:, :].clone() |
|
|
sup_queue[-1, :] = 0 |
|
|
for proposal in proposal_anc_dict: |
|
|
cls_idx = dataset.label_name.index(proposal['label']) |
|
|
sup_queue[-1, cls_idx] = proposal["score"] |
|
|
minput = sup_queue.unsqueeze(0).to(device) |
|
|
suppress_conf = sup_model(minput) |
|
|
suppress_conf = suppress_conf.squeeze().detach().cpu().numpy() |
|
|
for cls in range(num_class - 1): |
|
|
if suppress_conf[cls] > opt['sup_threshold']: |
|
|
for proposal in proposal_anc_dict: |
|
|
if proposal['label'] == dataset.label_name[cls]: |
|
|
if check_overlap_proposal(proposal_dict, proposal, overlapThresh=opt['soft_nms']) is None: |
|
|
proposal_dict.append(proposal) |
|
|
result_dict[video_name] = proposal_dict |
|
|
proposal_dict = [] |
|
|
end_time = time.time() |
|
|
working_time = end_time - start_time |
|
|
print(f"[INFO] Working time: {working_time:.2f}s, {total_frames / working_time:.1f}fps, {total_frames} frames") |
|
|
output_dict = {"version": "VERSION 1.3", "results": result_dict, "external_data": None} |
|
|
result_path = opt["result_file"].format(opt['exp']) |
|
|
os.makedirs(os.path.dirname(result_path), exist_ok=True) |
|
|
with open(result_path, "w") as f: |
|
|
json.dump(output_dict, f, indent=4) |
|
|
mAP = evaluation_detection(opt, verbose=False) |
|
|
mAP_value = sum(mAP) / len(mAP) if mAP else 0 |
|
|
return mAP_value |
|
|
|
|
|
def main(opt, video_name=None): |
|
|
max_perf = 0 |
|
|
if not video_name and 'video_name' in opt: |
|
|
video_name = opt['video_name'] |
|
|
if opt['mode'] == 'train': |
|
|
max_perf = train(opt) |
|
|
elif opt['mode'] == 'test': |
|
|
max_perf, comparison_text, viz_path, video_out_path = test(opt, video_name=video_name) |
|
|
return max_perf, comparison_text, viz_path, video_out_path |
|
|
elif opt['mode'] == 'test_online': |
|
|
max_perf = test_online(opt, video_name=video_name) |
|
|
elif opt['mode'] == 'eval': |
|
|
max_perf = evaluation_detection(opt, verbose=False) |
|
|
return max_perf |
|
|
|
|
|
def gradio_interface(video): |
|
|
global opt |
|
|
if not video: |
|
|
return None, None, "Please upload a video." |
|
|
video_name = os.path.splitext(os.path.basename(video))[0] |
|
|
feature_path = os.path.join(os.getcwd(), 'data', 'features', f"{video_name}.npz") |
|
|
if not os.path.exists(feature_path): |
|
|
return None, None, f"[ERROR] Feature file {feature_path} not found for video {video_name}." |
|
|
opt_dict = vars(opts.parse_opt()) |
|
|
opt_dict['mode'] = 'test' |
|
|
opt_dict['video_name'] = video_name |
|
|
opt_dict['video_path'] = video |
|
|
opt_dict['video_anno'] = os.path.join(os.getcwd(), 'data', 'annotations.json') |
|
|
opt_dict['video_feature_all_test'] = os.path.join(os.getcwd(), 'data', 'features') + os.sep |
|
|
opt_dict['checkpoint_path'] = os.path.join(os.getcwd(), 'checkpoint') |
|
|
opt_dict['result_file'] = os.path.join(os.getcwd(), 'results', 'result_{}.json') |
|
|
opt_dict['frame_result_file'] = os.path.join(os.getcwd(), 'results', 'frame_result_{}.h5') |
|
|
opt_dict['video_len_file'] = os.path.join(os.getcwd(), 'data', 'video_len_{}.json') |
|
|
opt_dict['proposal_label_file'] = os.path.join(os.getcwd(), 'data', 'proposal_label_{}.h5') |
|
|
opt_dict['suppress_label_file'] = os.path.join(os.getcwd(), 'data', 'suppress_label_{}.h5') |
|
|
opt_dict['batch_size'] = 1 |
|
|
opt_dict['data_format'] = 'npz_i3d' |
|
|
opt_dict['rgb_only'] = False |
|
|
opt_dict['anchors'] = [int(item) for item in opt_dict['anchors'].split(',')] |
|
|
opt_dict['predefined_fps'] = 30 |
|
|
opt_dict['split'] = 'test' |
|
|
opt_dict['setup'] = 'default' |
|
|
opt_dict['data_rescale'] = 1.0 |
|
|
opt_dict['pos_threshold'] = 0.5 |
|
|
mAP, comparison_text, viz_path, video_out_path = main(opt_dict, video_name=video_name) |
|
|
return viz_path, video_out_path, f"mAP: {mAP:.4f}\n\n{comparison_text}" |
|
|
|
|
|
if __name__ == "__main__": |
|
|
opt = opts.parse_opt() |
|
|
opt = vars(opt) |
|
|
opt['checkpoint_path'] = os.path.join(os.getcwd(), 'checkpoint') |
|
|
opt['result_file'] = os.path.join(os.getcwd(), 'results', 'result_{}.json') |
|
|
opt['frame_result_file'] = os.path.join(os.getcwd(), 'results', 'frame_result_{}.h5') |
|
|
opt['video_anno'] = os.path.join(os.getcwd(), 'data', 'annotations.json') |
|
|
opt['video_feature_all_test'] = os.path.join(os.getcwd(), 'data', 'features') + os.sep |
|
|
opt['video_len_file'] = os.path.join(os.getcwd(), 'data', 'video_len_{}.json') |
|
|
opt['proposal_label_file'] = os.path.join(os.getcwd(), 'data', 'proposal_label_{}.h5') |
|
|
opt['suppress_label_file'] = os.path.join(os.getcwd(), 'data', 'suppress_label_{}.h5') |
|
|
opt['data_format'] = 'npz_i3d' |
|
|
opt['rgb_only'] = False |
|
|
opt['predefined_fps'] = 30 |
|
|
opt['split'] = 'test' |
|
|
opt['setup'] = 'default' |
|
|
opt['data_rescale'] = 1.0 |
|
|
opt['pos_threshold'] = 0.5 |
|
|
os.makedirs(opt["checkpoint_path"], exist_ok=True) |
|
|
os.makedirs(os.path.dirname(opt["result_file"].format(opt['exp'])), exist_ok=True) |
|
|
os.makedirs(os.path.dirname(opt["video_anno"]), exist_ok=True) |
|
|
with open(os.path.join(opt["checkpoint_path"], f"{opt['exp']}_opts.json"), "w") as f: |
|
|
json.dump(opt, f, indent=4) |
|
|
if opt['seed'] >= 0: |
|
|
torch.manual_seed(opt['seed']) |
|
|
np.random.seed(opt['seed']) |
|
|
opt['anchors'] = [int(item) for item in opt['anchors'].split(',')] |
|
|
video_name = opt.get('video_name', None) |
|
|
if opt.get('gradio', False): |
|
|
iface = gr.Interface( |
|
|
fn=gradio_interface, |
|
|
inputs=gr.Video(label="Upload Video"), |
|
|
outputs=[ |
|
|
gr.Image(label="Action Length Visualization"), |
|
|
gr.Video(label="Annotated Video"), |
|
|
gr.Textbox(label="Results and mAP") |
|
|
], |
|
|
title="Action Detection Model", |
|
|
description="Upload a video to detect actions using pre-extracted I3D features. Ensure a corresponding .npz file exists in data/features/. View visualizations and performance metrics." |
|
|
) |
|
|
iface.launch() |
|
|
else: |
|
|
main(opt, video_name=video_name) |