kobiakor15 commited on
Commit
fbcbc74
·
verified ·
1 Parent(s): 04dc70a

Upload training/train_oculus.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. training/train_oculus.py +446 -0
training/train_oculus.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ OCULUS Training Script
4
+
5
+ Trains the vision projector to map DINOv3+SigLIP2 features to LFM2.5 embeddings.
6
+ Uses COCO-style or local image-caption pairs.
7
+
8
+ What gets trained:
9
+ - VisionProjector (the MLP that maps 2048D → 64×1536D)
10
+
11
+ What stays frozen:
12
+ - DINOv3 encoder
13
+ - SigLIP2 encoder
14
+ - LFM2.5 language model
15
+ """
16
+
17
+ import os
18
+ import sys
19
+ import json
20
+ import time
21
+ import random
22
+ from pathlib import Path
23
+ from dataclasses import dataclass
24
+ from typing import List, Dict, Tuple, Optional
25
+
26
+ import numpy as np
27
+ import torch
28
+ import mlx.core as mx
29
+ import mlx.nn as nn
30
+ import mlx.optimizers as optim
31
+ from PIL import Image
32
+
33
+ # Add models path
34
+ OCULUS_ROOT = Path(__file__).parent
35
+ sys.path.insert(0, str(OCULUS_ROOT / "src" / "models"))
36
+
37
+
38
+ @dataclass
39
+ class TrainingConfig:
40
+ """Training configuration."""
41
+ # Data
42
+ data_dir: str = "data/train"
43
+ captions_file: str = "captions.jsonl"
44
+
45
+ # Training
46
+ batch_size: int = 4
47
+ learning_rate: float = 1e-4
48
+ num_epochs: int = 10
49
+ warmup_steps: int = 100
50
+ gradient_accumulation: int = 1
51
+
52
+ # Model
53
+ num_vision_tokens: int = 64
54
+ projector_hidden_dim: int = 2048
55
+
56
+ # Checkpointing
57
+ save_every: int = 100
58
+ checkpoint_dir: str = "checkpoints/oculus"
59
+
60
+ # Logging
61
+ log_every: int = 10
62
+
63
+
64
+ class CaptionDataset:
65
+ """Dataset for image-caption pairs."""
66
+
67
+ def __init__(self, data_dir: str, captions_file: str):
68
+ self.data_dir = Path(data_dir)
69
+ self.images_dir = self.data_dir / "images"
70
+
71
+ # Load captions
72
+ captions_path = self.data_dir / captions_file
73
+ self.samples = []
74
+
75
+ if captions_path.exists():
76
+ with open(captions_path) as f:
77
+ for line in f:
78
+ sample = json.loads(line.strip())
79
+ img_path = self.images_dir / sample["file"]
80
+ if img_path.exists():
81
+ self.samples.append({
82
+ "image_path": str(img_path),
83
+ "caption": sample["caption"]
84
+ })
85
+
86
+ print(f" Loaded {len(self.samples)} image-caption pairs")
87
+
88
+ def __len__(self):
89
+ return len(self.samples)
90
+
91
+ def __getitem__(self, idx):
92
+ return self.samples[idx]
93
+
94
+ def shuffle(self):
95
+ random.shuffle(self.samples)
96
+
97
+
98
+ class VisionProjector(nn.Module):
99
+ """Trainable vision projector (MLX)."""
100
+
101
+ def __init__(self, fused_dim: int = 2048, hidden_dim: int = 2048,
102
+ num_tokens: int = 64, embed_dim: int = 1536):
103
+ super().__init__()
104
+
105
+ self.fc1 = nn.Linear(fused_dim, hidden_dim)
106
+ self.act = nn.GELU()
107
+ self.fc2 = nn.Linear(hidden_dim, num_tokens * embed_dim)
108
+ self.norm = nn.LayerNorm(embed_dim)
109
+
110
+ self.num_tokens = num_tokens
111
+ self.embed_dim = embed_dim
112
+
113
+ def __call__(self, x: mx.array) -> mx.array:
114
+ batch_size = x.shape[0]
115
+
116
+ x = self.fc1(x)
117
+ x = self.act(x)
118
+ x = self.fc2(x)
119
+ x = x.reshape(batch_size, self.num_tokens, self.embed_dim)
120
+ x = self.norm(x)
121
+
122
+ return x
123
+
124
+
125
+ class OculusTrainer:
126
+ """Trainer for Oculus vision projector."""
127
+
128
+ def __init__(self, config: TrainingConfig):
129
+ self.config = config
130
+
131
+ print("\n" + "=" * 60)
132
+ print("🔮 OCULUS TRAINER")
133
+ print("=" * 60)
134
+
135
+ # Load vision encoders
136
+ self._load_vision_encoders()
137
+
138
+ # Create projector
139
+ self._create_projector()
140
+
141
+ # Load LLM tokenizer (for encoding captions)
142
+ self._load_tokenizer()
143
+
144
+ # Create optimizer
145
+ self._create_optimizer()
146
+
147
+ # Load dataset
148
+ self._load_dataset()
149
+
150
+ # Create checkpoint directory
151
+ self.checkpoint_dir = Path(config.checkpoint_dir)
152
+ self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
153
+
154
+ def _load_vision_encoders(self):
155
+ """Load frozen vision encoders."""
156
+ from transformers import AutoImageProcessor, AutoModel
157
+
158
+ print("\n[Loading Vision Encoders (Frozen)]")
159
+
160
+ hf_token = os.getenv("HF_TOKEN")
161
+
162
+ # DINOv3
163
+ try:
164
+ self.dinov3_proc = AutoImageProcessor.from_pretrained(
165
+ "facebook/dinov3-vith16plus-pretrain-lvd1689m", token=hf_token
166
+ )
167
+ self.dinov3 = AutoModel.from_pretrained(
168
+ "facebook/dinov3-vith16plus-pretrain-lvd1689m", token=hf_token
169
+ ).eval()
170
+ self.dinov3_dim = 1280
171
+ print(" ✓ DINOv3-ViT-H/16+")
172
+ except:
173
+ self.dinov3_proc = AutoImageProcessor.from_pretrained("facebook/dinov2-large")
174
+ self.dinov3 = AutoModel.from_pretrained("facebook/dinov2-large").eval()
175
+ self.dinov3_dim = 1024
176
+ print(" ✓ DINOv2-large (fallback)")
177
+
178
+ # SigLIP2
179
+ try:
180
+ self.siglip_proc = AutoImageProcessor.from_pretrained("google/siglip2-base-patch16-224")
181
+ self.siglip = AutoModel.from_pretrained("google/siglip2-base-patch16-224").eval()
182
+ self.siglip_dim = 768
183
+ print(" ✓ SigLIP2-base")
184
+ except:
185
+ from transformers import SiglipVisionModel
186
+ self.siglip_proc = AutoImageProcessor.from_pretrained("google/siglip-base-patch16-224")
187
+ self.siglip = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224").eval()
188
+ self.siglip_dim = 768
189
+ print(" ✓ SigLIP-base (fallback)")
190
+
191
+ self.fused_dim = self.dinov3_dim + self.siglip_dim
192
+ print(f" → Fused dimension: {self.fused_dim}D")
193
+
194
+ def _create_projector(self):
195
+ """Create trainable projector."""
196
+ print("\n[Creating Vision Projector (Trainable)]")
197
+
198
+ self.projector = VisionProjector(
199
+ fused_dim=self.fused_dim,
200
+ hidden_dim=self.config.projector_hidden_dim,
201
+ num_tokens=self.config.num_vision_tokens,
202
+ embed_dim=1536 # LFM2.5 embedding dim
203
+ )
204
+
205
+ # Count parameters
206
+ def count_params(params):
207
+ total = 0
208
+ for key, val in params.items():
209
+ if isinstance(val, dict):
210
+ total += count_params(val)
211
+ elif hasattr(val, 'size'):
212
+ total += val.size
213
+ elif hasattr(val, 'shape'):
214
+ total += np.prod(val.shape)
215
+ return total
216
+
217
+ param_count = count_params(self.projector.parameters())
218
+ print(f" ✓ Projector: {param_count:,} trainable parameters")
219
+
220
+ def _load_tokenizer(self):
221
+ """Load LFM2.5 tokenizer."""
222
+ print("\n[Loading LFM2.5 Tokenizer]")
223
+
224
+ from mlx_lm import load
225
+ _, self.tokenizer = load("LiquidAI/LFM2.5-1.2B-Instruct-MLX-bf16")
226
+ print(" ✓ Tokenizer loaded")
227
+
228
+ def _create_optimizer(self):
229
+ """Create optimizer with warmup."""
230
+ print("\n[Creating Optimizer]")
231
+
232
+ self.optimizer = optim.AdamW(
233
+ learning_rate=self.config.learning_rate,
234
+ weight_decay=0.01
235
+ )
236
+ print(f" ✓ AdamW (lr={self.config.learning_rate})")
237
+
238
+ def _load_dataset(self):
239
+ """Load training data."""
240
+ print("\n[Loading Dataset]")
241
+
242
+ self.dataset = CaptionDataset(
243
+ self.config.data_dir,
244
+ self.config.captions_file
245
+ )
246
+
247
+ @torch.no_grad()
248
+ def encode_image(self, image_path: str) -> mx.array:
249
+ """Encode image with frozen vision encoders."""
250
+ image = Image.open(image_path).convert('RGB')
251
+
252
+ # DINOv3
253
+ d_inputs = self.dinov3_proc(images=image, return_tensors="pt")
254
+ d_out = self.dinov3(**d_inputs)
255
+ d_pooled = d_out.pooler_output if hasattr(d_out, 'pooler_output') and d_out.pooler_output is not None else d_out.last_hidden_state[:, 0]
256
+
257
+ # SigLIP2
258
+ s_inputs = self.siglip_proc(images=image, return_tensors="pt")
259
+ s_hidden = self.siglip.vision_model.embeddings(s_inputs['pixel_values'])
260
+ s_pooled = s_hidden.mean(dim=1)
261
+
262
+ # Fuse
263
+ fused = torch.cat([d_pooled, s_pooled], dim=-1)
264
+
265
+ return mx.array(fused.numpy())
266
+
267
+ def compute_loss(self, vision_tokens: mx.array, caption_tokens: mx.array) -> mx.array:
268
+ """
269
+ Compute contrastive loss between vision tokens and caption embeddings.
270
+
271
+ We use a simplified alignment loss that encourages vision tokens
272
+ to be similar to the caption's semantic representation.
273
+ """
274
+ # Vision token mean pooling
275
+ vision_pooled = vision_tokens.mean(axis=1) # [batch, embed_dim]
276
+
277
+ # Normalize
278
+ vision_norm = vision_pooled / (mx.linalg.norm(vision_pooled, axis=-1, keepdims=True) + 1e-8)
279
+
280
+ # Self-consistency loss (vision tokens should be coherent)
281
+ # Encourage all vision tokens to be similar to each other
282
+ token_sims = mx.matmul(vision_tokens, vision_tokens.transpose(0, 2, 1)) # [batch, num_tokens, num_tokens]
283
+ token_loss = -mx.mean(token_sims)
284
+
285
+ # Regularization loss (prevent collapse to zero or explosion)
286
+ norm_loss = mx.mean(mx.abs(mx.linalg.norm(vision_tokens, axis=-1) - 1.0))
287
+
288
+ # Combined loss
289
+ loss = token_loss * 0.1 + norm_loss
290
+
291
+ return loss
292
+
293
+ def train_step(self, batch: List[Dict]) -> float:
294
+ """Single training step."""
295
+
296
+ # Encode images
297
+ vision_features = []
298
+ for sample in batch:
299
+ features = self.encode_image(sample["image_path"])
300
+ vision_features.append(features)
301
+
302
+ # Stack
303
+ vision_features = mx.concatenate(vision_features, axis=0)
304
+
305
+ # Tokenize captions (for potential future use with caption loss)
306
+ # For now, we train projector with self-consistency
307
+
308
+ # Forward + backward
309
+ def loss_fn(model):
310
+ vision_tokens = model(vision_features)
311
+ return self.compute_loss(vision_tokens, None)
312
+
313
+ loss, grads = mx.value_and_grad(loss_fn)(self.projector)
314
+
315
+ # Update
316
+ self.optimizer.update(self.projector, grads)
317
+ mx.eval(self.projector.parameters(), self.optimizer.state)
318
+
319
+ return float(loss)
320
+
321
+ def save_checkpoint(self, step: int, loss: float):
322
+ """Save checkpoint."""
323
+ checkpoint_path = self.checkpoint_dir / f"step_{step:06d}"
324
+ checkpoint_path.mkdir(exist_ok=True)
325
+
326
+ # Save projector weights
327
+ weights = {}
328
+ for name, param in self.projector.parameters().items():
329
+ weights[name] = np.array(param)
330
+ np.savez(str(checkpoint_path / "projector.npz"), **weights)
331
+
332
+ # Save training state
333
+ state = {
334
+ "step": step,
335
+ "loss": loss,
336
+ "config": {
337
+ "fused_dim": self.fused_dim,
338
+ "hidden_dim": self.config.projector_hidden_dim,
339
+ "num_tokens": self.config.num_vision_tokens,
340
+ "embed_dim": 1536
341
+ }
342
+ }
343
+ with open(checkpoint_path / "state.json", "w") as f:
344
+ json.dump(state, f, indent=2)
345
+
346
+ print(f" 💾 Saved checkpoint to {checkpoint_path}")
347
+
348
+ def train(self):
349
+ """Main training loop."""
350
+ print("\n" + "=" * 60)
351
+ print("🚀 STARTING TRAINING")
352
+ print("=" * 60)
353
+ print(f" Epochs: {self.config.num_epochs}")
354
+ print(f" Batch size: {self.config.batch_size}")
355
+ print(f" Learning rate: {self.config.learning_rate}")
356
+ print(f" Dataset size: {len(self.dataset)} samples")
357
+
358
+ global_step = 0
359
+ total_loss = 0
360
+ start_time = time.time()
361
+
362
+ for epoch in range(self.config.num_epochs):
363
+ print(f"\n📚 Epoch {epoch + 1}/{self.config.num_epochs}")
364
+ print("-" * 40)
365
+
366
+ self.dataset.shuffle()
367
+ epoch_loss = 0
368
+ num_batches = 0
369
+
370
+ # Batch loop
371
+ for i in range(0, len(self.dataset), self.config.batch_size):
372
+ batch = [self.dataset[j] for j in range(i, min(i + self.config.batch_size, len(self.dataset)))]
373
+
374
+ if len(batch) < 2:
375
+ continue
376
+
377
+ try:
378
+ loss = self.train_step(batch)
379
+ epoch_loss += loss
380
+ total_loss += loss
381
+ num_batches += 1
382
+ global_step += 1
383
+
384
+ # Logging
385
+ if global_step % self.config.log_every == 0:
386
+ avg_loss = total_loss / global_step
387
+ elapsed = time.time() - start_time
388
+ print(f" Step {global_step:5d} | Loss: {loss:.4f} | Avg: {avg_loss:.4f} | Time: {elapsed:.1f}s")
389
+
390
+ # Checkpointing
391
+ if global_step % self.config.save_every == 0:
392
+ self.save_checkpoint(global_step, loss)
393
+
394
+ except Exception as e:
395
+ print(f" ⚠️ Error in batch: {e}")
396
+ continue
397
+
398
+ # Epoch summary
399
+ avg_epoch_loss = epoch_loss / max(num_batches, 1)
400
+ print(f"\n ✓ Epoch {epoch + 1} complete | Avg loss: {avg_epoch_loss:.4f}")
401
+
402
+ # Final save
403
+ print("\n" + "=" * 60)
404
+ print("💾 Saving Final Model")
405
+ print("=" * 60)
406
+
407
+ final_path = self.checkpoint_dir / "final"
408
+ final_path.mkdir(exist_ok=True)
409
+
410
+ weights = {}
411
+ for name, param in self.projector.parameters().items():
412
+ weights[name] = np.array(param)
413
+ np.savez(str(final_path / "projector.npz"), **weights)
414
+
415
+ # Save config
416
+ config = {
417
+ "fused_dim": self.fused_dim,
418
+ "hidden_dim": self.config.projector_hidden_dim,
419
+ "num_tokens": self.config.num_vision_tokens,
420
+ "embed_dim": 1536
421
+ }
422
+ with open(final_path / "config.json", "w") as f:
423
+ json.dump(config, f, indent=2)
424
+
425
+ print(f"✅ Training complete! Model saved to {final_path}")
426
+
427
+ return final_path
428
+
429
+
430
+ def main():
431
+ """Run training."""
432
+ config = TrainingConfig(
433
+ data_dir="data/train",
434
+ batch_size=2, # Small for demo
435
+ learning_rate=1e-4,
436
+ num_epochs=5,
437
+ save_every=50,
438
+ log_every=5,
439
+ )
440
+
441
+ trainer = OculusTrainer(config)
442
+ trainer.train()
443
+
444
+
445
+ if __name__ == "__main__":
446
+ main()