HATTAL / main.py
Darknsu's picture
Update main.py
a719c01 verified
import os
import json
import torch
import torchvision
import torch.nn.parallel
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import opts_egtea as opts
import time
import h5py
from tqdm import tqdm
from iou_utils import *
from eval import evaluation_detection
from tensorboardX import SummaryWriter
from dataset import VideoDataSet, calc_iou
from models import MYNET, SuppressNet
from loss_func import cls_loss_func, cls_loss_func_, regress_loss_func
from loss_func import MultiCrossEntropyLoss
from functools import partial
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import cv2
from typing import List, Dict, Optional
from PIL import Image, ImageDraw, ImageFont
import warnings
import gradio as gr
import subprocess
# Suppress non-critical warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=DeprecationWarning)
# Visualization Configuration (Optimized for HF Free CPU)
VIS_CONFIG = {
'frame_interval': 1.0,
'max_frames': 5, # Reduced for CPU memory
'save_dir': os.path.join(os.getcwd(), 'output', 'visualizations'),
'video_save_dir': os.path.join(os.getcwd(), 'output', 'videos'),
'gt_color': '#1f77b4', # Blue for ground truth
'pred_color': '#ff7f0e', # Orange for predictions
'fontsize_label': 10,
'fontsize_title': 14,
'frame_highlight_both': 'green',
'frame_highlight_gt': 'red',
'frame_highlight_pred': 'black',
'iou_threshold': 0.3,
'frame_scale_factor': 0.3,
'video_text_scale': 0.5,
'video_gt_text_color': (180, 119, 31), # BGR
'video_pred_text_color': (14, 127, 255), # BGR
'video_text_thickness': 1,
'video_font_path': os.path.join(os.getcwd(), 'fonts', 'Poppins ExtraBold Italic 800.ttf'),
'video_font_fallback': '/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf',
'video_pred_text_y': 0.45,
'video_gt_text_y': 0.55,
'video_footer_height': 150,
'video_gt_bar_y': 0.2,
'video_pred_bar_y': 0.5,
'video_bar_height': 0.15,
'video_bar_text_scale': 0.7,
'min_segment_duration': 1.0,
'video_frame_text_y': 0.05,
'video_bar_label_x': 10,
'video_bar_label_scale': 0.5,
'scroll_window_duration': 30.0,
'scroll_speed': 0.5,
}
def annotate_video_with_actions(
video_id: str,
pred_segments: List[Dict],
gt_segments: List[Dict],
video_path: str,
save_dir: str = VIS_CONFIG['video_save_dir'],
text_scale: float = VIS_CONFIG['video_text_scale'] * 1.5,
gt_text_color: tuple = VIS_CONFIG['video_gt_text_color'],
pred_text_color: tuple = VIS_CONFIG['video_pred_text_color'],
text_thickness: int = VIS_CONFIG['video_text_thickness']
) -> str:
os.makedirs(save_dir, exist_ok=True)
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
print(f"Error: Could not open video {video_path}. Skipping.")
return ""
fps = cap.get(cv2.CAP_PROP_FPS)
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
duration = total_frames / fps
footer_height = VIS_CONFIG['video_footer_height']
output_height = frame_height + footer_height
output_path = os.path.join(save_dir, f"annotated_{video_id}_{opt['exp']}.avi")
mp4_path = output_path.replace('.avi', '.mp4')
fourcc = cv2.VideoWriter_fourcc(*'XVID')
out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, output_height))
if not out.isOpened():
print(f"Error: Could not initialize video writer for {output_path}.")
cap.release()
return ""
min_duration = VIS_CONFIG['min_segment_duration']
gt_segments = [seg for seg in gt_segments if seg['duration'] >= min_duration]
pred_segments = [seg for seg in pred_segments if seg['duration'] >= min_duration]
color_palette = [
(128, 0, 0), (60, 20, 220), (0, 128, 0), (128, 0, 128), (79, 69, 54),
(128, 128, 0), (0, 0, 128), (130, 0, 75), (34, 139, 34), (0, 85, 204),
(149, 146, 209), (235, 206, 135), (250, 230, 230), (191, 226, 159),
(185, 218, 255), (255, 204, 204), (193, 182, 255), (201, 252, 189),
(144, 128, 112), (112, 25, 25), (102, 51, 102), (0, 128, 128), (171, 71, 0)
]
action_labels = set(seg['label'] for seg in gt_segments).union(set(seg['label'] for seg in pred_segments))
action_color_map = {label: color_palette[i % len(color_palette)] for i, label in enumerate(action_labels)}
gt_color_rgb = (gt_text_color[2], gt_text_color[1], gt_text_color[0])
pred_color_rgb = (pred_text_color[2], pred_text_color[1], pred_text_color[0])
font_path = VIS_CONFIG['video_font_path']
font_fallback = VIS_CONFIG['video_font_fallback']
font_size = int(20 * text_scale)
bar_font_size = int(20 * VIS_CONFIG['video_bar_text_scale'])
font = None
bar_font = None
try:
font = ImageFont.truetype(font_path, font_size)
bar_font = ImageFont.truetype(font_path, bar_font_size)
except IOError:
try:
font = ImageFont.truetype(font_fallback, font_size)
bar_font = ImageFont.truetype(font_fallback, bar_font_size)
except IOError:
font = None
bar_font = None
window_size = VIS_CONFIG['scroll_window_duration']
num_windows = int(np.ceil(duration / window_size))
text_bar_gap = 48
text_x = VIS_CONFIG['video_bar_label_x']
frame_idx = 0
written_frames = 0
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
extended_frame = np.zeros((output_height, frame_width, 3), dtype=np.uint8)
extended_frame[:frame_height, :, :] = frame
extended_frame[frame_height:, :, :] = 255
timestamp = frame_idx / fps
window_idx = int(timestamp // window_size)
window_start = window_idx * window_size
window_end = min(window_start + window_size, duration)
window_duration = window_end - window_start
window_timestamp = timestamp - window_start
gt_labels = [seg['label'] for seg in gt_segments if seg['start'] <= timestamp <= seg['end']]
gt_text = "GT: " + ", ".join(gt_labels) if gt_labels else "GT: None"
pred_labels = [seg['label'] for seg in pred_segments if seg['start'] <= timestamp <= seg['end']]
pred_text = "Pred: " + ", ".join(pred_labels) if pred_labels else "Pred: None"
footer_y = frame_height
gt_bar_y = footer_y + int(VIS_CONFIG['video_gt_bar_y'] * footer_height)
pred_bar_y = footer_y + int(VIS_CONFIG['video_pred_bar_y'] * footer_height)
bar_height = int(VIS_CONFIG['video_bar_height'] * footer_height)
if font:
gt_text_bbox = bar_font.getbbox("GT")
pred_text_bbox = bar_font.getbbox("Pred")
gt_text_width = gt_text_bbox[2] - gt_text_bbox[0]
pred_text_width = pred_text_bbox[2] - pred_text_bbox[0]
else:
gt_text_size = cv2.getTextSize("GT", cv2.FONT_HERSHEY_SIMPLEX, VIS_CONFIG['video_bar_text_scale'], 1)[0]
pred_text_size = cv2.getTextSize("Pred", cv2.FONT_HERSHEY_SIMPLEX, VIS_CONFIG['video_bar_text_scale'], 1)[0]
gt_text_width = gt_text_size[0]
pred_text_width = pred_text_size[0]
max_text_width = max(gt_text_width, pred_text_width)
bar_start_x = text_x + max_text_width + text_bar_gap
bar_width = frame_width - bar_start_x
for seg in gt_segments:
if seg['start'] <= window_end and seg['end'] >= window_start:
start_t = max(seg['start'], window_start)
end_t = min(seg['end'], window_start + window_timestamp)
start_x = bar_start_x + int(((start_t - window_start) / window_duration) * bar_width)
end_x = bar_start_x + int(((end_t - window_start) / window_duration) * bar_width)
if end_x > start_x:
cv2.rectangle(
extended_frame,
(start_x, gt_bar_y),
(end_x, gt_bar_y + bar_height),
action_color_map[seg['label']],
-1
)
for seg in pred_segments:
if seg['start'] <= window_end and seg['end'] >= window_start:
start_t = max(seg['start'], window_start)
end_t = min(seg['end'], window_start + window_timestamp)
start_x = bar_start_x + int(((start_t - window_start) / window_duration) * bar_width)
end_x = bar_start_x + int(((end_t - window_start) / window_duration) * bar_width)
if end_x > start_x:
cv2.rectangle(
extended_frame,
(start_x, pred_bar_y),
(end_x, pred_bar_y + bar_height),
action_color_map[seg['label']],
-1
)
if font:
frame_rgb = cv2.cvtColor(extended_frame, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(frame_rgb)
draw = ImageDraw.Draw(pil_image)
frame_info = f"Frame: {frame_idx} | FPS: {fps:.2f}"
frame_text_bbox = draw.textbbox((0, 0), frame_info, font=font)
frame_text_width = frame_text_bbox[2] - frame_text_bbox[0]
frame_text_x = (frame_width - frame_text_width) // 2
draw.text((frame_text_x, int(frame_height * VIS_CONFIG['video_frame_text_y'])), frame_info, font=font, fill=(0, 0, 0))
window_info = f"{window_start:.1f}s - {window_end:.1f}s"
window_text_bbox = draw.textbbox((0, 0), window_info, font=bar_font)
window_text_width = window_text_bbox[2] - window_text_bbox[0]
window_text_x = (frame_width - window_text_width) // 2
draw.text((window_text_x, footer_y + 10), window_info, font=bar_font, fill=(0, 0, 0))
draw.text((text_x, gt_bar_y + bar_height // 2), "GT", font=bar_font, fill=gt_color_rgb)
draw.text((text_x, pred_bar_y + bar_height // 2), "Pred", font=bar_font, fill=pred_color_rgb)
gt_y = int(frame_height * VIS_CONFIG['video_gt_text_y'])
pred_y = int(frame_height * VIS_CONFIG['video_pred_text_y'])
draw.text((10, gt_y), gt_text, font=font, fill=gt_color_rgb)
draw.text((10, pred_y), pred_text, font=font, fill=pred_color_rgb)
extended_frame = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
else:
frame_info = f"Frame: {frame_idx} | FPS: {fps:.2f}"
text_size = cv2.getTextSize(frame_info, cv2.FONT_HERSHEY_SIMPLEX, text_scale, text_thickness)[0]
frame_text_x = (frame_width - text_size[0]) // 2
cv2.putText(
extended_frame,
frame_info,
(frame_text_x, int(frame_height * VIS_CONFIG['video_frame_text_y']) + 20),
cv2.FONT_HERSHEY_SIMPLEX,
text_scale,
(0, 0, 0),
text_thickness,
cv2.LINE_AA
)
window_info = f"{window_start:.1f}s - {window_end:.1f}s"
window_text_size = cv2.getTextSize(window_info, cv2.FONT_HERSHEY_SIMPLEX, VIS_CONFIG['video_bar_text_scale'], 1)[0]
window_text_x = (frame_width - window_text_size[0]) // 2
cv2.putText(
extended_frame,
window_info,
(window_text_x, footer_y + 20),
cv2.FONT_HERSHEY_SIMPLEX,
VIS_CONFIG['video_bar_text_scale'],
(0, 0, 0),
1,
cv2.LINE_AA
)
cv2.putText(
extended_frame,
gt_text,
(10, int(frame_height * VIS_CONFIG['video_gt_text_y'])),
cv2.FONT_HERSHEY_SIMPLEX,
text_scale,
gt_text_color,
text_thickness,
cv2.LINE_AA
)
cv2.putText(
extended_frame,
pred_text,
(10, int(frame_height * VIS_CONFIG['video_pred_text_y'])),
cv2.FONT_HERSHEY_SIMPLEX,
text_scale,
pred_text_color,
text_thickness,
cv2.LINE_AA
)
cv2.putText(
extended_frame,
"GT",
(text_x, gt_bar_y + bar_height // 2 + 5),
cv2.FONT_HERSHEY_SIMPLEX,
VIS_CONFIG['video_bar_text_scale'],
gt_text_color,
1,
cv2.LINE_AA
)
cv2.putText(
extended_frame,
"Pred",
(text_x, pred_bar_y + bar_height // 2 + 5),
cv2.FONT_HERSHEY_SIMPLEX,
VIS_CONFIG['video_bar_text_scale'],
pred_text_color,
1,
cv2.LINE_AA
)
out.write(extended_frame)
written_frames += 1
frame_idx += 1
cap.release()
out.release()
print(f"[✅ Saved Annotated Video]: {output_path}, Frames={written_frames}")
try:
subprocess.run(['ffmpeg', '-i', output_path, '-vcodec', 'libx264', '-acodec', 'aac', mp4_path], check=True)
print(f"[✅ Converted to MP4]: {mp4_path}")
return mp4_path
except (subprocess.CalledProcessError, FileNotFoundError):
print("Note: FFmpeg not available or failed. Returning .avi (may not play in browsers).")
return output_path if os.path.exists(output_path) else ""
def visualize_action_lengths(
video_id: str,
pred_segments: List[Dict],
gt_segments: List[Dict],
video_path: str,
duration: float,
save_dir: str = VIS_CONFIG['save_dir'],
frame_interval: float = VIS_CONFIG['frame_interval']
) -> str:
os.makedirs(save_dir, exist_ok=True)
num_frames = int(duration / frame_interval) + 1
if num_frames > VIS_CONFIG['max_frames']:
frame_interval = duration / (VIS_CONFIG['max_frames'] - 1)
num_frames = VIS_CONFIG['max_frames']
frame_times = np.linspace(0, duration, num_frames, endpoint=True)
frames = []
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
print(f"Warning: Could not open video {video_path}. Using placeholders.")
frames = [np.ones((100, 100, 3), dtype=np.uint8) * 255 for _ in frame_times]
else:
for t in frame_times:
cap.set(cv2.CAP_PROP_POS_MSEC, t * 1000)
ret, frame = cap.read()
if ret:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = cv2.resize(frame, (int(frame.shape[1] * VIS_CONFIG['frame_scale_factor']), int(frame.shape[0] * VIS_CONFIG['frame_scale_factor'])))
frames.append(frame)
else:
frames.append(np.ones((100, 100, 3), dtype=np.uint8) * 255)
cap.release()
fig = plt.figure(figsize=(num_frames * 2, 6), constrained_layout=True)
gs = fig.add_gridspec(3, num_frames, height_ratios=[3, 1, 1])
for i, (t, frame) in enumerate(zip(frame_times, frames)):
ax = fig.add_subplot(gs[0, i])
gt_hit = any(seg['start'] <= t <= seg['end'] for seg in gt_segments)
pred_hit = any(seg['start'] <= t <= seg['end'] for seg in pred_segments)
border_color = None
if gt_hit and pred_hit:
border_color = VIS_CONFIG['frame_highlight_both']
elif gt_hit:
border_color = VIS_CONFIG['frame_highlight_gt']
elif pred_hit:
border_color = VIS_CONFIG['frame_highlight_pred']
ax.imshow(frame)
ax.axis('off')
if border_color:
for spine in ax.spines.values():
spine.set_edgecolor(border_color)
spine.set_linewidth(2)
ax.set_title(f"{t:.1f}s", fontsize=VIS_CONFIG['fontsize_label'],
color=border_color if border_color else 'black')
ax_gt = fig.add_subplot(gs[1, :])
ax_gt.set_xlim(0, duration)
ax_gt.set_ylim(0, 1)
ax_gt.axis('off')
ax_gt.text(-0.02 * duration, 0.5, "Ground Truth", fontsize=VIS_CONFIG['fontsize_title'],
va='center', ha='right', weight='bold')
for seg in gt_segments:
start, end = seg['start'], seg['end']
width = end - start
label = seg['label'][:10] + '...' if len(seg['label']) > 10 else seg['label']
ax_gt.add_patch(patches.Rectangle(
(start, 0.3), width, 0.4, facecolor=VIS_CONFIG['gt_color'],
edgecolor='black', alpha=0.8
))
ax_gt.text((start + end) / 2, 0.5, label, ha='center', va='center',
fontsize=VIS_CONFIG['fontsize_label'], color='white')
ax_gt.text(start, 0.2, f"{start:.1f}", ha='center', fontsize=8, color='black')
ax_gt.text(end, 0.2, f"{end:.1f}", ha='center', fontsize=8, color='black')
ax_pred = fig.add_subplot(gs[2, :])
ax_pred.set_xlim(0, duration)
ax_pred.set_ylim(0, 1)
ax_pred.axis('off')
ax_pred.text(-0.02 * duration, 0.5, "Prediction", fontsize=VIS_CONFIG['fontsize_title'],
va='center', ha='right', weight='bold')
for seg in pred_segments:
start, end = seg['start'], seg['end']
width = end - start
label = seg['label'][:10] + '...' if len(seg['label']) > 10 else seg['label']
ax_pred.add_patch(patches.Rectangle(
(start, 0.3), width, 0.4, facecolor=VIS_CONFIG['pred_color'],
edgecolor='black', alpha=0.8
))
ax_pred.text((start + end) / 2, 0.5, label, ha='center', va='center',
fontsize=VIS_CONFIG['fontsize_label'], color='white')
ax_pred.text(start, 0.8, f"{start:.1f}", ha='center', fontsize=8, color='black')
ax_pred.text(end, 0.8, f"{end:.1f}", ha='center', fontsize=8, color='black')
jpg_path = os.path.join(save_dir, f"viz_{video_id}_{opt['exp']}.png")
plt.savefig(jpg_path, dpi=100, bbox_inches='tight')
plt.close()
print(f"[INFO] Saved visualization: {jpg_path}")
return jpg_path
def train_one_epoch(opt, model, train_dataset, optimizer, warmup=False):
device = torch.device("cpu")
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=opt['batch_size'], shuffle=True,
num_workers=0, pin_memory=False, drop_last=True
)
epoch_cost = 0
epoch_cost_cls = 0
epoch_cost_reg = 0
epoch_cost_snip = 0
total_iter = len(train_dataset) // opt['batch_size']
cls_loss = MultiCrossEntropyLoss(focal=True)
snip_loss = MultiCrossEntropyLoss(focal=True)
for n_iter, (input_data, cls_label, reg_label, snip_label) in enumerate(tqdm(train_loader)):
if warmup:
for g in optimizer.param_groups:
g['lr'] = n_iter * opt['lr'] / total_iter
act_cls, act_reg, snip_cls = model(input_data.float().to(device))
act_cls.register_hook(partial(cls_loss.collect_grad, cls_label))
snip_cls.register_hook(lambda grad: snip_loss.collect_grad(grad, snip_label))
cost_reg = 0
cost_cls = 0
loss = cls_loss_func_(cls_loss, act_cls)
cost_cls = loss
epoch_cost_cls += loss.detach().cpu().numpy()
loss = regress_loss_func(reg_label, act_reg)
cost_reg = loss
epoch_cost_reg += loss.detach().cpu().numpy()
loss = cls_loss_func_(snip_loss, snip_label, snip_cls)
cost_snip = loss
epoch_cost_snip += loss.detach().cpu().numpy()
cost = opt['alpha'] * cost_cls + opt['beta'] * cost_reg + opt['gamma'] * cost_snip
epoch_cost += cost.detach().cpu().numpy()
optimizer.zero_grad()
cost.backward()
optimizer.step()
return n_iter, epoch_cost, epoch_cost_cls, epoch_cost_reg, epoch_cost_snip
def eval_frame(opt, model, dataset):
device = torch.device("cpu")
test_loader = torch.utils.data.DataLoader(
dataset,
batch_size=opt['batch_size'], shuffle=False,
num_workers=0, pin_memory=False, drop_last=False
)
labels_cls = {video_name: [] for video_name in dataset.video_names}
labels_reg = {video_name: [] for video_name in dataset.video_names}
output_cls = {video_name: [] for video_name in dataset.video_names}
output_reg = {video_name: [] for video_name in dataset.video_names}
start_time = time.time()
total_frames = 0
epoch_cost = 0
epoch_cost_cls = 0
epoch_cost_reg = 0
cls_loss_fn = MultiCrossEntropyLoss(focal=True)
for n_iter, (input_data, cls_label, reg_label, snip_label) in enumerate(tqdm(test_loader)):
act_cls, act_reg, _ = model(input_data.float().to(device))
cost_reg = 0
cost_cls = 0
loss = cls_loss_func_(cls_loss_fn, act_cls)
cost_cls = loss
epoch_cost_cls += loss.detach().cpu().numpy()
loss = regress_loss_func(reg_label, act_reg)
cost_reg = loss
epoch_cost_reg += loss.detach().cpu().numpy()
cost = opt['alpha'] * cost_cls + opt['beta'] * cost_reg
epoch_cost += cost.detach().cpu().numpy()
act_cls = torch.softmax(act_cls, dim=-1)
total_frames += input_data.size(0)
for idx in range(input_data.size(0)):
video_name, st, ed, data_idx = dataset.inputs[n_iter * opt['batch_size'] + idx]
output_cls[video_name].append(act_cls[idx].detach().cpu().numpy())
output_reg[video_name].append(act_reg[idx].detach().cpu().numpy())
labels_cls[video_name].append(cls_label[idx].cpu().numpy())
labels_reg[video_name].append(reg_label[idx].cpu().numpy())
end_time = time.time()
working_time = end_time - start_time
for video_name in dataset.video_names:
labels_cls[video_name] = np.stack(labels_cls[video_name], axis=0)
labels_reg[video_name] = np.stack(labels_reg[video_name], axis=0)
output_cls[video_name] = np.stack(output_cls[video_name], axis=0)
output_reg[video_name] = np.stack(output_reg[video_name], axis=0)
cls_loss = epoch_cost_cls / (n_iter + 1) if n_iter > 0 else 0
reg_loss = epoch_cost_reg / (n_iter + 1) if n_iter > 0 else 0
tot_loss = epoch_cost / (n_iter + 1) if n_iter > 0 else 0
return cls_loss, reg_loss, tot_loss, output_cls, output_reg, labels_cls, labels_reg, working_time, total_frames
def eval_map_nms(opt, dataset, output_cls, output_reg, labels_cls, labels_reg):
result_dict = {}
proposal_dict = []
num_class = opt["num_of_class"]
unit_size = opt['segment_size']
threshold = opt['threshold']
anchors = opt['anchors']
for video_name in dataset.video_names:
duration = dataset.video_len[video_name]
video_time = float(dataset.video_dict[video_name]["duration"])
frame_to_time = 100.0 * video_time / duration
for idx in range(duration):
cls_anc = output_cls[video_name][idx]
reg_anc = output_reg[video_name][idx]
proposal_anc_dict = []
for anc_idx in range(len(anchors)):
cls = np.argwhere(cls_anc[anc_idx][:-1] > threshold).reshape(-1)
if len(cls) == 0:
continue
ed = idx + anchors[anc_idx] * reg_anc[anc_idx][0]
length = anchors[anc_idx] * np.exp(reg_anc[anc_idx][1])
st = ed - length
for cidx in range(len(cls)):
label = cls[cidx]
tmp_dict = {
"segment": [float(st * frame_to_time / 100.0), float(ed * frame_to_time / 100.0)],
"score": float(cls_anc[anc_idx][label]),
"label": dataset.label_name[label],
"gentime": float(idx * frame_to_time / 100.0)
}
proposal_anc_dict.append(tmp_dict)
proposal_dict += proposal_anc_dict
proposal_dict = non_max_suppression(proposal_dict, overlapThresh=opt['soft_nms'])
result_dict[video_name] = proposal_dict
proposal_dict = []
return result_dict
def eval_map_suppress(opt, dataset, output_cls, output_reg, labels_cls, labels_reg):
device = torch.device("cpu")
model = SuppressNet(opt).to(device)
checkpoint_path = os.path.join(opt["checkpoint_path"], f"ckp_best_suppress.pth.tar")
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['state_dict'])
model.eval()
else:
print(f"[WARNING] SuppressNet checkpoint {checkpoint_path} not found. Using NMS.")
return eval_map_nms(opt, dataset, output_cls, output_reg, labels_cls, labels_reg)
result_dict = {}
proposal_dict = []
num_class = opt["num_of_class"]
unit_size = opt['segment_size']
threshold = opt['threshold']
anchors = opt['anchors']
for video_name in dataset.video_names:
duration = dataset.video_len[video_name]
video_time = float(dataset.video_dict[video_name]["duration"])
frame_to_time = 100.0 * video_time / duration
conf_queue = torch.zeros((unit_size, num_class - 1))
for idx in range(duration):
cls_anc = output_cls[video_name][idx]
reg_anc = output_reg[video_name][idx]
proposal_anc_dict = []
for anc_idx in range(len(anchors)):
cls = np.argwhere(cls_anc[anc_idx][:-1] > threshold).reshape(-1)
if len(cls) == 0:
continue
ed = idx + anchors[anc_idx] * reg_anc[anc_idx][0]
length = anchors[anc_idx] * np.exp(reg_anc[anc_idx][1])
st = ed - length
for cidx in range(len(cls)):
label = cls[cidx]
tmp_dict = {
"segment": [float(st * frame_to_time / 100.0), float(ed * frame_to_time / 100.0)],
"score": float(cls_anc[anc_idx][label]),
"label": dataset.label_name[label],
"gentime": float(idx * frame_to_time / 100.0)
}
proposal_anc_dict.append(tmp_dict)
proposal_anc_dict = non_max_suppression(proposal_anc_dict, overlapThresh=opt['soft_nms'])
conf_queue[:-1, :] = conf_queue[1:, :].clone()
conf_queue[-1, :] = 0
for proposal in proposal_anc_dict:
cls_idx = dataset.label_name.index(proposal['label'])
conf_queue[-1, cls_idx] = proposal["score"]
minput = conf_queue.unsqueeze(0).to(device)
suppress_conf = model(minput)
suppress_conf = suppress_conf.squeeze(0).detach().cpu().numpy()
for cls in range(num_class - 1):
if suppress_conf[cls] > opt['sup_threshold']:
for proposal in proposal_anc_dict:
if proposal['label'] == dataset.label_name[cls]:
if check_overlap_proposal(proposal_dict, proposal, overlapThresh=opt['soft_nms']) is None:
proposal_dict.append(proposal)
result_dict[video_name] = proposal_dict
proposal_dict = []
return result_dict
def test(opt, video_name=None):
device = torch.device("cpu")
model = MYNET(opt).to(device)
checkpoint_path = os.path.join(opt["checkpoint_path"], f"{opt['exp']}_ckp_best.pth.tar")
if not os.path.exists(checkpoint_path):
print(f"[ERROR] Checkpoint {checkpoint_path} not found.")
return None, "", "", ""
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['state_dict'])
model.eval()
dataset = VideoDataSet(opt, subset=opt['inference_subset'], video_name=video_name)
cls_loss, reg_loss, tot_loss, output_cls, output_reg, labels_cls, labels_reg, _, _ = eval_frame(opt, model, dataset)
if opt["pptype"] == "nms":
result_dict = eval_map_nms(opt, dataset, output_cls, output_reg, labels_cls, labels_reg)
elif opt["pptype"] == "net":
result_dict = eval_map_suppress(opt, dataset, output_cls, output_reg, labels_cls, labels_reg)
else:
print(f"[WARNING] Unknown pptype {opt['pptype']}. Using NMS.")
result_dict = eval_map_nms(opt, dataset, output_cls, output_reg, labels_cls, labels_reg)
output_dict = {"version": "VERSION 1.3", "results": result_dict, "external_data": None}
result_path = opt["result_file"].format(opt['exp'])
os.makedirs(os.path.dirname(result_path), exist_ok=True)
with open(result_path, 'w') as f:
json.dump(output_dict, f, indent=4)
mAP = evaluation_detection(opt, verbose=False)
mAP_value = sum(mAP) / len(mAP) if mAP else 0
if video_name:
print(f"\n[INFO] Comparing Predicted and Ground Truth Actions for Video: {video_name}")
anno_path = opt["video_anno"].format(opt["split"])
if not os.path.exists(anno_path):
print(f"[ERROR] Annotation file {anno_path} not found. Skipping comparison.")
return mAP_value, "", "", ""
with open(anno_path, 'r') as f:
anno_data = json.load(f)
gt_annotations = anno_data['database'][video_name]['annotations']
duration = anno_data['database'][video_name]['duration']
gt_segments = [{
'label': anno['label'],
'start': anno['segment'][0],
'end': anno['segment'][1],
'duration': anno['segment'][1] - anno['segment'][0]
} for anno in gt_annotations]
pred_segments = [{
'label': pred['label'],
'start': pred['segment'][0],
'end': pred['segment'][1],
'duration': pred['segment'][1] - pred['segment'][0],
'score': pred['score']
} for pred in result_dict.get(video_name, [])]
matches = []
iou_threshold = VIS_CONFIG['iou_threshold']
used_gt_indices = set()
for pred in pred_segments:
best_iou = 0
best_gt_idx = None
for gt_idx, gt in enumerate(gt_segments):
if gt_idx in used_gt_indices:
continue
iou = calc_iou([pred['end'], pred['duration']], [gt['end'], gt['duration']])
if iou > best_iou and iou >= iou_threshold:
best_iou = iou
best_gt_idx = gt_idx
if best_gt_idx is not None:
matches.append({
'pred': pred,
'gt': gt_segments[best_gt_idx],
'iou': best_iou
})
used_gt_indices.add(best_gt_idx)
else:
matches.append({'pred': pred, 'gt': None, 'iou': 0})
for gt_idx, gt in enumerate(gt_segments):
if gt_idx not in used_gt_indices:
matches.append({'pred': None, 'gt': gt, 'iou': 0})
comparison_text = "\n{:<20} {:<30} {:<30} {:<15} {:<10}\n".format(
"Action Label", "Predicted Segment (s)", "Ground Truth Segment (s)", "Duration Diff (s)", "IoU")
comparison_text += "-" * 105 + "\n"
for match in matches:
pred = match['pred']
gt = match['gt']
iou = match['iou']
if pred and gt:
label = pred['label'] if pred['label'] == gt['label'] else f"{pred['label']} (GT: {gt['label']})"
pred_str = f"[{pred['start']:.2f}, {pred['end']:.2f}] ({pred['duration']:.2f}s)"
gt_str = f"[{gt['start']:.2f}, {gt['end']:.2f}] ({gt['duration']:.2f}s)"
duration_diff = pred['duration'] - gt['duration']
comparison_text += "{:<20} {:<30} {:<30} {:<15.2f} {:<10.2f}\n".format(
label, pred_str, gt_str, duration_diff, iou)
elif pred:
pred_str = f"[{pred['start']:.2f}, {pred['end']:.2f}] ({pred['duration']:.2f}s)"
comparison_text += "{:<20} {:<30} {:<30} {:<15} {:<10.2f}\n".format(
pred['label'], pred_str, "None", "N/A", iou)
elif gt:
gt_str = f"[{gt['start']:.2f}, {gt['end']:.2f}] ({gt['duration']:.2f}s)"
comparison_text += "{:<20} {:<30} {:<30} {:<15} {:<10.2f}\n".format(
gt['label'], "None", gt_str, "N/A", iou)
matched_count = sum(1 for m in matches if m['pred'] and m['gt'])
avg_duration_diff = np.mean([m['pred']['duration'] - m['gt']['duration'] for m in matches if m['pred'] and m['gt']]) if matched_count > 0 else 0
avg_iou = np.mean([m['iou'] for m in matches if m['iou'] > 0]) if any(m['iou'] > 0 for m in matches) else 0
comparison_text += f"\nSummary:\n"
comparison_text += f"- Total Predictions: {len(pred_segments)}\n"
comparison_text += f"- Total Ground Truths: {len(gt_segments)}\n"
comparison_text += f"- Matched Segments: {matched_count}\n"
comparison_text += f"- Average Duration Difference (s): {avg_duration_diff:.2f}\n"
comparison_text += f"- Average IoU (Matched): {avg_iou:.2f}\n"
video_path = opt.get('video_path', '')
viz_path = ""
video_out_path = ""
if os.path.exists(video_path):
viz_path = visualize_action_lengths(
video_id=video_name,
pred_segments=pred_segments,
gt_segments=gt_segments,
video_path=video_path,
duration=duration
)
video_out_path = annotate_video_with_actions(
video_id=video_name,
pred_segments=pred_segments,
gt_segments=gt_segments,
video_path=video_path
)
else:
print(f"[WARNING] Video {video_path} not found. Skipping visualization.")
return mAP_value, comparison_text, viz_path, video_out_path
def test_online(opt, video_name=None):
device = torch.device("cpu")
model = MYNET(opt).to(device)
checkpoint_path = os.path.join(opt["checkpoint_path"], f"{opt['exp']}_ckp_best.pth.tar")
if not os.path.exists(checkpoint_path):
print(f"[ERROR] Checkpoint {checkpoint_path} not found.")
return 0
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['state_dict'])
model.eval()
sup_model = SuppressNet(opt).to(device)
sup_checkpoint_path = os.path.join(opt["checkpoint_path"], f"ckp_best_suppress.pth.tar")
if os.path.exists(sup_checkpoint_path):
checkpoint = torch.load(sup_checkpoint_path, map_location=device)
sup_model.load_state_dict(checkpoint['state_dict'])
sup_model.eval()
else:
print(f"[WARNING] SuppressNet checkpoint {sup_checkpoint_path} not found.")
dataset = VideoDataSet(opt, subset=opt['inference_subset'], video_name=video_name)
test_loader = torch.utils.data.DataLoader(
dataset,
batch_size=1,
shuffle=False,
num_workers=0,
pin_memory=False
)
result_dict = {}
proposal_dict = []
num_class = opt["num_of_class"]
unit_size = opt['segment_size']
threshold = opt['threshold']
anchors = opt['anchors']
start_time = time.time()
total_frames = 0
for video_name in dataset.video_names:
input_queue = torch.zeros((unit_size, opt['feat_dim']))
sup_queue = torch.zeros((unit_size, num_class - 1))
duration = dataset.video_len[video_name]
video_time = float(dataset.video_dict[video_name]["duration"])
frame_to_time = 100.0 * video_time / duration
for idx in range(duration):
total_frames += 1
input_queue[:-1, :] = input_queue[1:, :].clone()
input_queue[-1, :] = dataset._get_base_data(video_name, idx, idx + 1).squeeze(0)
minput = input_queue.unsqueeze(0).to(device)
act_cls, act_reg, _ = model(minput)
act_cls = torch.softmax(act_cls, dim=-1)
cls_anc = act_cls.squeeze(0).detach().cpu().numpy()
reg_anc = act_reg.squeeze(0).detach().cpu().numpy()
proposal_anc_dict = []
for anc_idx in range(len(anchors)):
cls = np.argwhere(cls_anc[anc_idx][:-1] > threshold).reshape(-1)
if len(cls) == 0:
continue
ed = idx + anchors[anc_idx] * reg_anc[anc_idx][0]
length = anchors[anc_idx] * np.exp(reg_anc[anc_idx][1])
st = ed - length
for cidx in range(len(cls)):
label = cls[cidx]
tmp_dict = {
"segment": [float(st * frame_to_time / 100.0), float(ed * frame_to_time / 100.0)],
"score": float(cls_anc[anc_idx][label]),
"label": dataset.label_name[label],
"gentime": float(idx * frame_to_time / 100.0)
}
proposal_anc_dict.append(tmp_dict)
proposal_anc_dict = non_max_suppression(proposal_anc_dict, overlapThresh=opt['soft_nms'])
sup_queue[:-1, :] = sup_queue[1:, :].clone()
sup_queue[-1, :] = 0
for proposal in proposal_anc_dict:
cls_idx = dataset.label_name.index(proposal['label'])
sup_queue[-1, cls_idx] = proposal["score"]
minput = sup_queue.unsqueeze(0).to(device)
suppress_conf = sup_model(minput)
suppress_conf = suppress_conf.squeeze().detach().cpu().numpy()
for cls in range(num_class - 1):
if suppress_conf[cls] > opt['sup_threshold']:
for proposal in proposal_anc_dict:
if proposal['label'] == dataset.label_name[cls]:
if check_overlap_proposal(proposal_dict, proposal, overlapThresh=opt['soft_nms']) is None:
proposal_dict.append(proposal)
result_dict[video_name] = proposal_dict
proposal_dict = []
end_time = time.time()
working_time = end_time - start_time
print(f"[INFO] Working time: {working_time:.2f}s, {total_frames / working_time:.1f}fps, {total_frames} frames")
output_dict = {"version": "VERSION 1.3", "results": result_dict, "external_data": None}
result_path = opt["result_file"].format(opt['exp'])
os.makedirs(os.path.dirname(result_path), exist_ok=True)
with open(result_path, "w") as f:
json.dump(output_dict, f, indent=4)
mAP = evaluation_detection(opt, verbose=False)
mAP_value = sum(mAP) / len(mAP) if mAP else 0
return mAP_value
def main(opt, video_name=None):
max_perf = 0
if not video_name and 'video_name' in opt:
video_name = opt['video_name']
if opt['mode'] == 'train':
max_perf = train(opt)
elif opt['mode'] == 'test':
max_perf, comparison_text, viz_path, video_out_path = test(opt, video_name=video_name)
return max_perf, comparison_text, viz_path, video_out_path
elif opt['mode'] == 'test_online':
max_perf = test_online(opt, video_name=video_name)
elif opt['mode'] == 'eval':
max_perf = evaluation_detection(opt, verbose=False)
return max_perf
def gradio_interface(video):
global opt
if not video:
return None, None, "Please upload a video."
video_name = os.path.splitext(os.path.basename(video))[0]
feature_path = os.path.join(os.getcwd(), 'data', 'features', f"{video_name}.npz")
if not os.path.exists(feature_path):
return None, None, f"[ERROR] Feature file {feature_path} not found for video {video_name}."
opt_dict = vars(opts.parse_opt())
opt_dict['mode'] = 'test'
opt_dict['video_name'] = video_name
opt_dict['video_path'] = video
opt_dict['video_anno'] = os.path.join(os.getcwd(), 'data', 'annotations.json')
opt_dict['video_feature_all_test'] = os.path.join(os.getcwd(), 'data', 'features') + os.sep
opt_dict['checkpoint_path'] = os.path.join(os.getcwd(), 'checkpoint')
opt_dict['result_file'] = os.path.join(os.getcwd(), 'results', 'result_{}.json')
opt_dict['frame_result_file'] = os.path.join(os.getcwd(), 'results', 'frame_result_{}.h5')
opt_dict['video_len_file'] = os.path.join(os.getcwd(), 'data', 'video_len_{}.json')
opt_dict['proposal_label_file'] = os.path.join(os.getcwd(), 'data', 'proposal_label_{}.h5')
opt_dict['suppress_label_file'] = os.path.join(os.getcwd(), 'data', 'suppress_label_{}.h5')
opt_dict['batch_size'] = 1
opt_dict['data_format'] = 'npz_i3d'
opt_dict['rgb_only'] = False
opt_dict['anchors'] = [int(item) for item in opt_dict['anchors'].split(',')]
opt_dict['predefined_fps'] = 30 # Adjust if needed
opt_dict['split'] = 'test'
opt_dict['setup'] = 'default'
opt_dict['data_rescale'] = 1.0
opt_dict['pos_threshold'] = 0.5
mAP, comparison_text, viz_path, video_out_path = main(opt_dict, video_name=video_name)
return viz_path, video_out_path, f"mAP: {mAP:.4f}\n\n{comparison_text}"
if __name__ == "__main__":
opt = opts.parse_opt()
opt = vars(opt)
opt['checkpoint_path'] = os.path.join(os.getcwd(), 'checkpoint')
opt['result_file'] = os.path.join(os.getcwd(), 'results', 'result_{}.json')
opt['frame_result_file'] = os.path.join(os.getcwd(), 'results', 'frame_result_{}.h5')
opt['video_anno'] = os.path.join(os.getcwd(), 'data', 'annotations.json')
opt['video_feature_all_test'] = os.path.join(os.getcwd(), 'data', 'features') + os.sep
opt['video_len_file'] = os.path.join(os.getcwd(), 'data', 'video_len_{}.json')
opt['proposal_label_file'] = os.path.join(os.getcwd(), 'data', 'proposal_label_{}.h5')
opt['suppress_label_file'] = os.path.join(os.getcwd(), 'data', 'suppress_label_{}.h5')
opt['data_format'] = 'npz_i3d'
opt['rgb_only'] = False
opt['predefined_fps'] = 30
opt['split'] = 'test'
opt['setup'] = 'default'
opt['data_rescale'] = 1.0
opt['pos_threshold'] = 0.5
os.makedirs(opt["checkpoint_path"], exist_ok=True)
os.makedirs(os.path.dirname(opt["result_file"].format(opt['exp'])), exist_ok=True)
os.makedirs(os.path.dirname(opt["video_anno"]), exist_ok=True)
with open(os.path.join(opt["checkpoint_path"], f"{opt['exp']}_opts.json"), "w") as f:
json.dump(opt, f, indent=4)
if opt['seed'] >= 0:
torch.manual_seed(opt['seed'])
np.random.seed(opt['seed'])
opt['anchors'] = [int(item) for item in opt['anchors'].split(',')]
video_name = opt.get('video_name', None)
if opt.get('gradio', False):
iface = gr.Interface(
fn=gradio_interface,
inputs=gr.Video(label="Upload Video"),
outputs=[
gr.Image(label="Action Length Visualization"),
gr.Video(label="Annotated Video"),
gr.Textbox(label="Results and mAP")
],
title="Action Detection Model",
description="Upload a video to detect actions using pre-extracted I3D features. Ensure a corresponding .npz file exists in data/features/. View visualizations and performance metrics."
)
iface.launch()
else:
main(opt, video_name=video_name)