matchcommentary / inference.py
abocide's picture
Upload folder using huggingface_hub
1ea8d66 verified
#!/usr/bin/env python3
"""
Matchcommentary Model Inference Script - HuggingFace Version
For automatic soccer commentary generation
"""
import torch
import argparse
import os
import csv
from tqdm import tqdm
from typing import List, Dict, Any
import json
# Assuming model files are included in the HuggingFace repository
from models.matchvoice_model import matchvoice_model
from matchvoice_dataset import MatchVoice_Dataset
from torch.utils.data import DataLoader
class MatchcommentaryPredictor:
"""Matchcommentary model inference class"""
def __init__(self, model_path: str = "./", device: str = "cuda:0"):
"""
Initialize Matchcommentary predictor
Args:
model_path: Path to model files
device: Device to run on
"""
self.device = device
self.model = None
self.load_model(model_path)
def load_model(self, model_path: str):
"""Load the model"""
print("Loading Matchcommentary model...")
# Initialize model
self.model = matchvoice_model(
llm_ckpt="meta-llama/Meta-Llama-3-8B-Instruct",
tokenizer_ckpt="meta-llama/Meta-Llama-3-8B-Instruct",
num_video_query_token=32,
num_features=512,
device=self.device,
inference=True
)
# Load checkpoint
checkpoint_path = os.path.join(model_path, "model_save_best_val_CIDEr.pth")
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location="cpu")
# Load state dict
model_state_dict = self.model.state_dict()
for key, value in checkpoint.items():
if key in model_state_dict:
model_state_dict[key] = value
self.model.load_state_dict(model_state_dict)
print("Model checkpoint loaded successfully!")
else:
print(f"Warning: Model checkpoint file not found at {checkpoint_path}")
self.model.eval()
def predict_single(self, video_features: torch.Tensor) -> List[str]:
"""
Predict commentary for a single video clip
Args:
video_features: Video feature tensor
Returns:
List of predicted commentary texts
"""
with torch.no_grad():
# Build input sample format
samples = {
'features': video_features.to(self.device),
'caption_info': [["", "", "", "", "", ""]] # Placeholder
}
predictions = self.model(samples)
return predictions
def predict_batch(self,
feature_root: str,
ann_root: str,
output_csv: str,
batch_size: int = 4,
num_workers: int = 2,
generate_num: int = 1,
fps: float = 0.5,
window: float = 15):
"""
Batch prediction and save results to CSV file
Args:
feature_root: Root directory for video features
ann_root: Root directory for annotation files
output_csv: Output CSV file path
batch_size: Batch size for processing
num_workers: Number of data loading workers
generate_num: Number of commentary generations per video clip
fps: Feature extraction frame rate
window: Video window size in seconds
"""
print("Preparing dataset...")
# Create dataset
test_dataset = MatchVoice_Dataset(
feature_root=feature_root,
ann_root=ann_root,
fps=fps,
timestamp_key="gameTime",
tokenizer_name="meta-llama/Meta-Llama-3-8B-Instruct",
window=window,
split_ratio=0.01, # Use small subset for quick testing
is_train=False
)
test_data_loader = DataLoader(
test_dataset,
batch_size=batch_size,
num_workers=num_workers,
drop_last=False,
shuffle=False,
pin_memory=True,
collate_fn=test_dataset.collater
)
print("Dataset preparation completed, starting prediction...")
# Create output directory
os.makedirs(os.path.dirname(output_csv), exist_ok=True)
# Write CSV header
headers = ['league', 'game', 'half', 'timestamp', 'type', 'anonymized']
headers += [f'predicted_res_{i}' for i in range(generate_num)]
with open(output_csv, 'w', newline='', encoding='utf-8') as file:
writer = csv.writer(file)
writer.writerow(headers)
# Start prediction
with torch.no_grad():
for samples in tqdm(test_data_loader, desc="Prediction Progress"):
all_predictions = []
# Generate multiple predictions
for _ in range(generate_num):
predicted_res = self.model(samples)
all_predictions.append(predicted_res)
# Write results
caption_info = samples["caption_info"]
with open(output_csv, 'a', newline='', encoding='utf-8') as file:
writer = csv.writer(file)
for info in zip(*all_predictions, caption_info):
row = [info[-1][4], info[-1][5], info[-1][0],
info[-1][1], info[-1][2], info[-1][3]] + list(info[:-1])
writer.writerow(row)
print(f"Prediction completed! Results saved to: {output_csv}")
def main():
"""Main function"""
parser = argparse.ArgumentParser(description="Matchcommentary Model Inference Script")
parser.add_argument("--model_path", type=str, default="./",
help="Path to model files")
parser.add_argument("--feature_root", type=str, default="./features",
help="Root directory for video features")
parser.add_argument("--ann_root", type=str, default="./dataset/MatchTime/train",
help="Root directory for annotation files")
parser.add_argument("--output_csv", type=str, default="./predictions.csv",
help="Output CSV file path")
parser.add_argument("--batch_size", type=int, default=4,
help="Batch size for processing")
parser.add_argument("--num_workers", type=int, default=2,
help="Number of data loading workers")
parser.add_argument("--generate_num", type=int, default=1,
help="Number of commentary generations per video clip")
parser.add_argument("--device", type=str, default="cuda:0",
help="Device to run on")
parser.add_argument("--fps", type=float, default=0.5,
help="Feature extraction frame rate")
parser.add_argument("--window", type=float, default=15,
help="Video window size in seconds")
args = parser.parse_args()
# Create predictor and run prediction
predictor = MatchcommentaryPredictor(args.model_path, args.device)
predictor.predict_batch(
feature_root=args.feature_root,
ann_root=args.ann_root,
output_csv=args.output_csv,
batch_size=args.batch_size,
num_workers=args.num_workers,
generate_num=args.generate_num,
fps=args.fps,
window=args.window
)
if __name__ == "__main__":
main()