File size: 7,792 Bytes
1ea8d66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
#!/usr/bin/env python3
"""
Matchcommentary Model Inference Script - HuggingFace Version
For automatic soccer commentary generation
"""

import torch
import argparse
import os
import csv
from tqdm import tqdm
from typing import List, Dict, Any
import json

# Assuming model files are included in the HuggingFace repository
from models.matchvoice_model import matchvoice_model
from matchvoice_dataset import MatchVoice_Dataset
from torch.utils.data import DataLoader

class MatchcommentaryPredictor:
    """Matchcommentary model inference class"""
    
    def __init__(self, model_path: str = "./", device: str = "cuda:0"):
        """
        Initialize Matchcommentary predictor
        
        Args:
            model_path: Path to model files
            device: Device to run on
        """
        self.device = device
        self.model = None
        self.load_model(model_path)
    
    def load_model(self, model_path: str):
        """Load the model"""
        print("Loading Matchcommentary model...")
        
        # Initialize model
        self.model = matchvoice_model(
            llm_ckpt="meta-llama/Meta-Llama-3-8B-Instruct",
            tokenizer_ckpt="meta-llama/Meta-Llama-3-8B-Instruct",
            num_video_query_token=32,
            num_features=512,
            device=self.device,
            inference=True
        )
        
        # Load checkpoint
        checkpoint_path = os.path.join(model_path, "model_save_best_val_CIDEr.pth")
        if os.path.exists(checkpoint_path):
            checkpoint = torch.load(checkpoint_path, map_location="cpu")
            
            # Load state dict
            model_state_dict = self.model.state_dict()
            for key, value in checkpoint.items():
                if key in model_state_dict:
                    model_state_dict[key] = value
            
            self.model.load_state_dict(model_state_dict)
            print("Model checkpoint loaded successfully!")
        else:
            print(f"Warning: Model checkpoint file not found at {checkpoint_path}")
        
        self.model.eval()
    
    def predict_single(self, video_features: torch.Tensor) -> List[str]:
        """
        Predict commentary for a single video clip
        
        Args:
            video_features: Video feature tensor
            
        Returns:
            List of predicted commentary texts
        """
        with torch.no_grad():
            # Build input sample format
            samples = {
                'features': video_features.to(self.device),
                'caption_info': [["", "", "", "", "", ""]]  # Placeholder
            }
            
            predictions = self.model(samples)
            return predictions
    
    def predict_batch(self, 
                     feature_root: str,
                     ann_root: str, 
                     output_csv: str,
                     batch_size: int = 4,
                     num_workers: int = 2,
                     generate_num: int = 1,
                     fps: float = 0.5,
                     window: float = 15):
        """
        Batch prediction and save results to CSV file
        
        Args:
            feature_root: Root directory for video features
            ann_root: Root directory for annotation files
            output_csv: Output CSV file path
            batch_size: Batch size for processing
            num_workers: Number of data loading workers
            generate_num: Number of commentary generations per video clip
            fps: Feature extraction frame rate
            window: Video window size in seconds
        """
        print("Preparing dataset...")
        
        # Create dataset
        test_dataset = MatchVoice_Dataset(
            feature_root=feature_root,
            ann_root=ann_root,
            fps=fps,
            timestamp_key="gameTime",
            tokenizer_name="meta-llama/Meta-Llama-3-8B-Instruct",
            window=window,
            split_ratio=0.01,  # Use small subset for quick testing
            is_train=False
        )
        
        test_data_loader = DataLoader(
            test_dataset, 
            batch_size=batch_size, 
            num_workers=num_workers, 
            drop_last=False, 
            shuffle=False, 
            pin_memory=True, 
            collate_fn=test_dataset.collater
        )
        
        print("Dataset preparation completed, starting prediction...")
        
        # Create output directory
        os.makedirs(os.path.dirname(output_csv), exist_ok=True)
        
        # Write CSV header
        headers = ['league', 'game', 'half', 'timestamp', 'type', 'anonymized']
        headers += [f'predicted_res_{i}' for i in range(generate_num)]
        
        with open(output_csv, 'w', newline='', encoding='utf-8') as file:
            writer = csv.writer(file)
            writer.writerow(headers)
        
        # Start prediction
        with torch.no_grad():
            for samples in tqdm(test_data_loader, desc="Prediction Progress"):
                all_predictions = []
                
                # Generate multiple predictions
                for _ in range(generate_num):
                    predicted_res = self.model(samples)
                    all_predictions.append(predicted_res)
                
                # Write results
                caption_info = samples["caption_info"]
                with open(output_csv, 'a', newline='', encoding='utf-8') as file:
                    writer = csv.writer(file)
                    for info in zip(*all_predictions, caption_info):
                        row = [info[-1][4], info[-1][5], info[-1][0], 
                               info[-1][1], info[-1][2], info[-1][3]] + list(info[:-1])
                        writer.writerow(row)
        
        print(f"Prediction completed! Results saved to: {output_csv}")

def main():
    """Main function"""
    parser = argparse.ArgumentParser(description="Matchcommentary Model Inference Script")
    parser.add_argument("--model_path", type=str, default="./", 
                       help="Path to model files")
    parser.add_argument("--feature_root", type=str, default="./features", 
                       help="Root directory for video features")
    parser.add_argument("--ann_root", type=str, default="./dataset/MatchTime/train", 
                       help="Root directory for annotation files")
    parser.add_argument("--output_csv", type=str, default="./predictions.csv", 
                       help="Output CSV file path")
    parser.add_argument("--batch_size", type=int, default=4, 
                       help="Batch size for processing")
    parser.add_argument("--num_workers", type=int, default=2, 
                       help="Number of data loading workers")
    parser.add_argument("--generate_num", type=int, default=1, 
                       help="Number of commentary generations per video clip")
    parser.add_argument("--device", type=str, default="cuda:0", 
                       help="Device to run on")
    parser.add_argument("--fps", type=float, default=0.5, 
                       help="Feature extraction frame rate")
    parser.add_argument("--window", type=float, default=15, 
                       help="Video window size in seconds")
    
    args = parser.parse_args()
    
    # Create predictor and run prediction
    predictor = MatchcommentaryPredictor(args.model_path, args.device)
    predictor.predict_batch(
        feature_root=args.feature_root,
        ann_root=args.ann_root,
        output_csv=args.output_csv,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        generate_num=args.generate_num,
        fps=args.fps,
        window=args.window
    )

if __name__ == "__main__":
    main()