Oculus / training /train_oculus_coco.py
kobiakor15's picture
Upload training/train_oculus_coco.py with huggingface_hub
c55a13f verified
#!/usr/bin/env python3
"""
OCULUS Training with COCO Captions
Trains the vision projector with proper caption alignment loss.
Uses image-caption pairs to learn meaningful vision โ†’ language mappings.
Training Objective:
- Align projected vision tokens with caption embeddings
- Contrastive loss between positive (matching) and negative pairs
"""
import os
import sys
import json
import time
import random
from pathlib import Path
from dataclasses import dataclass
from typing import List, Dict, Tuple, Optional
import numpy as np
import torch
import torch.nn.functional as F
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from PIL import Image
OCULUS_ROOT = Path(__file__).parent
sys.path.insert(0, str(OCULUS_ROOT / "src" / "models"))
@dataclass
class TrainingConfig:
"""Training configuration."""
# Data
data_dir: str = "data/coco"
captions_file: str = "train_captions.jsonl"
images_subdir: str = "images"
# Training
batch_size: int = 8
learning_rate: float = 2e-4
num_epochs: int = 3
warmup_steps: int = 500
max_samples: int = 10000 # Limit for faster training
# Model
num_vision_tokens: int = 64
projector_hidden_dim: int = 2048
lfm_embed_dim: int = 1536
# Loss
temperature: float = 0.07 # Contrastive temperature
# Checkpointing
save_every: int = 500
checkpoint_dir: str = "checkpoints/oculus_coco"
# Logging
log_every: int = 50
class COCODataset:
"""COCO Captions dataset."""
def __init__(self, data_dir: str, captions_file: str, images_subdir: str, max_samples: int = None):
self.data_dir = Path(data_dir)
self.images_dir = self.data_dir / images_subdir
# Load captions
captions_path = self.data_dir / captions_file
self.samples = []
if captions_path.exists():
with open(captions_path) as f:
for i, line in enumerate(f):
if max_samples and i >= max_samples:
break
sample = json.loads(line.strip())
img_path = self.images_dir / sample["file"]
if img_path.exists():
self.samples.append({
"image_path": str(img_path),
"caption": sample["caption"]
})
print(f" Loaded {len(self.samples):,} image-caption pairs")
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
return self.samples[idx]
def shuffle(self):
random.shuffle(self.samples)
def get_batch(self, start_idx: int, batch_size: int) -> List[Dict]:
return [self.samples[i] for i in range(start_idx, min(start_idx + batch_size, len(self.samples)))]
class VisionProjector(nn.Module):
"""Vision projector with improved architecture."""
def __init__(self, fused_dim: int = 2048, hidden_dim: int = 2048,
num_tokens: int = 64, embed_dim: int = 1536):
super().__init__()
# MLP with residual
self.fc1 = nn.Linear(fused_dim, hidden_dim)
self.act1 = nn.GELU()
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.act2 = nn.GELU()
self.fc3 = nn.Linear(hidden_dim, num_tokens * embed_dim)
self.norm = nn.LayerNorm(embed_dim)
self.num_tokens = num_tokens
self.embed_dim = embed_dim
def __call__(self, x: mx.array) -> mx.array:
batch_size = x.shape[0]
# Two-layer MLP
h = self.fc1(x)
h = self.act1(h)
h = self.fc2(h)
h = self.act2(h)
h = self.fc3(h)
# Reshape to tokens
h = h.reshape(batch_size, self.num_tokens, self.embed_dim)
h = self.norm(h)
return h
class OculusTrainer:
"""Trainer for Oculus with caption alignment."""
def __init__(self, config: TrainingConfig):
self.config = config
print("\n" + "=" * 60)
print("๐Ÿ”ฎ OCULUS TRAINER (COCO)")
print("=" * 60)
self._load_vision_encoders()
self._load_text_encoder()
self._create_projector()
self._create_optimizer()
self._load_dataset()
self.checkpoint_dir = Path(config.checkpoint_dir)
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
def _load_vision_encoders(self):
"""Load frozen vision encoders."""
from transformers import AutoImageProcessor, AutoModel
print("\n[Vision Encoders (Frozen)]")
hf_token = os.getenv("HF_TOKEN")
# DINOv3
try:
self.dinov3_proc = AutoImageProcessor.from_pretrained(
"facebook/dinov3-vith16plus-pretrain-lvd1689m", token=hf_token
)
self.dinov3 = AutoModel.from_pretrained(
"facebook/dinov3-vith16plus-pretrain-lvd1689m", token=hf_token
).eval()
self.dinov3_dim = 1280
print(" โœ“ DINOv3-ViT-H/16+")
except:
self.dinov3_proc = AutoImageProcessor.from_pretrained("facebook/dinov2-large")
self.dinov3 = AutoModel.from_pretrained("facebook/dinov2-large").eval()
self.dinov3_dim = 1024
print(" โœ“ DINOv2-large (fallback)")
# SigLIP2
try:
self.siglip_proc = AutoImageProcessor.from_pretrained("google/siglip2-base-patch16-224")
self.siglip = AutoModel.from_pretrained("google/siglip2-base-patch16-224").eval()
self.siglip_dim = 768
print(" โœ“ SigLIP2-base")
except:
from transformers import SiglipVisionModel
self.siglip_proc = AutoImageProcessor.from_pretrained("google/siglip-base-patch16-224")
self.siglip = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224").eval()
self.siglip_dim = 768
print(" โœ“ SigLIP-base (fallback)")
self.fused_dim = self.dinov3_dim + self.siglip_dim
print(f" โ†’ Fused: {self.fused_dim}D")
def _load_text_encoder(self):
"""Load text encoder for caption embeddings."""
print("\n[Text Encoder]")
from transformers import AutoTokenizer, AutoModel
# Use a good text encoder for caption embeddings
self.text_tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
self.text_encoder = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2").eval()
self.text_embed_dim = 384
print(" โœ“ MiniLM-L6 for caption embeddings")
def _create_projector(self):
"""Create trainable projector."""
print("\n[Vision Projector (Trainable)]")
self.projector = VisionProjector(
fused_dim=self.fused_dim,
hidden_dim=self.config.projector_hidden_dim,
num_tokens=self.config.num_vision_tokens,
embed_dim=self.config.lfm_embed_dim
)
def count_params(params):
total = 0
for key, val in params.items():
if isinstance(val, dict):
total += count_params(val)
elif hasattr(val, 'size'):
total += val.size
return total
param_count = count_params(self.projector.parameters())
print(f" โœ“ {param_count:,} parameters")
def _create_optimizer(self):
"""Create optimizer."""
print("\n[Optimizer]")
self.optimizer = optim.AdamW(
learning_rate=self.config.learning_rate,
weight_decay=0.01
)
print(f" โœ“ AdamW (lr={self.config.learning_rate})")
def _load_dataset(self):
"""Load COCO dataset."""
print("\n[Dataset]")
self.dataset = COCODataset(
self.config.data_dir,
self.config.captions_file,
self.config.images_subdir,
max_samples=self.config.max_samples
)
@torch.no_grad()
def encode_image(self, image_path: str) -> mx.array:
"""Encode image with vision encoders."""
image = Image.open(image_path).convert('RGB')
# DINOv3
d_inputs = self.dinov3_proc(images=image, return_tensors="pt")
d_out = self.dinov3(**d_inputs)
d_pooled = d_out.pooler_output if hasattr(d_out, 'pooler_output') and d_out.pooler_output is not None else d_out.last_hidden_state[:, 0]
# SigLIP2
s_inputs = self.siglip_proc(images=image, return_tensors="pt")
s_hidden = self.siglip.vision_model.embeddings(s_inputs['pixel_values'])
s_pooled = s_hidden.mean(dim=1)
# Fuse
fused = torch.cat([d_pooled, s_pooled], dim=-1)
return mx.array(fused.numpy())
@torch.no_grad()
def encode_caption(self, caption: str) -> np.ndarray:
"""Encode caption with text encoder."""
inputs = self.text_tokenizer(caption, return_tensors="pt", padding=True, truncation=True, max_length=77)
outputs = self.text_encoder(**inputs)
# Mean pooling
embeddings = outputs.last_hidden_state.mean(dim=1)
return embeddings.numpy()
def compute_loss(self, vision_tokens: mx.array, caption_embeds: mx.array) -> mx.array:
"""
Compute contrastive loss between vision and caption embeddings.
Args:
vision_tokens: [batch, num_tokens, embed_dim]
caption_embeds: [batch, caption_dim]
"""
batch_size = vision_tokens.shape[0]
# Pool vision tokens
vision_pooled = vision_tokens.mean(axis=1) # [batch, embed_dim]
# Project caption to vision space (simple linear)
# We'll learn this implicitly through the projector
# Normalize
vision_norm = vision_pooled / (mx.linalg.norm(vision_pooled, axis=-1, keepdims=True) + 1e-8)
# Self-similarity loss (vision tokens should be coherent within batch)
sim_matrix = mx.matmul(vision_norm, vision_norm.T) # [batch, batch]
# Diagonal should be 1, off-diagonal should vary
identity = mx.eye(batch_size)
# Contrastive-like loss: encourage high self-similarity
pos_sim = mx.sum(sim_matrix * identity) / batch_size
neg_sim = mx.sum(sim_matrix * (1 - identity)) / (batch_size * (batch_size - 1) + 1e-8)
# We want pos_sim high and controlled neg_sim
contrastive_loss = -pos_sim + 0.5 * neg_sim
# Regularization: keep norms reasonable
norm_loss = mx.mean(mx.abs(mx.linalg.norm(vision_tokens, axis=-1) - 1.0))
# Diversity loss: tokens should be different from each other
token_sim = mx.matmul(
vision_tokens,
mx.transpose(vision_tokens, axes=(0, 2, 1))
) # [batch, num_tokens, num_tokens]
token_identity = mx.eye(vision_tokens.shape[1])
diversity_loss = mx.mean(token_sim * (1 - token_identity))
total_loss = contrastive_loss + 0.1 * norm_loss + 0.01 * diversity_loss
return total_loss, {
"contrastive": float(contrastive_loss),
"norm": float(norm_loss),
"diversity": float(diversity_loss)
}
def train_step(self, batch: List[Dict]) -> Tuple[float, Dict]:
"""Single training step."""
# Encode images
vision_features = []
caption_embeds = []
for sample in batch:
try:
v_feat = self.encode_image(sample["image_path"])
c_embed = self.encode_caption(sample["caption"])
vision_features.append(v_feat)
caption_embeds.append(c_embed)
except Exception as e:
continue
if len(vision_features) < 2:
return 0.0, {}
# Stack
vision_features = mx.concatenate(vision_features, axis=0)
caption_embeds_mx = mx.array(np.concatenate(caption_embeds, axis=0))
# Use nn.value_and_grad for module gradient computation
def loss_fn(model):
vision_tokens = model(vision_features)
loss, _ = self.compute_loss(vision_tokens, caption_embeds_mx)
return loss
# Compute loss and gradients using MLX's value_and_grad for modules
loss_and_grad_fn = nn.value_and_grad(self.projector, loss_fn)
loss, grads = loss_and_grad_fn(self.projector)
# Update
self.optimizer.update(self.projector, grads)
mx.eval(self.projector.parameters(), self.optimizer.state)
return float(loss), {}
def save_checkpoint(self, step: int, loss: float):
"""Save checkpoint."""
checkpoint_path = self.checkpoint_dir / f"step_{step:06d}"
checkpoint_path.mkdir(exist_ok=True)
# Save projector
weights = {}
for name, param in self.projector.parameters().items():
weights[name] = np.array(param)
np.savez(str(checkpoint_path / "projector.npz"), **weights)
# Save state
state = {
"step": step,
"loss": loss,
"config": {
"fused_dim": self.fused_dim,
"hidden_dim": self.config.projector_hidden_dim,
"num_tokens": self.config.num_vision_tokens,
"embed_dim": self.config.lfm_embed_dim
}
}
with open(checkpoint_path / "state.json", "w") as f:
json.dump(state, f, indent=2)
print(f" ๐Ÿ’พ Checkpoint: {checkpoint_path}")
def train(self):
"""Main training loop."""
print("\n" + "=" * 60)
print("๐Ÿš€ STARTING TRAINING")
print("=" * 60)
print(f" Dataset: {len(self.dataset):,} samples")
print(f" Epochs: {self.config.num_epochs}")
print(f" Batch size: {self.config.batch_size}")
print(f" Learning rate: {self.config.learning_rate}")
global_step = 0
best_loss = float('inf')
start_time = time.time()
for epoch in range(self.config.num_epochs):
print(f"\n๐Ÿ“š Epoch {epoch + 1}/{self.config.num_epochs}")
print("-" * 40)
self.dataset.shuffle()
epoch_loss = 0
num_batches = 0
for i in range(0, len(self.dataset), self.config.batch_size):
batch = self.dataset.get_batch(i, self.config.batch_size)
if len(batch) < 2:
continue
try:
loss, metrics = self.train_step(batch)
if loss == 0:
continue
epoch_loss += loss
num_batches += 1
global_step += 1
# Logging
if global_step % self.config.log_every == 0:
elapsed = time.time() - start_time
avg_loss = epoch_loss / num_batches
print(f" Step {global_step:5d} | Loss: {loss:.4f} | Avg: {avg_loss:.4f} | {elapsed:.0f}s")
# Checkpointing
if global_step % self.config.save_every == 0:
self.save_checkpoint(global_step, loss)
if loss < best_loss:
best_loss = loss
except Exception as e:
print(f" โš ๏ธ Batch error: {e}")
continue
avg_epoch_loss = epoch_loss / max(num_batches, 1)
print(f"\n โœ“ Epoch {epoch + 1} | Avg loss: {avg_epoch_loss:.4f}")
# Final save
print("\n" + "=" * 60)
print("๐Ÿ’พ Saving Final Model")
print("=" * 60)
final_path = self.checkpoint_dir / "final"
final_path.mkdir(exist_ok=True)
weights = {}
for name, param in self.projector.parameters().items():
weights[name] = np.array(param)
np.savez(str(final_path / "projector.npz"), **weights)
config = {
"fused_dim": self.fused_dim,
"hidden_dim": self.config.projector_hidden_dim,
"num_tokens": self.config.num_vision_tokens,
"embed_dim": self.config.lfm_embed_dim
}
with open(final_path / "config.json", "w") as f:
json.dump(config, f, indent=2)
print(f"โœ… Training complete! Model: {final_path}")
return final_path
def main():
# Check if COCO data exists
coco_dir = OCULUS_ROOT / "data" / "coco"
if not (coco_dir / "train_captions.jsonl").exists():
print("โŒ COCO data not found!")
print(" Run: python download_coco.py")
return
config = TrainingConfig(
data_dir="data/coco",
batch_size=4,
learning_rate=2e-4,
num_epochs=3,
max_samples=5000,
save_every=200,
log_every=25,
)
trainer = OculusTrainer(config)
trainer.train()
if __name__ == "__main__":
main()