File size: 7,792 Bytes
1ea8d66 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 |
#!/usr/bin/env python3
"""
Matchcommentary Model Inference Script - HuggingFace Version
For automatic soccer commentary generation
"""
import torch
import argparse
import os
import csv
from tqdm import tqdm
from typing import List, Dict, Any
import json
# Assuming model files are included in the HuggingFace repository
from models.matchvoice_model import matchvoice_model
from matchvoice_dataset import MatchVoice_Dataset
from torch.utils.data import DataLoader
class MatchcommentaryPredictor:
"""Matchcommentary model inference class"""
def __init__(self, model_path: str = "./", device: str = "cuda:0"):
"""
Initialize Matchcommentary predictor
Args:
model_path: Path to model files
device: Device to run on
"""
self.device = device
self.model = None
self.load_model(model_path)
def load_model(self, model_path: str):
"""Load the model"""
print("Loading Matchcommentary model...")
# Initialize model
self.model = matchvoice_model(
llm_ckpt="meta-llama/Meta-Llama-3-8B-Instruct",
tokenizer_ckpt="meta-llama/Meta-Llama-3-8B-Instruct",
num_video_query_token=32,
num_features=512,
device=self.device,
inference=True
)
# Load checkpoint
checkpoint_path = os.path.join(model_path, "model_save_best_val_CIDEr.pth")
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location="cpu")
# Load state dict
model_state_dict = self.model.state_dict()
for key, value in checkpoint.items():
if key in model_state_dict:
model_state_dict[key] = value
self.model.load_state_dict(model_state_dict)
print("Model checkpoint loaded successfully!")
else:
print(f"Warning: Model checkpoint file not found at {checkpoint_path}")
self.model.eval()
def predict_single(self, video_features: torch.Tensor) -> List[str]:
"""
Predict commentary for a single video clip
Args:
video_features: Video feature tensor
Returns:
List of predicted commentary texts
"""
with torch.no_grad():
# Build input sample format
samples = {
'features': video_features.to(self.device),
'caption_info': [["", "", "", "", "", ""]] # Placeholder
}
predictions = self.model(samples)
return predictions
def predict_batch(self,
feature_root: str,
ann_root: str,
output_csv: str,
batch_size: int = 4,
num_workers: int = 2,
generate_num: int = 1,
fps: float = 0.5,
window: float = 15):
"""
Batch prediction and save results to CSV file
Args:
feature_root: Root directory for video features
ann_root: Root directory for annotation files
output_csv: Output CSV file path
batch_size: Batch size for processing
num_workers: Number of data loading workers
generate_num: Number of commentary generations per video clip
fps: Feature extraction frame rate
window: Video window size in seconds
"""
print("Preparing dataset...")
# Create dataset
test_dataset = MatchVoice_Dataset(
feature_root=feature_root,
ann_root=ann_root,
fps=fps,
timestamp_key="gameTime",
tokenizer_name="meta-llama/Meta-Llama-3-8B-Instruct",
window=window,
split_ratio=0.01, # Use small subset for quick testing
is_train=False
)
test_data_loader = DataLoader(
test_dataset,
batch_size=batch_size,
num_workers=num_workers,
drop_last=False,
shuffle=False,
pin_memory=True,
collate_fn=test_dataset.collater
)
print("Dataset preparation completed, starting prediction...")
# Create output directory
os.makedirs(os.path.dirname(output_csv), exist_ok=True)
# Write CSV header
headers = ['league', 'game', 'half', 'timestamp', 'type', 'anonymized']
headers += [f'predicted_res_{i}' for i in range(generate_num)]
with open(output_csv, 'w', newline='', encoding='utf-8') as file:
writer = csv.writer(file)
writer.writerow(headers)
# Start prediction
with torch.no_grad():
for samples in tqdm(test_data_loader, desc="Prediction Progress"):
all_predictions = []
# Generate multiple predictions
for _ in range(generate_num):
predicted_res = self.model(samples)
all_predictions.append(predicted_res)
# Write results
caption_info = samples["caption_info"]
with open(output_csv, 'a', newline='', encoding='utf-8') as file:
writer = csv.writer(file)
for info in zip(*all_predictions, caption_info):
row = [info[-1][4], info[-1][5], info[-1][0],
info[-1][1], info[-1][2], info[-1][3]] + list(info[:-1])
writer.writerow(row)
print(f"Prediction completed! Results saved to: {output_csv}")
def main():
"""Main function"""
parser = argparse.ArgumentParser(description="Matchcommentary Model Inference Script")
parser.add_argument("--model_path", type=str, default="./",
help="Path to model files")
parser.add_argument("--feature_root", type=str, default="./features",
help="Root directory for video features")
parser.add_argument("--ann_root", type=str, default="./dataset/MatchTime/train",
help="Root directory for annotation files")
parser.add_argument("--output_csv", type=str, default="./predictions.csv",
help="Output CSV file path")
parser.add_argument("--batch_size", type=int, default=4,
help="Batch size for processing")
parser.add_argument("--num_workers", type=int, default=2,
help="Number of data loading workers")
parser.add_argument("--generate_num", type=int, default=1,
help="Number of commentary generations per video clip")
parser.add_argument("--device", type=str, default="cuda:0",
help="Device to run on")
parser.add_argument("--fps", type=float, default=0.5,
help="Feature extraction frame rate")
parser.add_argument("--window", type=float, default=15,
help="Video window size in seconds")
args = parser.parse_args()
# Create predictor and run prediction
predictor = MatchcommentaryPredictor(args.model_path, args.device)
predictor.predict_batch(
feature_root=args.feature_root,
ann_root=args.ann_root,
output_csv=args.output_csv,
batch_size=args.batch_size,
num_workers=args.num_workers,
generate_num=args.generate_num,
fps=args.fps,
window=args.window
)
if __name__ == "__main__":
main() |