MSherbinii commited on
Commit
c3981cb
·
verified ·
1 Parent(s): 463a80f

Add HF-adapted training script with Accelerate

Browse files
Files changed (1) hide show
  1. train_hf.py +315 -0
train_hf.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HuggingFace-adapted IPAD Training Script
3
+ Trains on HF infrastructure with ZeroGPU, Accelerate, and automatic checkpointing
4
+ """
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torch.optim import Adam
9
+ from torch.cuda.amp import autocast, GradScaler
10
+ from pathlib import Path
11
+ import json
12
+ from datetime import datetime
13
+ from tqdm import tqdm
14
+ import wandb
15
+ from typing import Dict, Optional
16
+ import os
17
+
18
+ # HF infrastructure
19
+ from huggingface_hub import HfApi, create_repo
20
+ from accelerate import Accelerator
21
+
22
+ # Local imports
23
+ from IPAD.model.video_swin_transformer import VST
24
+ from IPAD.model.entropy_loss import EntropyLossEncap
25
+ from dataset import create_dataloaders, download_and_extract_dataset
26
+
27
+ class IPADTrainer:
28
+ """
29
+ IPAD Model Trainer with HF Integration
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ device_name: str = "S01",
35
+ mem_dim: int = 2000,
36
+ shrink_thres: float = 0.0025,
37
+ lr: float = 1e-4,
38
+ batch_size: int = 4,
39
+ epochs: int = 200,
40
+ entropy_loss_weight: float = 0.0002,
41
+ period_loss_weight: float = 0.02,
42
+ checkpoint_dir: str = "./checkpoints",
43
+ wandb_project: Optional[str] = "ipad-vad",
44
+ hf_repo: Optional[str] = "MSherbinii/ipad-vad-checkpoints"
45
+ ):
46
+ self.device_name = device_name
47
+ self.mem_dim = mem_dim
48
+ self.shrink_thres = shrink_thres
49
+ self.lr = lr
50
+ self.batch_size = batch_size
51
+ self.epochs = epochs
52
+ self.entropy_loss_weight = entropy_loss_weight
53
+ self.period_loss_weight = period_loss_weight
54
+ self.checkpoint_dir = Path(checkpoint_dir)
55
+ self.checkpoint_dir.mkdir(exist_ok=True, parents=True)
56
+ self.wandb_project = wandb_project
57
+ self.hf_repo = hf_repo
58
+
59
+ # Initialize Accelerator for distributed training
60
+ self.accelerator = Accelerator(
61
+ mixed_precision='fp16',
62
+ gradient_accumulation_steps=1,
63
+ log_with="wandb" if wandb_project else None
64
+ )
65
+
66
+ # Model
67
+ self.model = VST(mem_dim=mem_dim, shrink_thres=shrink_thres)
68
+
69
+ # Losses
70
+ self.recon_criterion = nn.MSELoss()
71
+ self.entropy_criterion = EntropyLossEncap()
72
+ self.period_criterion = nn.CrossEntropyLoss()
73
+
74
+ # Optimizer
75
+ self.optimizer = Adam(self.model.parameters(), lr=lr)
76
+
77
+ # HF API
78
+ self.hf_api = HfApi()
79
+ if hf_repo:
80
+ try:
81
+ create_repo(hf_repo, repo_type="model", private=False, exist_ok=True)
82
+ except:
83
+ pass
84
+
85
+ def setup_data(self, dataset_path: str):
86
+ """Setup dataloaders"""
87
+ self.train_loader, self.test_loader = create_dataloaders(
88
+ dataset_path=dataset_path,
89
+ device_name=self.device_name,
90
+ batch_size=self.batch_size,
91
+ num_workers=4,
92
+ clip_length=16,
93
+ frame_size=(256, 256)
94
+ )
95
+
96
+ # Prepare with Accelerator
97
+ self.model, self.optimizer, self.train_loader, self.test_loader = \
98
+ self.accelerator.prepare(
99
+ self.model, self.optimizer, self.train_loader, self.test_loader
100
+ )
101
+
102
+ def train_epoch(self, epoch: int) -> Dict[str, float]:
103
+ """Train for one epoch"""
104
+ self.model.train()
105
+ total_loss = 0.0
106
+ recon_loss_sum = 0.0
107
+ entropy_loss_sum = 0.0
108
+ period_loss_sum = 0.0
109
+
110
+ pbar = tqdm(self.train_loader, desc=f"Epoch {epoch}/{self.epochs}")
111
+
112
+ for batch_idx, clips in enumerate(pbar):
113
+ # clips shape: [B, C, T, H, W]
114
+
115
+ with self.accelerator.autocast():
116
+ # Forward pass
117
+ outputs = self.model(clips)
118
+ reconstructed = outputs['output']
119
+ att = outputs['att']
120
+ period_pred = outputs['recon_index']
121
+
122
+ # Reconstruction loss
123
+ recon_loss = self.recon_criterion(reconstructed, clips)
124
+
125
+ # Entropy loss on attention weights
126
+ entropy_loss = self.entropy_criterion(att)
127
+
128
+ # Period classification loss
129
+ # Create pseudo-labels (uniform distribution for now)
130
+ # In full implementation, this would use actual period annotations
131
+ period_labels = torch.randint(0, 200, (clips.size(0),)).to(clips.device)
132
+ period_loss = self.period_criterion(period_pred, period_labels)
133
+
134
+ # Combined loss
135
+ loss = (recon_loss +
136
+ self.entropy_loss_weight * entropy_loss +
137
+ self.period_loss_weight * period_loss)
138
+
139
+ # Backward pass
140
+ self.accelerator.backward(loss)
141
+ self.optimizer.step()
142
+ self.optimizer.zero_grad()
143
+
144
+ # Accumulate losses
145
+ total_loss += loss.item()
146
+ recon_loss_sum += recon_loss.item()
147
+ entropy_loss_sum += entropy_loss.item()
148
+ period_loss_sum += period_loss.item()
149
+
150
+ # Update progress bar
151
+ pbar.set_postfix({
152
+ 'loss': f'{loss.item():.4f}',
153
+ 'recon': f'{recon_loss.item():.4f}',
154
+ 'entropy': f'{entropy_loss.item():.6f}',
155
+ 'period': f'{period_loss.item():.4f}'
156
+ })
157
+
158
+ num_batches = len(self.train_loader)
159
+ return {
160
+ 'train_loss': total_loss / num_batches,
161
+ 'train_recon_loss': recon_loss_sum / num_batches,
162
+ 'train_entropy_loss': entropy_loss_sum / num_batches,
163
+ 'train_period_loss': period_loss_sum / num_batches
164
+ }
165
+
166
+ @torch.no_grad()
167
+ def validate(self) -> Dict[str, float]:
168
+ """Validate on test set"""
169
+ self.model.eval()
170
+ total_loss = 0.0
171
+ recon_loss_sum = 0.0
172
+
173
+ for clips in tqdm(self.test_loader, desc="Validating"):
174
+ with self.accelerator.autocast():
175
+ outputs = self.model(clips)
176
+ reconstructed = outputs['output']
177
+
178
+ recon_loss = self.recon_criterion(reconstructed, clips)
179
+ total_loss += recon_loss.item()
180
+ recon_loss_sum += recon_loss.item()
181
+
182
+ num_batches = len(self.test_loader)
183
+ return {
184
+ 'val_loss': total_loss / num_batches,
185
+ 'val_recon_loss': recon_loss_sum / num_batches
186
+ }
187
+
188
+ def save_checkpoint(self, epoch: int, metrics: Dict[str, float]):
189
+ """Save checkpoint locally and upload to HF Hub"""
190
+ checkpoint_name = f"{self.device_name}_epoch_{epoch:03d}.pth"
191
+ checkpoint_path = self.checkpoint_dir / checkpoint_name
192
+
193
+ # Save checkpoint
194
+ checkpoint = {
195
+ 'epoch': epoch,
196
+ 'model_state_dict': self.accelerator.unwrap_model(self.model).state_dict(),
197
+ 'optimizer_state_dict': self.optimizer.state_dict(),
198
+ 'metrics': metrics,
199
+ 'config': {
200
+ 'device_name': self.device_name,
201
+ 'mem_dim': self.mem_dim,
202
+ 'shrink_thres': self.shrink_thres,
203
+ 'lr': self.lr,
204
+ 'batch_size': self.batch_size
205
+ }
206
+ }
207
+
208
+ torch.save(checkpoint, checkpoint_path)
209
+ print(f"💾 Checkpoint saved: {checkpoint_path}")
210
+
211
+ # Upload to HF Hub
212
+ if self.hf_repo:
213
+ try:
214
+ self.hf_api.upload_file(
215
+ path_or_fileobj=str(checkpoint_path),
216
+ path_in_repo=f"checkpoints/{checkpoint_name}",
217
+ repo_id=self.hf_repo,
218
+ repo_type="model",
219
+ commit_message=f"Epoch {epoch} - {self.device_name}"
220
+ )
221
+ print(f"☁️ Uploaded to HF Hub: {self.hf_repo}")
222
+ except Exception as e:
223
+ print(f"⚠️ Failed to upload to HF Hub: {e}")
224
+
225
+ def train(self, dataset_path: str):
226
+ """Full training loop"""
227
+ print(f"\n🚀 Starting training for {self.device_name}")
228
+ print(f"📊 Epochs: {self.epochs}, Batch Size: {self.batch_size}, LR: {self.lr}")
229
+
230
+ # Setup data
231
+ self.setup_data(dataset_path)
232
+
233
+ # Initialize wandb
234
+ if self.wandb_project:
235
+ self.accelerator.init_trackers(
236
+ project_name=self.wandb_project,
237
+ config={
238
+ 'device_name': self.device_name,
239
+ 'mem_dim': self.mem_dim,
240
+ 'lr': self.lr,
241
+ 'batch_size': self.batch_size,
242
+ 'epochs': self.epochs
243
+ }
244
+ )
245
+
246
+ # Training loop
247
+ best_val_loss = float('inf')
248
+
249
+ for epoch in range(1, self.epochs + 1):
250
+ # Train
251
+ train_metrics = self.train_epoch(epoch)
252
+
253
+ # Validate every 10 epochs
254
+ if epoch % 10 == 0:
255
+ val_metrics = self.validate()
256
+ metrics = {**train_metrics, **val_metrics}
257
+
258
+ # Save best model
259
+ if val_metrics['val_loss'] < best_val_loss:
260
+ best_val_loss = val_metrics['val_loss']
261
+ self.save_checkpoint(epoch, metrics)
262
+
263
+ # Log metrics
264
+ if self.wandb_project:
265
+ self.accelerator.log(metrics, step=epoch)
266
+
267
+ print(f"\n📊 Epoch {epoch} - Train Loss: {train_metrics['train_loss']:.4f}, Val Loss: {val_metrics['val_loss']:.4f}")
268
+
269
+ # Save checkpoint every 50 epochs
270
+ if epoch % 50 == 0:
271
+ self.save_checkpoint(epoch, train_metrics)
272
+
273
+ print(f"\n✅ Training complete for {self.device_name}!")
274
+ print(f"📂 Checkpoints saved to: {self.checkpoint_dir}")
275
+ if self.hf_repo:
276
+ print(f"☁️ Model available at: https://huggingface.co/{self.hf_repo}")
277
+
278
+
279
+ def main():
280
+ """Main training entry point"""
281
+ import argparse
282
+
283
+ parser = argparse.ArgumentParser(description="Train IPAD VAD model on HF infrastructure")
284
+ parser.add_argument("--device", type=str, default="S01", help="Device name (S01-S12, R01-R04)")
285
+ parser.add_argument("--epochs", type=int, default=200, help="Number of epochs")
286
+ parser.add_argument("--batch-size", type=int, default=4, help="Batch size")
287
+ parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate")
288
+ parser.add_argument("--mem-dim", type=int, default=2000, help="Memory dimension")
289
+ parser.add_argument("--no-wandb", action="store_true", help="Disable wandb logging")
290
+ parser.add_argument("--dataset-path", type=str, default=None, help="Path to dataset (downloads if not provided)")
291
+
292
+ args = parser.parse_args()
293
+
294
+ # Download dataset if needed
295
+ if args.dataset_path is None:
296
+ dataset_path = download_and_extract_dataset()
297
+ else:
298
+ dataset_path = Path(args.dataset_path)
299
+
300
+ # Create trainer
301
+ trainer = IPADTrainer(
302
+ device_name=args.device,
303
+ epochs=args.epochs,
304
+ batch_size=args.batch_size,
305
+ lr=args.lr,
306
+ mem_dim=args.mem_dim,
307
+ wandb_project=None if args.no_wandb else "ipad-vad"
308
+ )
309
+
310
+ # Train
311
+ trainer.train(str(dataset_path))
312
+
313
+
314
+ if __name__ == "__main__":
315
+ main()