AbstractPhil commited on
Commit
e889862
·
verified ·
1 Parent(s): ef67372

Create trainer.py

Browse files
Files changed (1) hide show
  1. trainer.py +402 -0
trainer.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Train CantorLinear classifier on pre-extracted ImageNet CLIP features.
4
+ Uses AbstractPhil/imagenet-clip-features-orderly dataset from HuggingFace.
5
+ Author: AbstractPhil
6
+ License: MIT
7
+
8
+
9
+ Uses the geometricvocab github implementation.
10
+ try:
11
+ !pip uninstall -qy geometricvocab
12
+ except:
13
+ pass
14
+
15
+ !pip install -q git+https://github.com/AbstractEyes/lattice_vocabulary.git
16
+
17
+ """
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.optim as optim
22
+ from torch.utils.data import DataLoader, Dataset
23
+ from datasets import load_dataset
24
+ from tqdm import tqdm
25
+ import wandb
26
+ from dataclasses import dataclass
27
+ import sys
28
+ import math
29
+
30
+ # Import your CantorLinear layer
31
+ # Adjust the import path as needed for your setup
32
+ from geovocab2.train.model.layers.linear import CantorLinear, CantorLinearConfig
33
+
34
+
35
+ # ============================================================
36
+ # CONFIGURATION
37
+ # ============================================================
38
+
39
+ @dataclass
40
+ class TrainConfig:
41
+ # Dataset
42
+ dataset_name: str = "AbstractPhil/imagenet-clip-features-orderly"
43
+ clip_dim: int = 512 # CLIP ViT-B/16 feature dimension
44
+ num_classes: int = 1000 # ImageNet classes
45
+
46
+ # Model
47
+ hidden_dims: list = None # [2048, 1024] for 2-layer, None for direct
48
+ cantor_depth: int = 8
49
+ mask_mode: str = "alpha"
50
+ alpha_mode: str = "sigmoid"
51
+ alpha_min: float = 0.1
52
+ alpha_max: float = 1.0
53
+ per_output_alpha: bool = False
54
+ dropout: float = 0.1
55
+
56
+ # Training
57
+ batch_size: int = 512
58
+ num_epochs: int = 50
59
+ learning_rate: float = 1e-3
60
+ weight_decay: float = 1e-4
61
+ warmup_epochs: int = 5
62
+
63
+ # Optimizer
64
+ alpha_lr_mult: float = 0.1 # Separate LR for alpha parameters
65
+
66
+ # Logging
67
+ use_wandb: bool = False
68
+ wandb_project: str = "cantor-imagenet"
69
+ log_every: int = 50
70
+ eval_every: int = 500
71
+
72
+ # System
73
+ device: str = "cuda" if torch.cuda.is_available() else "cpu"
74
+ num_workers: int = 4
75
+ seed: int = 42
76
+
77
+ def __post_init__(self):
78
+ if self.hidden_dims is None:
79
+ self.hidden_dims = [] # Direct CLIP → classes
80
+
81
+
82
+ # ============================================================
83
+ # DATASET
84
+ # ============================================================
85
+
86
+ class CLIPFeaturesDataset(Dataset):
87
+ """Wrapper for HuggingFace dataset of CLIP features."""
88
+
89
+ def __init__(self, hf_dataset):
90
+ self.dataset = hf_dataset
91
+
92
+ def __len__(self):
93
+ return len(self.dataset)
94
+
95
+ def __getitem__(self, idx):
96
+ item = self.dataset[idx]
97
+ features = torch.tensor(item['clip_features'], dtype=torch.float32)
98
+ label = torch.tensor(item['label'], dtype=torch.long)
99
+ return features, label
100
+
101
+
102
+ # ============================================================
103
+ # MODEL
104
+ # ============================================================
105
+
106
+ class CantorCLIPClassifier(nn.Module):
107
+ """
108
+ Multi-layer classifier using CantorLinear layers.
109
+ Maps CLIP features → [hidden layers] → ImageNet classes
110
+ """
111
+
112
+ def __init__(self, cfg: TrainConfig):
113
+ super().__init__()
114
+ self.cfg = cfg
115
+
116
+ # Build layers
117
+ layers = []
118
+ in_dim = cfg.clip_dim
119
+
120
+ # Hidden layers
121
+ for hidden_dim in cfg.hidden_dims:
122
+ layers.append(CantorLinear(CantorLinearConfig(
123
+ in_features=in_dim,
124
+ out_features=hidden_dim,
125
+ depth=cfg.cantor_depth,
126
+ mask_mode=cfg.mask_mode,
127
+ alpha_mode=cfg.alpha_mode,
128
+ alpha_min=cfg.alpha_min,
129
+ alpha_max=cfg.alpha_max,
130
+ per_output_alpha=cfg.per_output_alpha
131
+ )))
132
+ layers.append(nn.ReLU())
133
+ layers.append(nn.Dropout(cfg.dropout))
134
+ in_dim = hidden_dim
135
+
136
+ # Output layer
137
+ layers.append(CantorLinear(CantorLinearConfig(
138
+ in_features=in_dim,
139
+ out_features=cfg.num_classes,
140
+ depth=cfg.cantor_depth,
141
+ mask_mode=cfg.mask_mode,
142
+ alpha_mode=cfg.alpha_mode,
143
+ alpha_min=cfg.alpha_min,
144
+ alpha_max=cfg.alpha_max,
145
+ per_output_alpha=cfg.per_output_alpha
146
+ )))
147
+
148
+ self.classifier = nn.Sequential(*layers)
149
+
150
+ def forward(self, x):
151
+ return self.classifier(x)
152
+
153
+ def get_alpha_stats(self):
154
+ """Collect alpha statistics from all CantorLinear layers."""
155
+ stats = {
156
+ "layer_names": [],
157
+ "alpha_means": [],
158
+ "alpha_stds": [],
159
+ "mask_densities": []
160
+ }
161
+
162
+ for name, module in self.named_modules():
163
+ if isinstance(module, CantorLinear):
164
+ alpha_stats = module.get_alpha_stats()
165
+ if alpha_stats:
166
+ stats["layer_names"].append(name)
167
+ stats["alpha_means"].append(alpha_stats["alpha_mean"])
168
+ stats["alpha_stds"].append(alpha_stats.get("alpha_std", 0.0))
169
+ stats["mask_densities"].append(module.mask.mean().item())
170
+
171
+ return stats
172
+
173
+
174
+ # ============================================================
175
+ # TRAINING
176
+ # ============================================================
177
+
178
+ def train_epoch(model, dataloader, criterion, optimizer, scheduler, cfg, epoch):
179
+ """Train for one epoch."""
180
+ model.train()
181
+ total_loss = 0.0
182
+ correct = 0
183
+ total = 0
184
+
185
+ pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{cfg.num_epochs}")
186
+
187
+ for batch_idx, (features, labels) in enumerate(pbar):
188
+ features = features.to(cfg.device)
189
+ labels = labels.to(cfg.device)
190
+
191
+ # Forward
192
+ optimizer.zero_grad()
193
+ outputs = model(features)
194
+ loss = criterion(outputs, labels)
195
+
196
+ # Backward
197
+ loss.backward()
198
+ optimizer.step()
199
+ if scheduler is not None:
200
+ scheduler.step()
201
+
202
+ # Metrics
203
+ total_loss += loss.item()
204
+ _, predicted = outputs.max(1)
205
+ total += labels.size(0)
206
+ correct += predicted.eq(labels).sum().item()
207
+
208
+ # Logging
209
+ if batch_idx % cfg.log_every == 0:
210
+ avg_loss = total_loss / (batch_idx + 1)
211
+ acc = 100. * correct / total
212
+ pbar.set_postfix({
213
+ 'loss': f'{avg_loss:.4f}',
214
+ 'acc': f'{acc:.2f}%'
215
+ })
216
+
217
+ if cfg.use_wandb:
218
+ wandb.log({
219
+ 'train/loss': avg_loss,
220
+ 'train/acc': acc,
221
+ 'train/lr': optimizer.param_groups[0]['lr']
222
+ })
223
+
224
+ return total_loss / len(dataloader), 100. * correct / total
225
+
226
+
227
+ def evaluate(model, dataloader, criterion, cfg):
228
+ """Evaluate model."""
229
+ model.eval()
230
+ total_loss = 0.0
231
+ correct = 0
232
+ total = 0
233
+
234
+ with torch.no_grad():
235
+ for features, labels in tqdm(dataloader, desc="Evaluating"):
236
+ features = features.to(cfg.device)
237
+ labels = labels.to(cfg.device)
238
+
239
+ outputs = model(features)
240
+ loss = criterion(outputs, labels)
241
+
242
+ total_loss += loss.item()
243
+ _, predicted = outputs.max(1)
244
+ total += labels.size(0)
245
+ correct += predicted.eq(labels).sum().item()
246
+
247
+ avg_loss = total_loss / len(dataloader)
248
+ acc = 100. * correct / total
249
+
250
+ return avg_loss, acc
251
+
252
+
253
+ def main():
254
+ cfg = TrainConfig()
255
+
256
+ # Set seed
257
+ torch.manual_seed(cfg.seed)
258
+ if torch.cuda.is_available():
259
+ torch.cuda.manual_seed(cfg.seed)
260
+
261
+ print("=" * 60)
262
+ print("CantorLinear ImageNet CLIP Features Training")
263
+ print("=" * 60)
264
+ print(f"\nConfiguration:")
265
+ print(f" Dataset: {cfg.dataset_name}")
266
+ print(f" CLIP dim: {cfg.clip_dim}")
267
+ print(f" Hidden dims: {cfg.hidden_dims if cfg.hidden_dims else 'Direct'}")
268
+ print(f" Cantor depth: {cfg.cantor_depth}")
269
+ print(f" Batch size: {cfg.batch_size}")
270
+ print(f" Learning rate: {cfg.learning_rate}")
271
+ print(f" Device: {cfg.device}")
272
+
273
+ # Initialize wandb
274
+ if cfg.use_wandb:
275
+ wandb.init(project=cfg.wandb_project, config=vars(cfg))
276
+
277
+ # Load dataset
278
+ print("\nLoading dataset...")
279
+ dataset = load_dataset(cfg.dataset_name, name="clip_vit_b16", split="train")
280
+
281
+ # Split into train/val (90/10)
282
+ dataset = dataset.train_test_split(test_size=0.1, seed=cfg.seed)
283
+ train_dataset = CLIPFeaturesDataset(dataset['train'])
284
+ val_dataset = CLIPFeaturesDataset(dataset['test'])
285
+
286
+ print(f"Train samples: {len(train_dataset)}")
287
+ print(f"Val samples: {len(val_dataset)}")
288
+
289
+ # Create dataloaders
290
+ train_loader = DataLoader(
291
+ train_dataset,
292
+ batch_size=cfg.batch_size,
293
+ shuffle=True,
294
+ num_workers=cfg.num_workers,
295
+ pin_memory=True
296
+ )
297
+ val_loader = DataLoader(
298
+ val_dataset,
299
+ batch_size=cfg.batch_size,
300
+ shuffle=False,
301
+ num_workers=cfg.num_workers,
302
+ pin_memory=True
303
+ )
304
+
305
+ # Create model
306
+ print("\nBuilding model...")
307
+ model = CantorCLIPClassifier(cfg).to(cfg.device)
308
+
309
+ # Print model info
310
+ total_params = sum(p.numel() for p in model.parameters())
311
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
312
+ print(f"Total parameters: {total_params:,}")
313
+ print(f"Trainable parameters: {trainable_params:,}")
314
+
315
+ # Alpha statistics
316
+ stats = model.get_alpha_stats()
317
+ if stats['alpha_means']:
318
+ print(f"CantorLinear layers: {len(stats['alpha_means'])}")
319
+ print(f"Avg mask density: {sum(stats['mask_densities'])/len(stats['mask_densities']):.4f}")
320
+
321
+ # Loss and optimizer
322
+ criterion = nn.CrossEntropyLoss()
323
+
324
+ # Separate learning rates for alpha parameters
325
+ alpha_params = []
326
+ other_params = []
327
+ for name, param in model.named_parameters():
328
+ if 'alpha' in name:
329
+ alpha_params.append(param)
330
+ else:
331
+ other_params.append(param)
332
+
333
+ optimizer = optim.AdamW([
334
+ {'params': other_params, 'lr': cfg.learning_rate},
335
+ {'params': alpha_params, 'lr': cfg.learning_rate * cfg.alpha_lr_mult}
336
+ ], weight_decay=cfg.weight_decay)
337
+
338
+ # Learning rate scheduler with warmup
339
+ total_steps = len(train_loader) * cfg.num_epochs
340
+ warmup_steps = len(train_loader) * cfg.warmup_epochs
341
+
342
+ def lr_lambda(step):
343
+ if step < warmup_steps:
344
+ return step / warmup_steps
345
+ else:
346
+ return 0.5 * (1 + math.cos(math.pi * (step - warmup_steps) / (total_steps - warmup_steps)))
347
+
348
+ scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
349
+
350
+ # Training loop
351
+ print("\nStarting training...")
352
+ best_val_acc = 0.0
353
+
354
+ for epoch in range(cfg.num_epochs):
355
+ train_loss, train_acc = train_epoch(
356
+ model, train_loader, criterion, optimizer, scheduler, cfg, epoch
357
+ )
358
+
359
+ val_loss, val_acc = evaluate(model, val_loader, criterion, cfg)
360
+
361
+ print(f"\nEpoch {epoch+1}/{cfg.num_epochs}")
362
+ print(f" Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
363
+ print(f" Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")
364
+
365
+ # Log alpha evolution
366
+ stats = model.get_alpha_stats()
367
+ if stats['alpha_means']:
368
+ mean_alpha = sum(stats['alpha_means']) / len(stats['alpha_means'])
369
+ mean_density = sum(stats['mask_densities']) / len(stats['mask_densities'])
370
+ print(f" Mean Alpha: {mean_alpha:.4f} | Mean Density: {mean_density:.4f}")
371
+
372
+ if cfg.use_wandb:
373
+ wandb.log({
374
+ 'val/loss': val_loss,
375
+ 'val/acc': val_acc,
376
+ 'alpha/mean': mean_alpha,
377
+ 'alpha/density': mean_density,
378
+ 'epoch': epoch
379
+ })
380
+
381
+ # Save best model
382
+ if val_acc > best_val_acc:
383
+ best_val_acc = val_acc
384
+ torch.save({
385
+ 'epoch': epoch,
386
+ 'model_state_dict': model.state_dict(),
387
+ 'optimizer_state_dict': optimizer.state_dict(),
388
+ 'val_acc': val_acc,
389
+ 'config': cfg
390
+ }, 'best_cantor_imagenet.pt')
391
+ print(f" ✓ New best model saved! (Val Acc: {val_acc:.2f}%)")
392
+
393
+ print("\n" + "=" * 60)
394
+ print(f"Training complete! Best Val Acc: {best_val_acc:.2f}%")
395
+ print("=" * 60)
396
+
397
+ if cfg.use_wandb:
398
+ wandb.finish()
399
+
400
+
401
+ if __name__ == "__main__":
402
+ main()