|
|
|
|
|
""" |
|
|
Matchcommentary Model Inference Script - HuggingFace Version |
|
|
For automatic soccer commentary generation |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import argparse |
|
|
import os |
|
|
import csv |
|
|
from tqdm import tqdm |
|
|
from typing import List, Dict, Any |
|
|
import json |
|
|
|
|
|
|
|
|
from models.matchvoice_model import matchvoice_model |
|
|
from matchvoice_dataset import MatchVoice_Dataset |
|
|
from torch.utils.data import DataLoader |
|
|
|
|
|
class MatchcommentaryPredictor: |
|
|
"""Matchcommentary model inference class""" |
|
|
|
|
|
def __init__(self, model_path: str = "./", device: str = "cuda:0"): |
|
|
""" |
|
|
Initialize Matchcommentary predictor |
|
|
|
|
|
Args: |
|
|
model_path: Path to model files |
|
|
device: Device to run on |
|
|
""" |
|
|
self.device = device |
|
|
self.model = None |
|
|
self.load_model(model_path) |
|
|
|
|
|
def load_model(self, model_path: str): |
|
|
"""Load the model""" |
|
|
print("Loading Matchcommentary model...") |
|
|
|
|
|
|
|
|
self.model = matchvoice_model( |
|
|
llm_ckpt="meta-llama/Meta-Llama-3-8B-Instruct", |
|
|
tokenizer_ckpt="meta-llama/Meta-Llama-3-8B-Instruct", |
|
|
num_video_query_token=32, |
|
|
num_features=512, |
|
|
device=self.device, |
|
|
inference=True |
|
|
) |
|
|
|
|
|
|
|
|
checkpoint_path = os.path.join(model_path, "model_save_best_val_CIDEr.pth") |
|
|
if os.path.exists(checkpoint_path): |
|
|
checkpoint = torch.load(checkpoint_path, map_location="cpu") |
|
|
|
|
|
|
|
|
model_state_dict = self.model.state_dict() |
|
|
for key, value in checkpoint.items(): |
|
|
if key in model_state_dict: |
|
|
model_state_dict[key] = value |
|
|
|
|
|
self.model.load_state_dict(model_state_dict) |
|
|
print("Model checkpoint loaded successfully!") |
|
|
else: |
|
|
print(f"Warning: Model checkpoint file not found at {checkpoint_path}") |
|
|
|
|
|
self.model.eval() |
|
|
|
|
|
def predict_single(self, video_features: torch.Tensor) -> List[str]: |
|
|
""" |
|
|
Predict commentary for a single video clip |
|
|
|
|
|
Args: |
|
|
video_features: Video feature tensor |
|
|
|
|
|
Returns: |
|
|
List of predicted commentary texts |
|
|
""" |
|
|
with torch.no_grad(): |
|
|
|
|
|
samples = { |
|
|
'features': video_features.to(self.device), |
|
|
'caption_info': [["", "", "", "", "", ""]] |
|
|
} |
|
|
|
|
|
predictions = self.model(samples) |
|
|
return predictions |
|
|
|
|
|
def predict_batch(self, |
|
|
feature_root: str, |
|
|
ann_root: str, |
|
|
output_csv: str, |
|
|
batch_size: int = 4, |
|
|
num_workers: int = 2, |
|
|
generate_num: int = 1, |
|
|
fps: float = 0.5, |
|
|
window: float = 15): |
|
|
""" |
|
|
Batch prediction and save results to CSV file |
|
|
|
|
|
Args: |
|
|
feature_root: Root directory for video features |
|
|
ann_root: Root directory for annotation files |
|
|
output_csv: Output CSV file path |
|
|
batch_size: Batch size for processing |
|
|
num_workers: Number of data loading workers |
|
|
generate_num: Number of commentary generations per video clip |
|
|
fps: Feature extraction frame rate |
|
|
window: Video window size in seconds |
|
|
""" |
|
|
print("Preparing dataset...") |
|
|
|
|
|
|
|
|
test_dataset = MatchVoice_Dataset( |
|
|
feature_root=feature_root, |
|
|
ann_root=ann_root, |
|
|
fps=fps, |
|
|
timestamp_key="gameTime", |
|
|
tokenizer_name="meta-llama/Meta-Llama-3-8B-Instruct", |
|
|
window=window, |
|
|
split_ratio=0.01, |
|
|
is_train=False |
|
|
) |
|
|
|
|
|
test_data_loader = DataLoader( |
|
|
test_dataset, |
|
|
batch_size=batch_size, |
|
|
num_workers=num_workers, |
|
|
drop_last=False, |
|
|
shuffle=False, |
|
|
pin_memory=True, |
|
|
collate_fn=test_dataset.collater |
|
|
) |
|
|
|
|
|
print("Dataset preparation completed, starting prediction...") |
|
|
|
|
|
|
|
|
os.makedirs(os.path.dirname(output_csv), exist_ok=True) |
|
|
|
|
|
|
|
|
headers = ['league', 'game', 'half', 'timestamp', 'type', 'anonymized'] |
|
|
headers += [f'predicted_res_{i}' for i in range(generate_num)] |
|
|
|
|
|
with open(output_csv, 'w', newline='', encoding='utf-8') as file: |
|
|
writer = csv.writer(file) |
|
|
writer.writerow(headers) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
for samples in tqdm(test_data_loader, desc="Prediction Progress"): |
|
|
all_predictions = [] |
|
|
|
|
|
|
|
|
for _ in range(generate_num): |
|
|
predicted_res = self.model(samples) |
|
|
all_predictions.append(predicted_res) |
|
|
|
|
|
|
|
|
caption_info = samples["caption_info"] |
|
|
with open(output_csv, 'a', newline='', encoding='utf-8') as file: |
|
|
writer = csv.writer(file) |
|
|
for info in zip(*all_predictions, caption_info): |
|
|
row = [info[-1][4], info[-1][5], info[-1][0], |
|
|
info[-1][1], info[-1][2], info[-1][3]] + list(info[:-1]) |
|
|
writer.writerow(row) |
|
|
|
|
|
print(f"Prediction completed! Results saved to: {output_csv}") |
|
|
|
|
|
def main(): |
|
|
"""Main function""" |
|
|
parser = argparse.ArgumentParser(description="Matchcommentary Model Inference Script") |
|
|
parser.add_argument("--model_path", type=str, default="./", |
|
|
help="Path to model files") |
|
|
parser.add_argument("--feature_root", type=str, default="./features", |
|
|
help="Root directory for video features") |
|
|
parser.add_argument("--ann_root", type=str, default="./dataset/MatchTime/train", |
|
|
help="Root directory for annotation files") |
|
|
parser.add_argument("--output_csv", type=str, default="./predictions.csv", |
|
|
help="Output CSV file path") |
|
|
parser.add_argument("--batch_size", type=int, default=4, |
|
|
help="Batch size for processing") |
|
|
parser.add_argument("--num_workers", type=int, default=2, |
|
|
help="Number of data loading workers") |
|
|
parser.add_argument("--generate_num", type=int, default=1, |
|
|
help="Number of commentary generations per video clip") |
|
|
parser.add_argument("--device", type=str, default="cuda:0", |
|
|
help="Device to run on") |
|
|
parser.add_argument("--fps", type=float, default=0.5, |
|
|
help="Feature extraction frame rate") |
|
|
parser.add_argument("--window", type=float, default=15, |
|
|
help="Video window size in seconds") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
predictor = MatchcommentaryPredictor(args.model_path, args.device) |
|
|
predictor.predict_batch( |
|
|
feature_root=args.feature_root, |
|
|
ann_root=args.ann_root, |
|
|
output_csv=args.output_csv, |
|
|
batch_size=args.batch_size, |
|
|
num_workers=args.num_workers, |
|
|
generate_num=args.generate_num, |
|
|
fps=args.fps, |
|
|
window=args.window |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |