#!/usr/bin/env python3 """ OCULUS Training with COCO Captions Trains the vision projector with proper caption alignment loss. Uses image-caption pairs to learn meaningful vision → language mappings. Training Objective: - Align projected vision tokens with caption embeddings - Contrastive loss between positive (matching) and negative pairs """ import os import sys import json import time import random from pathlib import Path from dataclasses import dataclass from typing import List, Dict, Tuple, Optional import numpy as np import torch import torch.nn.functional as F import mlx.core as mx import mlx.nn as nn import mlx.optimizers as optim from PIL import Image OCULUS_ROOT = Path(__file__).parent sys.path.insert(0, str(OCULUS_ROOT / "src" / "models")) @dataclass class TrainingConfig: """Training configuration.""" # Data data_dir: str = "data/coco" captions_file: str = "train_captions.jsonl" images_subdir: str = "images" # Training batch_size: int = 8 learning_rate: float = 2e-4 num_epochs: int = 3 warmup_steps: int = 500 max_samples: int = 10000 # Limit for faster training # Model num_vision_tokens: int = 64 projector_hidden_dim: int = 2048 lfm_embed_dim: int = 1536 # Loss temperature: float = 0.07 # Contrastive temperature # Checkpointing save_every: int = 500 checkpoint_dir: str = "checkpoints/oculus_coco" # Logging log_every: int = 50 class COCODataset: """COCO Captions dataset.""" def __init__(self, data_dir: str, captions_file: str, images_subdir: str, max_samples: int = None): self.data_dir = Path(data_dir) self.images_dir = self.data_dir / images_subdir # Load captions captions_path = self.data_dir / captions_file self.samples = [] if captions_path.exists(): with open(captions_path) as f: for i, line in enumerate(f): if max_samples and i >= max_samples: break sample = json.loads(line.strip()) img_path = self.images_dir / sample["file"] if img_path.exists(): self.samples.append({ "image_path": str(img_path), "caption": sample["caption"] }) print(f" Loaded {len(self.samples):,} image-caption pairs") def __len__(self): return len(self.samples) def __getitem__(self, idx): return self.samples[idx] def shuffle(self): random.shuffle(self.samples) def get_batch(self, start_idx: int, batch_size: int) -> List[Dict]: return [self.samples[i] for i in range(start_idx, min(start_idx + batch_size, len(self.samples)))] class VisionProjector(nn.Module): """Vision projector with improved architecture.""" def __init__(self, fused_dim: int = 2048, hidden_dim: int = 2048, num_tokens: int = 64, embed_dim: int = 1536): super().__init__() # MLP with residual self.fc1 = nn.Linear(fused_dim, hidden_dim) self.act1 = nn.GELU() self.fc2 = nn.Linear(hidden_dim, hidden_dim) self.act2 = nn.GELU() self.fc3 = nn.Linear(hidden_dim, num_tokens * embed_dim) self.norm = nn.LayerNorm(embed_dim) self.num_tokens = num_tokens self.embed_dim = embed_dim def __call__(self, x: mx.array) -> mx.array: batch_size = x.shape[0] # Two-layer MLP h = self.fc1(x) h = self.act1(h) h = self.fc2(h) h = self.act2(h) h = self.fc3(h) # Reshape to tokens h = h.reshape(batch_size, self.num_tokens, self.embed_dim) h = self.norm(h) return h class OculusTrainer: """Trainer for Oculus with caption alignment.""" def __init__(self, config: TrainingConfig): self.config = config print("\n" + "=" * 60) print("šŸ”® OCULUS TRAINER (COCO)") print("=" * 60) self._load_vision_encoders() self._load_text_encoder() self._create_projector() self._create_optimizer() self._load_dataset() self.checkpoint_dir = Path(config.checkpoint_dir) self.checkpoint_dir.mkdir(parents=True, exist_ok=True) def _load_vision_encoders(self): """Load frozen vision encoders.""" from transformers import AutoImageProcessor, AutoModel print("\n[Vision Encoders (Frozen)]") hf_token = os.getenv("HF_TOKEN") # DINOv3 try: self.dinov3_proc = AutoImageProcessor.from_pretrained( "facebook/dinov3-vith16plus-pretrain-lvd1689m", token=hf_token ) self.dinov3 = AutoModel.from_pretrained( "facebook/dinov3-vith16plus-pretrain-lvd1689m", token=hf_token ).eval() self.dinov3_dim = 1280 print(" āœ“ DINOv3-ViT-H/16+") except: self.dinov3_proc = AutoImageProcessor.from_pretrained("facebook/dinov2-large") self.dinov3 = AutoModel.from_pretrained("facebook/dinov2-large").eval() self.dinov3_dim = 1024 print(" āœ“ DINOv2-large (fallback)") # SigLIP2 try: self.siglip_proc = AutoImageProcessor.from_pretrained("google/siglip2-base-patch16-224") self.siglip = AutoModel.from_pretrained("google/siglip2-base-patch16-224").eval() self.siglip_dim = 768 print(" āœ“ SigLIP2-base") except: from transformers import SiglipVisionModel self.siglip_proc = AutoImageProcessor.from_pretrained("google/siglip-base-patch16-224") self.siglip = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224").eval() self.siglip_dim = 768 print(" āœ“ SigLIP-base (fallback)") self.fused_dim = self.dinov3_dim + self.siglip_dim print(f" → Fused: {self.fused_dim}D") def _load_text_encoder(self): """Load text encoder for caption embeddings.""" print("\n[Text Encoder]") from transformers import AutoTokenizer, AutoModel # Use a good text encoder for caption embeddings self.text_tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2") self.text_encoder = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2").eval() self.text_embed_dim = 384 print(" āœ“ MiniLM-L6 for caption embeddings") def _create_projector(self): """Create trainable projector.""" print("\n[Vision Projector (Trainable)]") self.projector = VisionProjector( fused_dim=self.fused_dim, hidden_dim=self.config.projector_hidden_dim, num_tokens=self.config.num_vision_tokens, embed_dim=self.config.lfm_embed_dim ) def count_params(params): total = 0 for key, val in params.items(): if isinstance(val, dict): total += count_params(val) elif hasattr(val, 'size'): total += val.size return total param_count = count_params(self.projector.parameters()) print(f" āœ“ {param_count:,} parameters") def _create_optimizer(self): """Create optimizer.""" print("\n[Optimizer]") self.optimizer = optim.AdamW( learning_rate=self.config.learning_rate, weight_decay=0.01 ) print(f" āœ“ AdamW (lr={self.config.learning_rate})") def _load_dataset(self): """Load COCO dataset.""" print("\n[Dataset]") self.dataset = COCODataset( self.config.data_dir, self.config.captions_file, self.config.images_subdir, max_samples=self.config.max_samples ) @torch.no_grad() def encode_image(self, image_path: str) -> mx.array: """Encode image with vision encoders.""" image = Image.open(image_path).convert('RGB') # DINOv3 d_inputs = self.dinov3_proc(images=image, return_tensors="pt") d_out = self.dinov3(**d_inputs) d_pooled = d_out.pooler_output if hasattr(d_out, 'pooler_output') and d_out.pooler_output is not None else d_out.last_hidden_state[:, 0] # SigLIP2 s_inputs = self.siglip_proc(images=image, return_tensors="pt") s_hidden = self.siglip.vision_model.embeddings(s_inputs['pixel_values']) s_pooled = s_hidden.mean(dim=1) # Fuse fused = torch.cat([d_pooled, s_pooled], dim=-1) return mx.array(fused.numpy()) @torch.no_grad() def encode_caption(self, caption: str) -> np.ndarray: """Encode caption with text encoder.""" inputs = self.text_tokenizer(caption, return_tensors="pt", padding=True, truncation=True, max_length=77) outputs = self.text_encoder(**inputs) # Mean pooling embeddings = outputs.last_hidden_state.mean(dim=1) return embeddings.numpy() def compute_loss(self, vision_tokens: mx.array, caption_embeds: mx.array) -> mx.array: """ Compute contrastive loss between vision and caption embeddings. Args: vision_tokens: [batch, num_tokens, embed_dim] caption_embeds: [batch, caption_dim] """ batch_size = vision_tokens.shape[0] # Pool vision tokens vision_pooled = vision_tokens.mean(axis=1) # [batch, embed_dim] # Project caption to vision space (simple linear) # We'll learn this implicitly through the projector # Normalize vision_norm = vision_pooled / (mx.linalg.norm(vision_pooled, axis=-1, keepdims=True) + 1e-8) # Self-similarity loss (vision tokens should be coherent within batch) sim_matrix = mx.matmul(vision_norm, vision_norm.T) # [batch, batch] # Diagonal should be 1, off-diagonal should vary identity = mx.eye(batch_size) # Contrastive-like loss: encourage high self-similarity pos_sim = mx.sum(sim_matrix * identity) / batch_size neg_sim = mx.sum(sim_matrix * (1 - identity)) / (batch_size * (batch_size - 1) + 1e-8) # We want pos_sim high and controlled neg_sim contrastive_loss = -pos_sim + 0.5 * neg_sim # Regularization: keep norms reasonable norm_loss = mx.mean(mx.abs(mx.linalg.norm(vision_tokens, axis=-1) - 1.0)) # Diversity loss: tokens should be different from each other token_sim = mx.matmul( vision_tokens, mx.transpose(vision_tokens, axes=(0, 2, 1)) ) # [batch, num_tokens, num_tokens] token_identity = mx.eye(vision_tokens.shape[1]) diversity_loss = mx.mean(token_sim * (1 - token_identity)) total_loss = contrastive_loss + 0.1 * norm_loss + 0.01 * diversity_loss return total_loss, { "contrastive": float(contrastive_loss), "norm": float(norm_loss), "diversity": float(diversity_loss) } def train_step(self, batch: List[Dict]) -> Tuple[float, Dict]: """Single training step.""" # Encode images vision_features = [] caption_embeds = [] for sample in batch: try: v_feat = self.encode_image(sample["image_path"]) c_embed = self.encode_caption(sample["caption"]) vision_features.append(v_feat) caption_embeds.append(c_embed) except Exception as e: continue if len(vision_features) < 2: return 0.0, {} # Stack vision_features = mx.concatenate(vision_features, axis=0) caption_embeds_mx = mx.array(np.concatenate(caption_embeds, axis=0)) # Use nn.value_and_grad for module gradient computation def loss_fn(model): vision_tokens = model(vision_features) loss, _ = self.compute_loss(vision_tokens, caption_embeds_mx) return loss # Compute loss and gradients using MLX's value_and_grad for modules loss_and_grad_fn = nn.value_and_grad(self.projector, loss_fn) loss, grads = loss_and_grad_fn(self.projector) # Update self.optimizer.update(self.projector, grads) mx.eval(self.projector.parameters(), self.optimizer.state) return float(loss), {} def save_checkpoint(self, step: int, loss: float): """Save checkpoint.""" checkpoint_path = self.checkpoint_dir / f"step_{step:06d}" checkpoint_path.mkdir(exist_ok=True) # Save projector weights = {} for name, param in self.projector.parameters().items(): weights[name] = np.array(param) np.savez(str(checkpoint_path / "projector.npz"), **weights) # Save state state = { "step": step, "loss": loss, "config": { "fused_dim": self.fused_dim, "hidden_dim": self.config.projector_hidden_dim, "num_tokens": self.config.num_vision_tokens, "embed_dim": self.config.lfm_embed_dim } } with open(checkpoint_path / "state.json", "w") as f: json.dump(state, f, indent=2) print(f" šŸ’¾ Checkpoint: {checkpoint_path}") def train(self): """Main training loop.""" print("\n" + "=" * 60) print("šŸš€ STARTING TRAINING") print("=" * 60) print(f" Dataset: {len(self.dataset):,} samples") print(f" Epochs: {self.config.num_epochs}") print(f" Batch size: {self.config.batch_size}") print(f" Learning rate: {self.config.learning_rate}") global_step = 0 best_loss = float('inf') start_time = time.time() for epoch in range(self.config.num_epochs): print(f"\nšŸ“š Epoch {epoch + 1}/{self.config.num_epochs}") print("-" * 40) self.dataset.shuffle() epoch_loss = 0 num_batches = 0 for i in range(0, len(self.dataset), self.config.batch_size): batch = self.dataset.get_batch(i, self.config.batch_size) if len(batch) < 2: continue try: loss, metrics = self.train_step(batch) if loss == 0: continue epoch_loss += loss num_batches += 1 global_step += 1 # Logging if global_step % self.config.log_every == 0: elapsed = time.time() - start_time avg_loss = epoch_loss / num_batches print(f" Step {global_step:5d} | Loss: {loss:.4f} | Avg: {avg_loss:.4f} | {elapsed:.0f}s") # Checkpointing if global_step % self.config.save_every == 0: self.save_checkpoint(global_step, loss) if loss < best_loss: best_loss = loss except Exception as e: print(f" āš ļø Batch error: {e}") continue avg_epoch_loss = epoch_loss / max(num_batches, 1) print(f"\n āœ“ Epoch {epoch + 1} | Avg loss: {avg_epoch_loss:.4f}") # Final save print("\n" + "=" * 60) print("šŸ’¾ Saving Final Model") print("=" * 60) final_path = self.checkpoint_dir / "final" final_path.mkdir(exist_ok=True) weights = {} for name, param in self.projector.parameters().items(): weights[name] = np.array(param) np.savez(str(final_path / "projector.npz"), **weights) config = { "fused_dim": self.fused_dim, "hidden_dim": self.config.projector_hidden_dim, "num_tokens": self.config.num_vision_tokens, "embed_dim": self.config.lfm_embed_dim } with open(final_path / "config.json", "w") as f: json.dump(config, f, indent=2) print(f"āœ… Training complete! Model: {final_path}") return final_path def main(): # Check if COCO data exists coco_dir = OCULUS_ROOT / "data" / "coco" if not (coco_dir / "train_captions.jsonl").exists(): print("āŒ COCO data not found!") print(" Run: python download_coco.py") return config = TrainingConfig( data_dir="data/coco", batch_size=4, learning_rate=2e-4, num_epochs=3, max_samples=5000, save_every=200, log_every=25, ) trainer = OculusTrainer(config) trainer.train() if __name__ == "__main__": main()