#!/usr/bin/env python3 """ Matchcommentary Model Inference Script - HuggingFace Version For automatic soccer commentary generation """ import torch import argparse import os import csv from tqdm import tqdm from typing import List, Dict, Any import json # Assuming model files are included in the HuggingFace repository from models.matchvoice_model import matchvoice_model from matchvoice_dataset import MatchVoice_Dataset from torch.utils.data import DataLoader class MatchcommentaryPredictor: """Matchcommentary model inference class""" def __init__(self, model_path: str = "./", device: str = "cuda:0"): """ Initialize Matchcommentary predictor Args: model_path: Path to model files device: Device to run on """ self.device = device self.model = None self.load_model(model_path) def load_model(self, model_path: str): """Load the model""" print("Loading Matchcommentary model...") # Initialize model self.model = matchvoice_model( llm_ckpt="meta-llama/Meta-Llama-3-8B-Instruct", tokenizer_ckpt="meta-llama/Meta-Llama-3-8B-Instruct", num_video_query_token=32, num_features=512, device=self.device, inference=True ) # Load checkpoint checkpoint_path = os.path.join(model_path, "model_save_best_val_CIDEr.pth") if os.path.exists(checkpoint_path): checkpoint = torch.load(checkpoint_path, map_location="cpu") # Load state dict model_state_dict = self.model.state_dict() for key, value in checkpoint.items(): if key in model_state_dict: model_state_dict[key] = value self.model.load_state_dict(model_state_dict) print("Model checkpoint loaded successfully!") else: print(f"Warning: Model checkpoint file not found at {checkpoint_path}") self.model.eval() def predict_single(self, video_features: torch.Tensor) -> List[str]: """ Predict commentary for a single video clip Args: video_features: Video feature tensor Returns: List of predicted commentary texts """ with torch.no_grad(): # Build input sample format samples = { 'features': video_features.to(self.device), 'caption_info': [["", "", "", "", "", ""]] # Placeholder } predictions = self.model(samples) return predictions def predict_batch(self, feature_root: str, ann_root: str, output_csv: str, batch_size: int = 4, num_workers: int = 2, generate_num: int = 1, fps: float = 0.5, window: float = 15): """ Batch prediction and save results to CSV file Args: feature_root: Root directory for video features ann_root: Root directory for annotation files output_csv: Output CSV file path batch_size: Batch size for processing num_workers: Number of data loading workers generate_num: Number of commentary generations per video clip fps: Feature extraction frame rate window: Video window size in seconds """ print("Preparing dataset...") # Create dataset test_dataset = MatchVoice_Dataset( feature_root=feature_root, ann_root=ann_root, fps=fps, timestamp_key="gameTime", tokenizer_name="meta-llama/Meta-Llama-3-8B-Instruct", window=window, split_ratio=0.01, # Use small subset for quick testing is_train=False ) test_data_loader = DataLoader( test_dataset, batch_size=batch_size, num_workers=num_workers, drop_last=False, shuffle=False, pin_memory=True, collate_fn=test_dataset.collater ) print("Dataset preparation completed, starting prediction...") # Create output directory os.makedirs(os.path.dirname(output_csv), exist_ok=True) # Write CSV header headers = ['league', 'game', 'half', 'timestamp', 'type', 'anonymized'] headers += [f'predicted_res_{i}' for i in range(generate_num)] with open(output_csv, 'w', newline='', encoding='utf-8') as file: writer = csv.writer(file) writer.writerow(headers) # Start prediction with torch.no_grad(): for samples in tqdm(test_data_loader, desc="Prediction Progress"): all_predictions = [] # Generate multiple predictions for _ in range(generate_num): predicted_res = self.model(samples) all_predictions.append(predicted_res) # Write results caption_info = samples["caption_info"] with open(output_csv, 'a', newline='', encoding='utf-8') as file: writer = csv.writer(file) for info in zip(*all_predictions, caption_info): row = [info[-1][4], info[-1][5], info[-1][0], info[-1][1], info[-1][2], info[-1][3]] + list(info[:-1]) writer.writerow(row) print(f"Prediction completed! Results saved to: {output_csv}") def main(): """Main function""" parser = argparse.ArgumentParser(description="Matchcommentary Model Inference Script") parser.add_argument("--model_path", type=str, default="./", help="Path to model files") parser.add_argument("--feature_root", type=str, default="./features", help="Root directory for video features") parser.add_argument("--ann_root", type=str, default="./dataset/MatchTime/train", help="Root directory for annotation files") parser.add_argument("--output_csv", type=str, default="./predictions.csv", help="Output CSV file path") parser.add_argument("--batch_size", type=int, default=4, help="Batch size for processing") parser.add_argument("--num_workers", type=int, default=2, help="Number of data loading workers") parser.add_argument("--generate_num", type=int, default=1, help="Number of commentary generations per video clip") parser.add_argument("--device", type=str, default="cuda:0", help="Device to run on") parser.add_argument("--fps", type=float, default=0.5, help="Feature extraction frame rate") parser.add_argument("--window", type=float, default=15, help="Video window size in seconds") args = parser.parse_args() # Create predictor and run prediction predictor = MatchcommentaryPredictor(args.model_path, args.device) predictor.predict_batch( feature_root=args.feature_root, ann_root=args.ann_root, output_csv=args.output_csv, batch_size=args.batch_size, num_workers=args.num_workers, generate_num=args.generate_num, fps=args.fps, window=args.window ) if __name__ == "__main__": main()