SAT / main.py
Darknsu's picture
Update main.py
efa906f verified
# import os
# import json
# import torch
# import numpy as np
# import gradio as gr
# import opts_egtea as opts
# from dataset import VideoDataSet, calc_iou
# from models import MYNET, SuppressNet
# from loss_func import cls_loss_func, regress_loss_func
# from eval import evaluation_detection
# from iou_utils import non_max_suppression, check_overlap_proposal
# from typing import List, Dict, Optional
# # Configuration
# VIS_CONFIG = {
# 'iou_threshold': 0.3,
# 'min_segment_duration': 1.0,
# }
# # Determine device
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# print(f"Using device: {device}")
# def eval_frame(opt, model, dataset):
# """Evaluate model frame by frame"""
# test_loader = torch.utils.data.DataLoader(
# dataset,
# batch_size=opt['batch_size'],
# shuffle=False,
# num_workers=0,
# pin_memory=False
# )
# labels_cls = {video_name: [] for video_name in dataset.video_list}
# labels_reg = {video_name: [] for video_name in dataset.video_list}
# output_cls = {video_name: [] for video_name in dataset.video_list}
# output_reg = {video_name: [] for video_name in dataset.video_list}
# model.eval()
# with torch.no_grad():
# for n_iter, batch_data in enumerate(test_loader):
# try:
# if len(batch_data) == 4:
# input_data, cls_label, reg_label, _ = batch_data
# else:
# input_data, cls_label, reg_label = batch_data
# input_data = input_data.to(device)
# cls_label = cls_label.to(device) if cls_label is not None else None
# reg_label = reg_label.to(device) if reg_label is not None else None
# act_cls, act_reg, _ = model(input_data.float())
# act_cls = torch.softmax(act_cls, dim=-1)
# for b in range(input_data.size(0)):
# batch_idx = n_iter * opt['batch_size'] + b
# if batch_idx < len(dataset.inputs):
# video_name = dataset.inputs[batch_idx][0]
# output_cls[video_name].append(act_cls[b, :].detach().cpu().numpy())
# output_reg[video_name].append(act_reg[b, :].detach().cpu().numpy())
# if cls_label is not None:
# labels_cls[video_name].append(cls_label[b, :].cpu().numpy())
# if reg_label is not None:
# labels_reg[video_name].append(reg_label[b, :].cpu().numpy())
# except Exception as e:
# print(f"Error in batch {n_iter}: {str(e)}")
# continue
# # Stack arrays
# for video_name in dataset.video_list:
# if output_cls[video_name]:
# output_cls[video_name] = np.stack(output_cls[video_name], axis=0)
# output_reg[video_name] = np.stack(output_reg[video_name], axis=0)
# if labels_cls[video_name]:
# labels_cls[video_name] = np.stack(labels_cls[video_name], axis=0)
# if labels_reg[video_name]:
# labels_reg[video_name] = np.stack(labels_reg[video_name], axis=0)
# return output_cls, output_reg, labels_cls, labels_reg
# def eval_map_nms(opt, dataset, output_cls, output_reg):
# """Evaluate with Non-Maximum Suppression"""
# result_dict = {}
# anchors = opt['anchors']
# for video_name in dataset.video_list:
# if video_name not in output_cls or len(output_cls[video_name]) == 0:
# result_dict[video_name] = []
# continue
# duration = dataset.video_len[video_name]
# video_time = float(dataset.video_dict[video_name]["duration"])
# frame_to_time = 100.0 * video_time / duration
# proposal_dict = []
# for idx in range(min(duration, len(output_cls[video_name]))):
# cls_anc = output_cls[video_name][idx]
# reg_anc = output_reg[video_name][idx]
# for anc_idx in range(len(anchors)):
# if anc_idx >= len(cls_anc):
# continue
# cls = np.argwhere(cls_anc[anc_idx][:-1] > opt['threshold']).reshape(-1)
# if len(cls) == 0:
# continue
# ed = idx + anchors[anc_idx] * reg_anc[anc_idx][0]
# length = anchors[anc_idx] * np.exp(reg_anc[anc_idx][1])
# st = ed - length
# for cidx in range(len(cls)):
# label = cls[cidx]
# if label < len(dataset.label_name):
# tmp_dict = {
# "segment": [float(st * frame_to_time / 100.0), float(ed * frame_to_time / 100.0)],
# "score": float(cls_anc[anc_idx][label]),
# "label": dataset.label_name[label],
# "gentime": float(idx * frame_to_time / 100.0)
# }
# proposal_dict.append(tmp_dict)
# # Apply NMS
# proposal_dict = non_max_suppression(proposal_dict, overlapThresh=opt['soft_nms'])
# result_dict[video_name] = proposal_dict
# return result_dict
# def load_ground_truth(opt, video_name):
# """Load ground truth annotations if available"""
# gt_segments = []
# duration = 0
# try:
# video_anno_file = opt["video_anno"].format(opt["split"])
# if os.path.exists(video_anno_file):
# with open(video_anno_file, 'r') as f:
# anno_data = json.load(f)
# if video_name in anno_data['database']:
# gt_annotations = anno_data['database'][video_name]['annotations']
# duration = anno_data['database'][video_name]['duration']
# for anno in gt_annotations:
# start, end = anno['segment']
# gt_segments.append({
# 'label': anno['label'],
# 'start': start,
# 'end': end,
# 'duration': end - start
# })
# except Exception as e:
# print(f"Could not load ground truth: {str(e)}")
# return gt_segments, duration
# def process_video(video_name, split_number):
# """Process a single video for action localization"""
# try:
# # Parse options
# opt = opts.parse_opt()
# opt = vars(opt)
# opt['mode'] = 'test'
# opt['split'] = str(split_number)
# opt['checkpoint_path'] = './checkpoint'
# opt['video_feature_all_test'] = './data/I3D/'
# opt['anchors'] = [int(item) for item in opt['anchors'].split(',')]
# opt['batch_size'] = 1
# # Check if required files exist
# checkpoint_path = './checkpoint/01_ckp_best.pth.tar'
# if not os.path.exists(checkpoint_path):
# return "Error: Model checkpoint not found at ./checkpoint/01_ckp_best.pth.tar"
# npz_path = os.path.join(opt['video_feature_all_test'], f"{video_name}.npz")
# if not os.path.exists(npz_path):
# return f"Error: Feature file not found at {npz_path}"
# # Load model
# model = MYNET(opt).to(device)
# checkpoint = torch.load(checkpoint_path, map_location=device)
# # Handle different checkpoint formats
# if 'state_dict' in checkpoint:
# model.load_state_dict(checkpoint['state_dict'])
# else:
# model.load_state_dict(checkpoint)
# model.eval()
# # Create dataset
# dataset = VideoDataSet(opt, subset='test', video_name=video_name)
# if len(dataset.video_list) == 0:
# return f"Error: No video found with name '{video_name}' in dataset"
# # Run inference
# output_cls, output_reg, labels_cls, labels_reg = eval_frame(opt, model, dataset)
# result_dict = eval_map_nms(opt, dataset, output_cls, output_reg)
# # Load ground truth
# gt_segments, duration = load_ground_truth(opt, video_name)
# # Process predictions
# pred_segments = []
# for pred in result_dict.get(video_name, []):
# start, end = pred['segment']
# pred_segments.append({
# 'label': pred['label'],
# 'start': start,
# 'end': end,
# 'duration': end - start,
# 'score': pred['score']
# })
# # Generate output text
# output_text = f"Predicted Actions for Video: {video_name}\n"
# output_text += "=" * 50 + "\n\n"
# if pred_segments:
# output_text += "PREDICTED ACTIONS:\n"
# output_text += "-" * 30 + "\n"
# for i, pred in enumerate(pred_segments, 1):
# output_text += f"{i}. {pred['label']}\n"
# output_text += f" Time: [{pred['start']:.2f}s - {pred['end']:.2f}s]\n"
# output_text += f" Duration: {pred['duration']:.2f}s\n"
# output_text += f" Confidence: {pred['score']:.3f}\n\n"
# else:
# output_text += "No actions detected above threshold.\n\n"
# # Add ground truth comparison if available
# if gt_segments:
# output_text += "\nGROUND TRUTH COMPARISON:\n"
# output_text += "-" * 30 + "\n"
# # Calculate basic metrics
# matched_count = 0
# total_pred = len(pred_segments)
# total_gt = len(gt_segments)
# for gt in gt_segments:
# output_text += f"GT: {gt['label']} [{gt['start']:.2f}s - {gt['end']:.2f}s]\n"
# # Find best matching prediction
# best_match = None
# best_iou = 0
# for pred in pred_segments:
# # Simple overlap calculation
# overlap_start = max(gt['start'], pred['start'])
# overlap_end = min(gt['end'], pred['end'])
# if overlap_end > overlap_start:
# overlap = overlap_end - overlap_start
# union = (gt['end'] - gt['start']) + (pred['end'] - pred['start']) - overlap
# iou = overlap / union if union > 0 else 0
# if iou > best_iou:
# best_iou = iou
# best_match = pred
# if best_match and best_iou > VIS_CONFIG['iou_threshold']:
# matched_count += 1
# output_text += f" β†’ Matched with: {best_match['label']} (IoU: {best_iou:.3f})\n"
# else:
# output_text += f" β†’ No match found\n"
# output_text += "\n"
# # Summary statistics
# precision = matched_count / total_pred if total_pred > 0 else 0
# recall = matched_count / total_gt if total_gt > 0 else 0
# f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
# output_text += f"\nSUMMARY STATISTICS:\n"
# output_text += f"Total Predictions: {total_pred}\n"
# output_text += f"Total Ground Truth: {total_gt}\n"
# output_text += f"Matched: {matched_count}\n"
# output_text += f"Precision: {precision:.3f}\n"
# output_text += f"Recall: {recall:.3f}\n"
# output_text += f"F1-Score: {f1:.3f}\n"
# return output_text
# except Exception as e:
# return f"Error processing video: {str(e)}\n\nPlease check:\n1. Model checkpoint exists\n2. Feature file exists\n3. All dependencies are installed"
# def get_available_videos():
# """Get list of available videos from I3D features directory"""
# feature_dir = './data/I3D/'
# if not os.path.exists(feature_dir):
# return []
# videos = []
# for file in os.listdir(feature_dir):
# if file.endswith('.npz'):
# video_name = file.replace('.npz', '')
# videos.append(video_name)
# return sorted(videos)
# # Initialize available videos
# available_videos = get_available_videos()
# if not available_videos:
# available_videos = ["No videos found"]
# # Gradio Interface
# iface = gr.Interface(
# fn=process_video,
# inputs=[
# gr.Dropdown(
# label="Select Video",
# choices=available_videos,
# value=available_videos[0] if available_videos else None,
# info="Choose from pre-uploaded videos in data/I3D/ folder"
# ),
# gr.Dropdown(
# label="Split Number",
# choices=["1", "2", "3"],
# value="1",
# info="Dataset split for annotations"
# )
# ],
# outputs=[
# gr.Textbox(
# label="Action Predictions",
# lines=20,
# max_lines=50,
# show_copy_button=True
# )
# ],
# title="🎬 Temporal Action Localization",
# description="""
# This app performs temporal action localization on pre-uploaded videos using I3D features.
# **How to use:**
# 1. Select a video from the dropdown (videos must be in data/I3D/ folder as .npz files)
# 2. Choose the annotation split number
# 3. Click Submit to get action predictions
# **Requirements:**
# - Model checkpoint: `01_ckp_best.pth.tar` in root directory
# - Video features: `.npz` files in `data/I3D/` folder
# """,
# examples=[
# [available_videos[0] if available_videos and available_videos[0] != "No videos found" else "example_video", "1"],
# ] if available_videos and available_videos[0] != "No videos found" else None,
# cache_examples=False,
# theme=gr.themes.Soft()
# )
# if __name__ == '__main__':
# print(f"Available videos: {available_videos}")
# print(f"Using device: {device}")
# iface.launch(
# server_name="0.0.0.0",
# server_port=7860,
# share=False
# )
# import os
# import json
# import torch
# import numpy as np
# import gradio as gr
# import opts_egtea as opts
# from dataset import VideoDataSet, calc_iou
# from models import MYNET, SuppressNet
# from loss_func import cls_loss_func, regress_loss_func
# from eval import evaluation_detection
# from iou_utils import non_max_suppression, check_overlap_proposal
# from typing import List, Dict, Optional
# from huggingface_hub import hf_hub_download, list_repo_files
# import tempfile
# import shutil
# import traceback
# # Configuration
# VIS_CONFIG = {
# 'iou_threshold': 0.3,
# 'min_segment_duration': 1.0,
# }
# # Hugging Face Dataset Configuration
# HF_DATASET_REPO = "Darknsu/EGTEA_Dataset"
# HF_DATASET_SUBFOLDER = "I3D" # Adjust this based on your dataset structure
# # Determine device
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# print(f"Using device: {device}")
# # Create local cache directory for downloaded files
# CACHE_DIR = "./hf_cache"
# os.makedirs(CACHE_DIR, exist_ok=True)
# def download_npz_file(video_name: str) -> str:
# """
# Download .npz file from Hugging Face dataset repository
# Returns: Local path to the downloaded file
# """
# try:
# # Construct the file path in the dataset repo
# file_path = f"{video_name}.npz"
# # Check if file already exists in cache
# local_path = os.path.join(CACHE_DIR, f"{video_name}.npz")
# if os.path.exists(local_path):
# print(f"Using cached file: {local_path}")
# return local_path
# # Download from Hugging Face dataset
# print(f"Downloading {file_path} from {HF_DATASET_REPO}...")
# downloaded_path = hf_hub_download(
# repo_id=HF_DATASET_REPO,
# filename=file_path,
# repo_type="dataset",
# cache_dir=CACHE_DIR
# )
# # Copy to our expected location for easier access
# shutil.copy2(downloaded_path, local_path)
# print(f"File downloaded and cached: {local_path}")
# return local_path
# except Exception as e:
# raise Exception(f"Failed to download {video_name}.npz: {str(e)}")
# def get_available_videos_from_hf():
# """Get list of available videos from Hugging Face dataset repository"""
# try:
# print("Fetching available videos from Hugging Face dataset...")
# files = list_repo_files(
# repo_id=HF_DATASET_REPO,
# repo_type="dataset"
# )
# # Filter for .npz files in the I3D subfolder
# videos = []
# for file in files:
# if file.startswith(f"") and file.endswith('.npz'):
# # Extract the full filename without extension
# # For files like "I3D/OP02-R02-TurkeySandwich.npz"
# video_name = os.path.basename(file).replace('.npz', '')
# videos.append(video_name)
# videos = sorted(videos)
# print(f"Found {len(videos)} videos in dataset: {videos[:5]}{'...' if len(videos) > 5 else ''}")
# return videos
# except Exception as e:
# print(f"Error fetching videos from HF dataset: {str(e)}")
# return ["Error loading videos"]
# class HFVideoDataSet(VideoDataSet):
# """
# Modified VideoDataSet that downloads files from Hugging Face on demand
# """
# def __init__(self, opt, subset='test', video_name=None):
# # Store the original video_feature_all_test path
# self.original_feature_path = opt['video_feature_all_test']
# # Create temporary directory for this session
# self.temp_dir = tempfile.mkdtemp(prefix="hf_video_")
# print(f"Created temp directory: {self.temp_dir}")
# # Download the specific video file if video_name is provided
# if video_name:
# try:
# print(f"Downloading features for video: {video_name}")
# downloaded_path = download_npz_file(video_name)
# # Ensure the temp directory exists
# os.makedirs(self.temp_dir, exist_ok=True)
# # Copy to temp directory with expected structure - FIX: Add proper path separator
# temp_file_path = os.path.join(self.temp_dir, f"{video_name}.npz")
# print(f"Copying {downloaded_path} to {temp_file_path}")
# shutil.copy2(downloaded_path, temp_file_path)
# # Verify file exists and print debug info
# if not os.path.exists(temp_file_path):
# raise Exception(f"Failed to copy file to {temp_file_path}")
# else:
# print(f"Video file ready: {temp_file_path}")
# print(f"File size: {os.path.getsize(temp_file_path)} bytes")
# except Exception as e:
# print(f"Error downloading video {video_name}: {str(e)}")
# # Clean up temp directory on error
# if hasattr(self, 'temp_dir') and os.path.exists(self.temp_dir):
# shutil.rmtree(self.temp_dir)
# raise e
# # Set the feature path to our temp directory
# opt['video_feature_all_test'] = self.temp_dir
# print(f"Set video_feature_all_test to: {opt['video_feature_all_test']}")
# # Initialize parent class
# try:
# super().__init__(opt, subset, video_name)
# print(f"Successfully initialized dataset with {len(self.video_list)} videos")
# except Exception as e:
# print(f"Error initializing parent VideoDataSet: {str(e)}")
# # Clean up temp directory on error
# if hasattr(self, 'temp_dir') and os.path.exists(self.temp_dir):
# shutil.rmtree(self.temp_dir)
# raise e
# def __del__(self):
# # Clean up temporary directory
# try:
# if hasattr(self, 'temp_dir') and os.path.exists(self.temp_dir):
# shutil.rmtree(self.temp_dir)
# print(f"Cleaned up temp directory: {self.temp_dir}")
# except Exception as e:
# print(f"Warning: Could not clean up temp directory: {e}")
# def eval_frame(opt, model, dataset):
# """Evaluate model frame by frame"""
# try:
# test_loader = torch.utils.data.DataLoader(
# dataset,
# batch_size=opt['batch_size'],
# shuffle=False,
# num_workers=0,
# pin_memory=False
# )
# labels_cls = {video_name: [] for video_name in dataset.video_list}
# labels_reg = {video_name: [] for video_name in dataset.video_list}
# output_cls = {video_name: [] for video_name in dataset.video_list}
# output_reg = {video_name: [] for video_name in dataset.video_list}
# model.eval()
# with torch.no_grad():
# for n_iter, batch_data in enumerate(test_loader):
# try:
# if len(batch_data) == 4:
# input_data, cls_label, reg_label, _ = batch_data
# else:
# input_data, cls_label, reg_label = batch_data
# input_data = input_data.to(device)
# cls_label = cls_label.to(device) if cls_label is not None else None
# reg_label = reg_label.to(device) if reg_label is not None else None
# act_cls, act_reg, _ = model(input_data.float())
# act_cls = torch.softmax(act_cls, dim=-1)
# for b in range(input_data.size(0)):
# batch_idx = n_iter * opt['batch_size'] + b
# if batch_idx < len(dataset.inputs):
# video_name = dataset.inputs[batch_idx][0]
# output_cls[video_name].append(act_cls[b, :].detach().cpu().numpy())
# output_reg[video_name].append(act_reg[b, :].detach().cpu().numpy())
# if cls_label is not None:
# labels_cls[video_name].append(cls_label[b, :].cpu().numpy())
# if reg_label is not None:
# labels_reg[video_name].append(reg_label[b, :].cpu().numpy())
# except Exception as e:
# print(f"Error in batch {n_iter}: {str(e)}")
# continue
# # Stack arrays
# for video_name in dataset.video_list:
# if output_cls[video_name]:
# output_cls[video_name] = np.stack(output_cls[video_name], axis=0)
# output_reg[video_name] = np.stack(output_reg[video_name], axis=0)
# if labels_cls[video_name]:
# labels_cls[video_name] = np.stack(labels_cls[video_name], axis=0)
# if labels_reg[video_name]:
# labels_reg[video_name] = np.stack(labels_reg[video_name], axis=0)
# return output_cls, output_reg, labels_cls, labels_reg
# except Exception as e:
# print(f"Error in eval_frame: {str(e)}")
# raise e
# def eval_map_nms(opt, dataset, output_cls, output_reg):
# """Evaluate with Non-Maximum Suppression"""
# try:
# result_dict = {}
# anchors = opt['anchors']
# for video_name in dataset.video_list:
# if video_name not in output_cls or len(output_cls[video_name]) == 0:
# result_dict[video_name] = []
# continue
# duration = dataset.video_len[video_name]
# video_time = float(dataset.video_dict[video_name]["duration"])
# frame_to_time = 100.0 * video_time / duration
# proposal_dict = []
# for idx in range(min(duration, len(output_cls[video_name]))):
# cls_anc = output_cls[video_name][idx]
# reg_anc = output_reg[video_name][idx]
# for anc_idx in range(len(anchors)):
# if anc_idx >= len(cls_anc):
# continue
# cls = np.argwhere(cls_anc[anc_idx][:-1] > opt['threshold']).reshape(-1)
# if len(cls) == 0:
# continue
# ed = idx + anchors[anc_idx] * reg_anc[anc_idx][0]
# length = anchors[anc_idx] * np.exp(reg_anc[anc_idx][1])
# st = ed - length
# for cidx in range(len(cls)):
# label = cls[cidx]
# if label < len(dataset.label_name):
# tmp_dict = {
# "segment": [float(st * frame_to_time / 100.0), float(ed * frame_to_time / 100.0)],
# "score": float(cls_anc[anc_idx][label]),
# "label": dataset.label_name[label],
# "gentime": float(idx * frame_to_time / 100.0)
# }
# proposal_dict.append(tmp_dict)
# # Apply NMS
# proposal_dict = non_max_suppression(proposal_dict, overlapThresh=opt['soft_nms'])
# result_dict[video_name] = proposal_dict
# return result_dict
# except Exception as e:
# print(f"Error in eval_map_nms: {str(e)}")
# raise e
# def load_ground_truth(opt, video_name):
# """Load ground truth annotations if available"""
# gt_segments = []
# duration = 0
# try:
# video_anno_file = opt["video_anno"].format(opt["split"])
# if os.path.exists(video_anno_file):
# with open(video_anno_file, 'r') as f:
# anno_data = json.load(f)
# if video_name in anno_data['database']:
# gt_annotations = anno_data['database'][video_name]['annotations']
# duration = anno_data['database'][video_name]['duration']
# for anno in gt_annotations:
# start, end = anno['segment']
# gt_segments.append({
# 'label': anno['label'],
# 'start': start,
# 'end': end,
# 'duration': end - start
# })
# except Exception as e:
# print(f"Could not load ground truth: {str(e)}")
# return gt_segments, duration
# def process_video(video_name, split_number, progress=gr.Progress()):
# """Process a single video for action localization"""
# dataset = None # Initialize dataset variable
# try:
# if not video_name or video_name in ["Error: Could not load videos from HF dataset", "Error loading videos"]:
# return "Error: Please select a valid video name"
# progress(0.1, desc="Initializing...")
# # Parse options
# opt = opts.parse_opt()
# opt = vars(opt)
# opt['mode'] = 'test'
# opt['split'] = str(split_number)
# opt['checkpoint_path'] = './checkpoint'
# opt['video_feature_all_test'] = './data/I3D/' # This will be overridden by HFVideoDataSet
# opt['anchors'] = [int(item) for item in opt['anchors'].split(',')]
# opt['batch_size'] = 1
# progress(0.2, desc="Checking model checkpoint...")
# # Check if required files exist
# checkpoint_path = './checkpoint/01_ckp_best.pth.tar'
# if not os.path.exists(checkpoint_path):
# # Try alternative locations
# alt_paths = ['./01_ckp_best.pth.tar', '01_ckp_best.pth.tar']
# checkpoint_path = None
# for alt_path in alt_paths:
# if os.path.exists(alt_path):
# checkpoint_path = alt_path
# break
# if checkpoint_path is None:
# return "Error: Model checkpoint not found. Please ensure '01_ckp_best.pth.tar' is in the repository."
# progress(0.3, desc="Loading model...")
# # Load model
# model = MYNET(opt).to(device)
# checkpoint = torch.load(checkpoint_path, map_location=device)
# # Handle different checkpoint formats
# if 'state_dict' in checkpoint:
# model.load_state_dict(checkpoint['state_dict'])
# else:
# model.load_state_dict(checkpoint)
# model.eval()
# print("Model loaded successfully")
# progress(0.4, desc=f"Downloading video features for {video_name}...")
# # Create dataset with HF integration
# try:
# dataset = HFVideoDataSet(opt, subset='test', video_name=video_name)
# print(f"Dataset created successfully with {len(dataset.video_list)} videos")
# except Exception as e:
# error_msg = f"Error downloading or loading video '{video_name}': {str(e)}\n\nPlease check:\n1. Video name is correct\n2. File exists in HF dataset\n3. Network connection is stable"
# print(error_msg)
# return error_msg
# if len(dataset.video_list) == 0:
# return f"Error: No video found with name '{video_name}' in dataset after download"
# progress(0.6, desc="Running inference...")
# # Run inference
# try:
# output_cls, output_reg, labels_cls, labels_reg = eval_frame(opt, model, dataset)
# print("Inference completed successfully")
# except Exception as e:
# error_msg = f"Error during inference: {str(e)}"
# print(error_msg)
# return error_msg
# progress(0.8, desc="Processing results...")
# try:
# result_dict = eval_map_nms(opt, dataset, output_cls, output_reg)
# print("NMS processing completed")
# except Exception as e:
# error_msg = f"Error during NMS processing: {str(e)}"
# print(error_msg)
# return error_msg
# # Load ground truth
# gt_segments, duration = load_ground_truth(opt, video_name)
# # Process predictions
# pred_segments = []
# for pred in result_dict.get(video_name, []):
# start, end = pred['segment']
# pred_segments.append({
# 'label': pred['label'],
# 'start': start,
# 'end': end,
# 'duration': end - start,
# 'score': pred['score']
# })
# progress(0.9, desc="Generating output...")
# # Generate output text
# output_text = f"Predicted Actions for Video: {video_name}\n"
# output_text += "=" * 50 + "\n\n"
# if pred_segments:
# output_text += "PREDICTED ACTIONS:\n"
# output_text += "-" * 30 + "\n"
# for i, pred in enumerate(pred_segments, 1):
# output_text += f"{i}. {pred['label']}\n"
# output_text += f" Time: [{pred['start']:.2f}s - {pred['end']:.2f}s]\n"
# output_text += f" Duration: {pred['duration']:.2f}s\n"
# output_text += f" Confidence: {pred['score']:.3f}\n\n"
# else:
# output_text += "No actions detected above threshold.\n\n"
# # Add ground truth comparison if available
# if gt_segments:
# output_text += "\nGROUND TRUTH COMPARISON:\n"
# output_text += "-" * 30 + "\n"
# # Calculate basic metrics
# matched_count = 0
# total_pred = len(pred_segments)
# total_gt = len(gt_segments)
# for gt in gt_segments:
# output_text += f"GT: {gt['label']} [{gt['start']:.2f}s - {gt['end']:.2f}s]\n"
# # Find best matching prediction
# best_match = None
# best_iou = 0
# for pred in pred_segments:
# # Simple overlap calculation
# overlap_start = max(gt['start'], pred['start'])
# overlap_end = min(gt['end'], pred['end'])
# if overlap_end > overlap_start:
# overlap = overlap_end - overlap_start
# union = (gt['end'] - gt['start']) + (pred['end'] - pred['start']) - overlap
# iou = overlap / union if union > 0 else 0
# if iou > best_iou:
# best_iou = iou
# best_match = pred
# if best_match and best_iou > VIS_CONFIG['iou_threshold']:
# matched_count += 1
# output_text += f" β†’ Matched with: {best_match['label']} (IoU: {best_iou:.3f})\n"
# else:
# output_text += f" β†’ No match found\n"
# output_text += "\n"
# # Summary statistics
# precision = matched_count / total_pred if total_pred > 0 else 0
# recall = matched_count / total_gt if total_gt > 0 else 0
# f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
# output_text += f"\nSUMMARY STATISTICS:\n"
# output_text += f"Total Predictions: {total_pred}\n"
# output_text += f"Total Ground Truth: {total_gt}\n"
# output_text += f"Matched: {matched_count}\n"
# output_text += f"Precision: {precision:.3f}\n"
# output_text += f"Recall: {recall:.3f}\n"
# output_text += f"F1-Score: {f1:.3f}\n"
# progress(1.0, desc="Complete!")
# print("Processing completed successfully")
# return output_text
# except Exception as e:
# error_details = traceback.format_exc()
# error_msg = f"Error processing video: {str(e)}\n\nDetailed error:\n{error_details}\n\nPlease check:\n1. Model checkpoint exists\n2. Video exists in HF dataset\n3. All dependencies are installed"
# print(error_msg)
# return error_msg
# finally:
# # Ensure cleanup happens even if there's an error
# if dataset is not None and hasattr(dataset, '__del__'):
# try:
# dataset.__del__()
# except Exception as e:
# print(f"Warning: Error during dataset cleanup: {e}")
# def refresh_video_list():
# """Refresh the list of available videos"""
# try:
# new_videos = get_available_videos_from_hf()
# return gr.Dropdown(choices=new_videos)
# except Exception as e:
# print(f"Error refreshing video list: {e}")
# return gr.Dropdown(choices=["Error refreshing videos"])
# # Initialize available videos
# print("Loading available videos from Hugging Face dataset...")
# try:
# available_videos = get_available_videos_from_hf()
# if not available_videos or available_videos == ["Error loading videos"]:
# available_videos = ["Error: Could not load videos from HF dataset"]
# except Exception as e:
# print(f"Error loading initial video list: {e}")
# available_videos = ["Error: Could not load videos from HF dataset"]
# print(f"Available videos: {len(available_videos)} videos found")
# # Gradio Interface
# with gr.Blocks(theme=gr.themes.Soft(), title="🎬 Temporal Action Localization") as iface:
# gr.Markdown("""
# # 🎬 Temporal Action Localization
# This app performs temporal action localization on videos using I3D features loaded dynamically from Hugging Face datasets.
# **Features:**
# - βœ… Dynamic loading from HF dataset repository
# - βœ… Real-time inference with progress tracking
# - βœ… Ground truth comparison when available
# - βœ… Detailed action predictions with confidence scores
# """)
# with gr.Row():
# with gr.Column(scale=1):
# video_dropdown = gr.Dropdown(
# label="Select Video",
# choices=available_videos,
# value=available_videos[0] if available_videos and "Error" not in available_videos[0] else None,
# info="Videos loaded from Hugging Face dataset"
# )
# split_dropdown = gr.Dropdown(
# label="Split Number",
# choices=["1", "2", "3"],
# value="1",
# info="Dataset split for annotations"
# )
# refresh_btn = gr.Button("πŸ”„ Refresh Video List", variant="secondary")
# submit_btn = gr.Button("πŸš€ Run Action Localization", variant="primary")
# with gr.Column(scale=2):
# output_text = gr.Textbox(
# label="Action Predictions",
# lines=25,
# max_lines=50,
# show_copy_button=True,
# placeholder="Results will appear here..."
# )
# gr.Markdown(f"""
# **Dataset Source:** [{HF_DATASET_REPO}](https://huggingface.co/datasets/{HF_DATASET_REPO})
# **Requirements:**
# - Model checkpoint: `01_ckp_best.pth.tar` in repository root
# - Video features: Automatically downloaded from HF dataset
# """)
# # Event handlers
# refresh_btn.click(
# fn=refresh_video_list,
# outputs=video_dropdown
# )
# submit_btn.click(
# fn=process_video,
# inputs=[video_dropdown, split_dropdown],
# outputs=output_text
# )
# # Example
# if available_videos and "Error" not in available_videos[0]:
# gr.Examples(
# examples=[[available_videos[0], "1"]],
# inputs=[video_dropdown, split_dropdown],
# fn=process_video,
# outputs=output_text,
# cache_examples=False
# )
# if __name__ == '__main__':
# print(f"Available videos: {len(available_videos)}")
# print(f"Using device: {device}")
# print(f"HF Dataset: {HF_DATASET_REPO}")
# iface.launch(
# server_name="0.0.0.0",
# server_port=7860,
# share=False
# )
import os
import json
import torch
import numpy as np
import gradio as gr
import opts_egtea as opts
from dataset import VideoDataSet, calc_iou
from models import MYNET, SuppressNet
from loss_func import cls_loss_func, regress_loss_func
from eval import evaluation_detection
from iou_utils import non_max_suppression, check_overlap_proposal
from typing import List, Dict, Optional
from huggingface_hub import hf_hub_download, list_repo_files
import tempfile
import shutil
import traceback
# Configuration
VIS_CONFIG = {
'iou_threshold': 0.3,
'min_segment_duration': 1.0,
}
# Hugging Face Dataset Configuration
HF_DATASET_REPO = "Darknsu/EGTEA_Dataset"
HF_DATASET_SUBFOLDER = "I3D" # Adjust this based on your dataset structure
# Determine device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Create local cache directory for downloaded files
CACHE_DIR = "./hf_cache"
os.makedirs(CACHE_DIR, exist_ok=True)
def download_npz_file(video_name: str) -> str:
"""
Download .npz file from Hugging Face dataset repository
Returns: Local path to the downloaded file
"""
try:
# Construct the file path in the dataset repo
file_path = f"{HF_DATASET_SUBFOLDER}/{video_name}.npz"
# Check if file already exists in cache
local_path = os.path.join(CACHE_DIR, f"{video_name}.npz")
if os.path.exists(local_path):
print(f"Using cached file: {local_path}")
return local_path
# Download from Hugging Face dataset
print(f"Downloading {file_path} from {HF_DATASET_REPO}...")
downloaded_path = hf_hub_download(
repo_id=HF_DATASET_REPO,
filename=file_path,
repo_type="dataset",
cache_dir=CACHE_DIR
)
# Copy to our expected location for easier access
shutil.copy2(downloaded_path, local_path)
print(f"File downloaded and cached: {local_path}")
return local_path
except Exception as e:
raise Exception(f"Failed to download {video_name}.npz: {str(e)}")
def get_available_videos_from_hf():
"""Get list of available videos from Hugging Face dataset repository"""
try:
print("Fetching available videos from Hugging Face dataset...")
files = list_repo_files(
repo_id=HF_DATASET_REPO,
repo_type="dataset"
)
# Filter for .npz files in the I3D subfolder
videos = []
for file in files:
if file.startswith(f"{HF_DATASET_SUBFOLDER}/") and file.endswith('.npz'):
# Extract the full filename without extension
# For files like "I3D/OP02-R02-TurkeySandwich.npz"
video_name = os.path.basename(file).replace('.npz', '')
videos.append(video_name)
videos = sorted(videos)
print(f"Found {len(videos)} videos in dataset: {videos[:5]}{'...' if len(videos) > 5 else ''}")
return videos
except Exception as e:
print(f"Error fetching videos from HF dataset: {str(e)}")
return ["Error loading videos"]
class HFVideoDataSet(VideoDataSet):
"""
Modified VideoDataSet that downloads files from Hugging Face on demand
"""
def __init__(self, opt, subset='test', video_name=None):
# Store the original video_feature_all_test path
self.original_feature_path = opt['video_feature_all_test']
# Create temporary directory for this session
self.temp_dir = tempfile.mkdtemp(prefix="hf_video_")
print(f"Created temp directory: {self.temp_dir}")
# Download the specific video file if video_name is provided
if video_name:
try:
print(f"Downloading features for video: {video_name}")
downloaded_path = download_npz_file(video_name)
# Ensure the temp directory exists
os.makedirs(self.temp_dir, exist_ok=True)
# Copy to temp directory with expected structure - FIX: Add proper path separator
temp_file_path = os.path.join(self.temp_dir, f"{video_name}.npz")
print(f"Copying {downloaded_path} to {temp_file_path}")
shutil.copy2(downloaded_path, temp_file_path)
# Verify file exists and print debug info
if not os.path.exists(temp_file_path):
raise Exception(f"Failed to copy file to {temp_file_path}")
else:
print(f"Video file ready: {temp_file_path}")
print(f"File size: {os.path.getsize(temp_file_path)} bytes")
except Exception as e:
print(f"Error downloading video {video_name}: {str(e)}")
# Clean up temp directory on error
if hasattr(self, 'temp_dir') and os.path.exists(self.temp_dir):
shutil.rmtree(self.temp_dir)
raise e
# Set the feature path to our temp directory
opt['video_feature_all_test'] = self.temp_dir
print(f"Set video_feature_all_test to: {opt['video_feature_all_test']}")
# Initialize parent class
try:
super().__init__(opt, subset, video_name)
print(f"Successfully initialized dataset with {len(self.video_list)} videos")
except Exception as e:
print(f"Error initializing parent VideoDataSet: {str(e)}")
# Clean up temp directory on error
if hasattr(self, 'temp_dir') and os.path.exists(self.temp_dir):
shutil.rmtree(self.temp_dir)
raise e
def __del__(self):
# Clean up temporary directory
try:
if hasattr(self, 'temp_dir') and os.path.exists(self.temp_dir):
shutil.rmtree(self.temp_dir)
print(f"Cleaned up temp directory: {self.temp_dir}")
except Exception as e:
print(f"Warning: Could not clean up temp directory: {e}")
def eval_frame(opt, model, dataset):
"""Evaluate model frame by frame"""
try:
test_loader = torch.utils.data.DataLoader(
dataset,
batch_size=opt['batch_size'],
shuffle=False,
num_workers=0,
pin_memory=False
)
labels_cls = {video_name: [] for video_name in dataset.video_list}
labels_reg = {video_name: [] for video_name in dataset.video_list}
output_cls = {video_name: [] for video_name in dataset.video_list}
output_reg = {video_name: [] for video_name in dataset.video_list}
model.eval()
with torch.no_grad():
for n_iter, batch_data in enumerate(test_loader):
try:
if len(batch_data) == 4:
input_data, cls_label, reg_label, _ = batch_data
else:
input_data, cls_label, reg_label = batch_data
input_data = input_data.to(device)
cls_label = cls_label.to(device) if cls_label is not None else None
reg_label = reg_label.to(device) if reg_label is not None else None
act_cls, act_reg, _ = model(input_data.float())
act_cls = torch.softmax(act_cls, dim=-1)
for b in range(input_data.size(0)):
batch_idx = n_iter * opt['batch_size'] + b
if batch_idx < len(dataset.inputs):
video_name = dataset.inputs[batch_idx][0]
output_cls[video_name].append(act_cls[b, :].detach().cpu().numpy())
output_reg[video_name].append(act_reg[b, :].detach().cpu().numpy())
if cls_label is not None:
labels_cls[video_name].append(cls_label[b, :].cpu().numpy())
if reg_label is not None:
labels_reg[video_name].append(reg_label[b, :].cpu().numpy())
except Exception as e:
print(f"Error in batch {n_iter}: {str(e)}")
continue
# Stack arrays
for video_name in dataset.video_list:
if output_cls[video_name]:
output_cls[video_name] = np.stack(output_cls[video_name], axis=0)
output_reg[video_name] = np.stack(output_reg[video_name], axis=0)
if labels_cls[video_name]:
labels_cls[video_name] = np.stack(labels_cls[video_name], axis=0)
if labels_reg[video_name]:
labels_reg[video_name] = np.stack(labels_reg[video_name], axis=0)
return output_cls, output_reg, labels_cls, labels_reg
except Exception as e:
print(f"Error in eval_frame: {str(e)}")
raise e
def eval_map_nms(opt, dataset, output_cls, output_reg):
"""Evaluate with Non-Maximum Suppression"""
try:
result_dict = {}
anchors = opt['anchors']
for video_name in dataset.video_list:
if video_name not in output_cls or len(output_cls[video_name]) == 0:
result_dict[video_name] = []
continue
duration = dataset.video_len[video_name]
video_time = float(dataset.video_dict[video_name]["duration"])
frame_to_time = 100.0 * video_time / duration
proposal_dict = []
for idx in range(min(duration, len(output_cls[video_name]))):
cls_anc = output_cls[video_name][idx]
reg_anc = output_reg[video_name][idx]
for anc_idx in range(len(anchors)):
if anc_idx >= len(cls_anc):
continue
cls = np.argwhere(cls_anc[anc_idx][:-1] > opt['threshold']).reshape(-1)
if len(cls) == 0:
continue
ed = idx + anchors[anc_idx] * reg_anc[anc_idx][0]
length = anchors[anc_idx] * np.exp(reg_anc[anc_idx][1])
st = ed - length
for cidx in range(len(cls)):
label = cls[cidx]
if label < len(dataset.label_name):
tmp_dict = {
"segment": [float(st * frame_to_time / 100.0), float(ed * frame_to_time / 100.0)],
"score": float(cls_anc[anc_idx][label]),
"label": dataset.label_name[label],
"gentime": float(idx * frame_to_time / 100.0)
}
proposal_dict.append(tmp_dict)
# Apply NMS
proposal_dict = non_max_suppression(proposal_dict, overlapThresh=opt['soft_nms'])
result_dict[video_name] = proposal_dict
return result_dict
except Exception as e:
print(f"Error in eval_map_nms: {str(e)}")
raise e
def load_ground_truth(opt, video_name):
"""Load ground truth annotations if available"""
gt_segments = []
duration = 0
try:
video_anno_file = opt["video_anno"].format(opt["split"])
if os.path.exists(video_anno_file):
with open(video_anno_file, 'r') as f:
anno_data = json.load(f)
if video_name in anno_data['database']:
gt_annotations = anno_data['database'][video_name]['annotations']
duration = anno_data['database'][video_name]['duration']
for anno in gt_annotations:
start, end = anno['segment']
gt_segments.append({
'label': anno['label'],
'start': start,
'end': end,
'duration': end - start
})
except Exception as e:
print(f"Could not load ground truth: {str(e)}")
return gt_segments, duration
def process_video(video_name, split_number, progress=gr.Progress()):
"""Process a single video for action localization"""
dataset = None # Initialize dataset variable
try:
if not video_name or video_name in ["Error: Could not load videos from HF dataset", "Error loading videos"]:
return "Error: Please select a valid video name"
progress(0.1, desc="Initializing...")
# Parse options
opt = opts.parse_opt()
opt = vars(opt)
opt['mode'] = 'test'
opt['split'] = str(split_number)
opt['checkpoint_path'] = './checkpoint'
opt['video_feature_all_test'] = './data/I3D/' # This will be overridden by HFVideoDataSet
opt['anchors'] = [int(item) for item in opt['anchors'].split(',')]
opt['batch_size'] = 1
progress(0.2, desc="Checking model checkpoint...")
# Check if required files exist
checkpoint_path = './checkpoint/01_ckp_best.pth.tar'
if not os.path.exists(checkpoint_path):
# Try alternative locations
alt_paths = ['./01_ckp_best.pth.tar', '01_ckp_best.pth.tar']
checkpoint_path = None
for alt_path in alt_paths:
if os.path.exists(alt_path):
checkpoint_path = alt_path
break
if checkpoint_path is None:
return "Error: Model checkpoint not found. Please ensure '01_ckp_best.pth.tar' is in the repository."
progress(0.3, desc="Loading model...")
# Load model
model = MYNET(opt).to(device)
checkpoint = torch.load(checkpoint_path, map_location=device)
# Handle different checkpoint formats
if 'state_dict' in checkpoint:
model.load_state_dict(checkpoint['state_dict'])
else:
model.load_state_dict(checkpoint)
model.eval()
print("Model loaded successfully")
progress(0.4, desc=f"Downloading video features for {video_name}...")
# Create dataset with HF integration
try:
dataset = HFVideoDataSet(opt, subset='test', video_name=video_name)
print(f"Dataset created successfully with {len(dataset.video_list)} videos")
except Exception as e:
error_msg = f"Error downloading or loading video '{video_name}': {str(e)}\n\nPlease check:\n1. Video name is correct\n2. File exists in HF dataset\n3. Network connection is stable"
print(error_msg)
return error_msg
if len(dataset.video_list) == 0:
return f"Error: No video found with name '{video_name}' in dataset after download"
progress(0.6, desc="Running inference...")
# Run inference
try:
output_cls, output_reg, labels_cls, labels_reg = eval_frame(opt, model, dataset)
print("Inference completed successfully")
except Exception as e:
error_msg = f"Error during inference: {str(e)}"
print(error_msg)
return error_msg
progress(0.8, desc="Processing results...")
try:
result_dict = eval_map_nms(opt, dataset, output_cls, output_reg)
print("NMS processing completed")
except Exception as e:
error_msg = f"Error during NMS processing: {str(e)}"
print(error_msg)
return error_msg
# Load ground truth
gt_segments, duration = load_ground_truth(opt, video_name)
# Process predictions
pred_segments = []
for pred in result_dict.get(video_name, []):
start, end = pred['segment']
pred_segments.append({
'label': pred['label'],
'start': start,
'end': end,
'duration': end - start,
'score': pred['score']
})
progress(0.9, desc="Generating output...")
# Generate output text
output_text = f"Predicted Actions for Video: {video_name}\n"
output_text += "=" * 50 + "\n\n"
if pred_segments:
output_text += "PREDICTED ACTIONS:\n"
output_text += "-" * 30 + "\n"
for i, pred in enumerate(pred_segments, 1):
output_text += f"{i}. {pred['label']}\n"
output_text += f" Time: [{pred['start']:.2f}s - {pred['end']:.2f}s]\n"
output_text += f" Duration: {pred['duration']:.2f}s\n"
output_text += f" Confidence: {pred['score']:.3f}\n\n"
else:
output_text += "No actions detected above threshold.\n\n"
# Add ground truth comparison if available
if gt_segments:
output_text += "\nGROUND TRUTH COMPARISON:\n"
output_text += "-" * 30 + "\n"
# Calculate basic metrics
matched_count = 0
total_pred = len(pred_segments)
total_gt = len(gt_segments)
for gt in gt_segments:
output_text += f"GT: {gt['label']} [{gt['start']:.2f}s - {gt['end']:.2f}s]\n"
# Find best matching prediction
best_match = None
best_iou = 0
for pred in pred_segments:
# Simple overlap calculation
overlap_start = max(gt['start'], pred['start'])
overlap_end = min(gt['end'], pred['end'])
if overlap_end > overlap_start:
overlap = overlap_end - overlap_start
union = (gt['end'] - gt['start']) + (pred['end'] - pred['start']) - overlap
iou = overlap / union if union > 0 else 0
if iou > best_iou:
best_iou = iou
best_match = pred
if best_match and best_iou > VIS_CONFIG['iou_threshold']:
matched_count += 1
output_text += f" β†’ Matched with: {best_match['label']} (IoU: {best_iou:.3f})\n"
else:
output_text += f" β†’ No match found\n"
output_text += "\n"
# Summary statistics
precision = matched_count / total_pred if total_pred > 0 else 0
recall = matched_count / total_gt if total_gt > 0 else 0
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
output_text += f"\nSUMMARY STATISTICS:\n"
output_text += f"Total Predictions: {total_pred}\n"
output_text += f"Total Ground Truth: {total_gt}\n"
output_text += f"Matched: {matched_count}\n"
output_text += f"Precision: {precision:.3f}\n"
output_text += f"Recall: {recall:.3f}\n"
output_text += f"F1-Score: {f1:.3f}\n"
progress(1.0, desc="Complete!")
print("Processing completed successfully")
return output_text
except Exception as e:
error_details = traceback.format_exc()
error_msg = f"Error processing video: {str(e)}\n\nDetailed error:\n{error_details}\n\nPlease check:\n1. Model checkpoint exists\n2. Video exists in HF dataset\n3. All dependencies are installed"
print(error_msg)
return error_msg
finally:
# Ensure cleanup happens even if there's an error
if dataset is not None and hasattr(dataset, '__del__'):
try:
dataset.__del__()
except Exception as e:
print(f"Warning: Error during dataset cleanup: {e}")
def refresh_video_list():
"""Refresh the list of available videos"""
try:
new_videos = get_available_videos_from_hf()
return gr.Dropdown(choices=new_videos)
except Exception as e:
print(f"Error refreshing video list: {e}")
return gr.Dropdown(choices=["Error refreshing videos"])
# Initialize available videos
print("Loading available videos from Hugging Face dataset...")
try:
available_videos = get_available_videos_from_hf()
if not available_videos or available_videos == ["Error loading videos"]:
available_videos = ["Error: Could not load videos from HF dataset"]
except Exception as e:
print(f"Error loading initial video list: {e}")
available_videos = ["Error: Could not load videos from HF dataset"]
print(f"Available videos: {len(available_videos)} videos found")
# Gradio Interface
with gr.Blocks(theme=gr.themes.Soft(), title="🎬 Temporal Action Localization") as iface:
gr.Markdown("""
# 🎬 Temporal Action Localization
This app performs temporal action localization on videos using I3D features loaded dynamically from Hugging Face datasets.
**Features:**
- βœ… Dynamic loading from HF dataset repository
- βœ… Real-time inference with progress tracking
- βœ… Ground truth comparison when available
- βœ… Detailed action predictions with confidence scores
""")
with gr.Row():
with gr.Column(scale=1):
video_dropdown = gr.Dropdown(
label="Select Video",
choices=available_videos,
value=available_videos[0] if available_videos and "Error" not in available_videos[0] else None,
info="Videos loaded from Hugging Face dataset"
)
split_dropdown = gr.Dropdown(
label="Split Number",
choices=["1", "2", "3"],
value="1",
info="Dataset split for annotations"
)
refresh_btn = gr.Button("πŸ”„ Refresh Video List", variant="secondary")
submit_btn = gr.Button("πŸš€ Run Action Localization", variant="primary")
with gr.Column(scale=2):
output_text = gr.Textbox(
label="Action Predictions",
lines=25,
max_lines=50,
show_copy_button=True,
placeholder="Results will appear here..."
)
gr.Markdown(f"""
**Dataset Source:** [{HF_DATASET_REPO}](https://huggingface.co/datasets/{HF_DATASET_REPO})
**Requirements:**
- Model checkpoint: `01_ckp_best.pth.tar` in repository root
- Video features: Automatically downloaded from HF dataset
""")
# Event handlers
refresh_btn.click(
fn=refresh_video_list,
outputs=video_dropdown
)
submit_btn.click(
fn=process_video,
inputs=[video_dropdown, split_dropdown],
outputs=output_text
)
# Example
if available_videos and "Error" not in available_videos[0]:
gr.Examples(
examples=[[available_videos[0], "1"]],
inputs=[video_dropdown, split_dropdown],
fn=process_video,
outputs=output_text,
cache_examples=False
)
if __name__ == '__main__':
print(f"Available videos: {len(available_videos)}")
print(f"Using device: {device}")
print(f"HF Dataset: {HF_DATASET_REPO}")
iface.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)