abocide commited on
Commit
1ea8d66
·
verified ·
1 Parent(s): d49ef20

Upload folder using huggingface_hub

Browse files
README.md ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Matchcommentary: Automatic Soccer Game Commentary Generation Model
2
+
3
+ ## Model Overview
4
+
5
+ Matchcommentary is a multimodal learning-based automatic soccer game commentary generation model that generates fluent soccer commentary text based on video features. The model combines visual feature extraction, Q-Former architecture, and large language models to achieve high-quality soccer commentary generation.
6
+
7
+ ## Model Architecture
8
+
9
+ - **Base Model**: LLaMA-3-8B-Instruct
10
+ - **Vision Encoder**: Q-Former architecture
11
+ - **Feature Dimension**: 512-dimensional video features
12
+ - **Window Size**: 15-second video clips
13
+ - **Query Tokens**: 32 video query tokens
14
+
15
+ ## Usage
16
+
17
+ ### Install Dependencies
18
+
19
+ ```bash
20
+ pip install torch transformers einops pycocoevalcap opencv-python numpy
21
+ ```
22
+
23
+ ### Quick Start
24
+
25
+ ```python
26
+ from models.matchvoice_model import matchvoice_model
27
+ from matchvoice_dataset import MatchVoice_Dataset
28
+ import torch
29
+
30
+ # Load model
31
+ model = matchvoice_model(
32
+ llm_ckpt="meta-llama/Meta-Llama-3-8B-Instruct",
33
+ tokenizer_ckpt="meta-llama/Meta-Llama-3-8B-Instruct",
34
+ num_video_query_token=32,
35
+ num_features=512,
36
+ device="cuda:0",
37
+ inference=True
38
+ )
39
+
40
+ # Load checkpoint
41
+ checkpoint = torch.load("model_save_best_val_CIDEr.pth", map_location="cpu")
42
+ model.load_state_dict(checkpoint)
43
+ model.eval()
44
+
45
+ # Perform inference (requires prepared video features)
46
+ with torch.no_grad():
47
+ predictions = model(samples)
48
+ ```
49
+
50
+ ### Complete Inference Pipeline
51
+
52
+ Using the provided `inference1.py` script:
53
+
54
+ ```bash
55
+ python inference1.py \
56
+ --feature_root ./features \
57
+ --ann_root ./dataset/MatchTime/train \
58
+ --model_ckpt model_save_best_val_CIDEr.pth \
59
+ --window 15 \
60
+ --batch_size 4 \
61
+ --num_video_query_token 32 \
62
+ --num_features 512 \
63
+ --csv_output_path ./inference_result/predictions.csv
64
+ ```
65
+
66
+ ## Input Data Format
67
+
68
+ The model expects the following input format:
69
+
70
+ 1. **Video Features**: ResNet_PCA512 features with shape `[batch_size, time_length, feature_dim]`
71
+ 2. **Timestamp Information**: Metadata including game time, event type, etc.
72
+ 3. **Attention Mask**: For handling variable-length sequences
73
+
74
+ ## Output Format
75
+
76
+ The model outputs a CSV file with the following columns:
77
+ - `league`: League and season information
78
+ - `game`: Game name
79
+ - `half`: First/second half
80
+ - `timestamp`: Event timestamp
81
+ - `type`: Soccer event type
82
+ - `anonymized`: Ground truth annotation
83
+ - `predicted_res_{i}`: Model prediction results
84
+
85
+ ## Model Features
86
+
87
+ - Supports multiple video feature formats (ResNet, C3D, CLIP, etc.)
88
+ - Soccer-specific vocabulary constraint generation
89
+ - Supports both batch inference and single video inference
90
+ - Q-Former-based multimodal fusion architecture
91
+
92
+ ## Performance Metrics
93
+
94
+ Evaluation results on the MatchTime dataset:
95
+ - Achieved best validation CIDEr score
96
+ - Supports real-time soccer commentary generation
97
+
98
+
99
+
ckpt/model_save_best_val_CIDEr.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6aea3d4f7776b9b1c40b518fe1ce0b5ed6a7d3c8c60f55113e9ed08d281439ba
3
+ size 2186901790
config.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": ["MatchcommentaryModel"],
3
+ "model_type": "Matchcommentary",
4
+ "llm_ckpt": "meta-llama/Meta-Llama-3-8B-Instruct",
5
+ "tokenizer_ckpt": "meta-llama/Meta-Llama-3-8B-Instruct",
6
+ "max_frame_pos": 128,
7
+ "window": 15,
8
+ "num_query_tokens": 32,
9
+ "num_video_query_token": 32,
10
+ "num_features": 512,
11
+ "fps": 0.5,
12
+ "max_token_length": 128,
13
+ "feature_subdir": "ResNET_PCA512",
14
+ "torch_dtype": "float16",
15
+ "transformers_version": "4.42.3",
16
+ "description": "MatchcommentaryModel model for automatic soccer game commentary generation, trained on MatchTime dataset"
17
+ }
inference.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Matchcommentary Model Inference Script - HuggingFace Version
4
+ For automatic soccer commentary generation
5
+ """
6
+
7
+ import torch
8
+ import argparse
9
+ import os
10
+ import csv
11
+ from tqdm import tqdm
12
+ from typing import List, Dict, Any
13
+ import json
14
+
15
+ # Assuming model files are included in the HuggingFace repository
16
+ from models.matchvoice_model import matchvoice_model
17
+ from matchvoice_dataset import MatchVoice_Dataset
18
+ from torch.utils.data import DataLoader
19
+
20
+ class MatchcommentaryPredictor:
21
+ """Matchcommentary model inference class"""
22
+
23
+ def __init__(self, model_path: str = "./", device: str = "cuda:0"):
24
+ """
25
+ Initialize Matchcommentary predictor
26
+
27
+ Args:
28
+ model_path: Path to model files
29
+ device: Device to run on
30
+ """
31
+ self.device = device
32
+ self.model = None
33
+ self.load_model(model_path)
34
+
35
+ def load_model(self, model_path: str):
36
+ """Load the model"""
37
+ print("Loading Matchcommentary model...")
38
+
39
+ # Initialize model
40
+ self.model = matchvoice_model(
41
+ llm_ckpt="meta-llama/Meta-Llama-3-8B-Instruct",
42
+ tokenizer_ckpt="meta-llama/Meta-Llama-3-8B-Instruct",
43
+ num_video_query_token=32,
44
+ num_features=512,
45
+ device=self.device,
46
+ inference=True
47
+ )
48
+
49
+ # Load checkpoint
50
+ checkpoint_path = os.path.join(model_path, "model_save_best_val_CIDEr.pth")
51
+ if os.path.exists(checkpoint_path):
52
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
53
+
54
+ # Load state dict
55
+ model_state_dict = self.model.state_dict()
56
+ for key, value in checkpoint.items():
57
+ if key in model_state_dict:
58
+ model_state_dict[key] = value
59
+
60
+ self.model.load_state_dict(model_state_dict)
61
+ print("Model checkpoint loaded successfully!")
62
+ else:
63
+ print(f"Warning: Model checkpoint file not found at {checkpoint_path}")
64
+
65
+ self.model.eval()
66
+
67
+ def predict_single(self, video_features: torch.Tensor) -> List[str]:
68
+ """
69
+ Predict commentary for a single video clip
70
+
71
+ Args:
72
+ video_features: Video feature tensor
73
+
74
+ Returns:
75
+ List of predicted commentary texts
76
+ """
77
+ with torch.no_grad():
78
+ # Build input sample format
79
+ samples = {
80
+ 'features': video_features.to(self.device),
81
+ 'caption_info': [["", "", "", "", "", ""]] # Placeholder
82
+ }
83
+
84
+ predictions = self.model(samples)
85
+ return predictions
86
+
87
+ def predict_batch(self,
88
+ feature_root: str,
89
+ ann_root: str,
90
+ output_csv: str,
91
+ batch_size: int = 4,
92
+ num_workers: int = 2,
93
+ generate_num: int = 1,
94
+ fps: float = 0.5,
95
+ window: float = 15):
96
+ """
97
+ Batch prediction and save results to CSV file
98
+
99
+ Args:
100
+ feature_root: Root directory for video features
101
+ ann_root: Root directory for annotation files
102
+ output_csv: Output CSV file path
103
+ batch_size: Batch size for processing
104
+ num_workers: Number of data loading workers
105
+ generate_num: Number of commentary generations per video clip
106
+ fps: Feature extraction frame rate
107
+ window: Video window size in seconds
108
+ """
109
+ print("Preparing dataset...")
110
+
111
+ # Create dataset
112
+ test_dataset = MatchVoice_Dataset(
113
+ feature_root=feature_root,
114
+ ann_root=ann_root,
115
+ fps=fps,
116
+ timestamp_key="gameTime",
117
+ tokenizer_name="meta-llama/Meta-Llama-3-8B-Instruct",
118
+ window=window,
119
+ split_ratio=0.01, # Use small subset for quick testing
120
+ is_train=False
121
+ )
122
+
123
+ test_data_loader = DataLoader(
124
+ test_dataset,
125
+ batch_size=batch_size,
126
+ num_workers=num_workers,
127
+ drop_last=False,
128
+ shuffle=False,
129
+ pin_memory=True,
130
+ collate_fn=test_dataset.collater
131
+ )
132
+
133
+ print("Dataset preparation completed, starting prediction...")
134
+
135
+ # Create output directory
136
+ os.makedirs(os.path.dirname(output_csv), exist_ok=True)
137
+
138
+ # Write CSV header
139
+ headers = ['league', 'game', 'half', 'timestamp', 'type', 'anonymized']
140
+ headers += [f'predicted_res_{i}' for i in range(generate_num)]
141
+
142
+ with open(output_csv, 'w', newline='', encoding='utf-8') as file:
143
+ writer = csv.writer(file)
144
+ writer.writerow(headers)
145
+
146
+ # Start prediction
147
+ with torch.no_grad():
148
+ for samples in tqdm(test_data_loader, desc="Prediction Progress"):
149
+ all_predictions = []
150
+
151
+ # Generate multiple predictions
152
+ for _ in range(generate_num):
153
+ predicted_res = self.model(samples)
154
+ all_predictions.append(predicted_res)
155
+
156
+ # Write results
157
+ caption_info = samples["caption_info"]
158
+ with open(output_csv, 'a', newline='', encoding='utf-8') as file:
159
+ writer = csv.writer(file)
160
+ for info in zip(*all_predictions, caption_info):
161
+ row = [info[-1][4], info[-1][5], info[-1][0],
162
+ info[-1][1], info[-1][2], info[-1][3]] + list(info[:-1])
163
+ writer.writerow(row)
164
+
165
+ print(f"Prediction completed! Results saved to: {output_csv}")
166
+
167
+ def main():
168
+ """Main function"""
169
+ parser = argparse.ArgumentParser(description="Matchcommentary Model Inference Script")
170
+ parser.add_argument("--model_path", type=str, default="./",
171
+ help="Path to model files")
172
+ parser.add_argument("--feature_root", type=str, default="./features",
173
+ help="Root directory for video features")
174
+ parser.add_argument("--ann_root", type=str, default="./dataset/MatchTime/train",
175
+ help="Root directory for annotation files")
176
+ parser.add_argument("--output_csv", type=str, default="./predictions.csv",
177
+ help="Output CSV file path")
178
+ parser.add_argument("--batch_size", type=int, default=4,
179
+ help="Batch size for processing")
180
+ parser.add_argument("--num_workers", type=int, default=2,
181
+ help="Number of data loading workers")
182
+ parser.add_argument("--generate_num", type=int, default=1,
183
+ help="Number of commentary generations per video clip")
184
+ parser.add_argument("--device", type=str, default="cuda:0",
185
+ help="Device to run on")
186
+ parser.add_argument("--fps", type=float, default=0.5,
187
+ help="Feature extraction frame rate")
188
+ parser.add_argument("--window", type=float, default=15,
189
+ help="Video window size in seconds")
190
+
191
+ args = parser.parse_args()
192
+
193
+ # Create predictor and run prediction
194
+ predictor = MatchcommentaryPredictor(args.model_path, args.device)
195
+ predictor.predict_batch(
196
+ feature_root=args.feature_root,
197
+ ann_root=args.ann_root,
198
+ output_csv=args.output_csv,
199
+ batch_size=args.batch_size,
200
+ num_workers=args.num_workers,
201
+ generate_num=args.generate_num,
202
+ fps=args.fps,
203
+ window=args.window
204
+ )
205
+
206
+ if __name__ == "__main__":
207
+ main()
model_card.md ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags:
4
+ - multimodal
5
+ - video-understanding
6
+ - sports
7
+ - commentary-generation
8
+ - llama3
9
+ - soccer
10
+ language:
11
+ - en
12
+ datasets:
13
+ - MatchTime
14
+ pipeline_tag: text-generation
15
+ ---
16
+
17
+ # Matchcommentary: Automatic Soccer Game Commentary Generation
18
+
19
+ ## Model Description
20
+
21
+ Matchcommentary is a multimodal model designed for automatic soccer game commentary generation. It combines video feature understanding with large language models to generate fluent and contextually appropriate soccer commentary.
22
+
23
+ ## Architecture
24
+
25
+ The model consists of:
26
+ - **Vision Encoder**: Q-Former architecture for processing video features
27
+ - **Language Model**: LLaMA-3-8B-Instruct for text generation
28
+ - **Feature Fusion**: Cross-attention mechanism between visual and textual information
29
+ - **Domain Adaptation**: Soccer-specific vocabulary constraints
30
+
31
+ ## Intended Use
32
+
33
+ ### Primary Use Cases
34
+ - Automatic soccer game commentary generation
35
+ - Sports video understanding and description
36
+ - Multimodal video-to-text generation
37
+
38
+ ### Limitations
39
+ - Trained specifically on soccer/football content
40
+ - Requires pre-extracted video features
41
+ - Performance may vary on different video qualities or angles
42
+
43
+ ## Training Data
44
+
45
+ The model was trained on the MatchTime dataset, which contains:
46
+ - Soccer game videos with corresponding commentary
47
+ - Multiple leagues and seasons
48
+ - Temporal alignment between visual events and commentary
49
+
50
+ ## Performance
51
+
52
+ The model achieves state-of-the-art performance on the MatchTime benchmark, with the best validation CIDEr score among tested configurations.
53
+
54
+ ## Usage
55
+
56
+ ```python
57
+ from models.matchvoice_model import matchvoice_model
58
+ import torch
59
+
60
+ # Load model
61
+ model = matchvoice_model(
62
+ llm_ckpt="meta-llama/Meta-Llama-3-8B-Instruct",
63
+ tokenizer_ckpt="meta-llama/Meta-Llama-3-8B-Instruct",
64
+ num_video_query_token=32,
65
+ num_features=512,
66
+ device="cuda:0",
67
+ inference=True
68
+ )
69
+
70
+ # Load checkpoint
71
+ checkpoint = torch.load("model_save_best_val_CIDEr.pth")
72
+ model.load_state_dict(checkpoint)
73
+ model.eval()
74
+
75
+ # Generate commentary
76
+ with torch.no_grad():
77
+ commentary = model(video_samples)
78
+ ```
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ transformers>=4.42.3
3
+ einops>=0.8.0
4
+ numpy>=1.26.3
5
+ opencv-python>=4.10.0
6
+ pycocoevalcap>=1.2
7
+ pycocotools>=2.0.8
8
+ pillow>=10.4.0
9
+ pyyaml>=6.0.2
10
+ requests>=2.32.3
11
+ safetensors>=0.4.4
12
+ huggingface-hub>=0.24.6
13
+ tqdm
14
+ argparse
soccer_words_llama3.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:654f03e1d4678cd0c3e8ca587af027e4bc14489e94e90bd30ad856242dab2d94
3
+ size 9092