recomendation / training /train_vit.py
Ali Mohsin
folder reorganise
72af8c3
import os
import argparse
import sys
from typing import List
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
# Fix import paths
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from data.polyvore import PolyvoreOutfitTripletDataset
from models.vit_outfit import OutfitCompatibilityModel
from models.resnet_embedder import ResNetItemEmbedder
from utils.export import ensure_export_dir
from utils.advanced_metrics import AdvancedMetrics, calculate_outfit_compatibility_metrics
import json
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser()
p.add_argument("--data_root", type=str, default=os.getenv("POLYVORE_ROOT", "/home/user/app/data/Polyvore"))
p.add_argument("--epochs", type=int, default=50)
p.add_argument("--batch_size", type=int, default=4)
p.add_argument("--lr", type=float, default=5e-4)
p.add_argument("--embedding_dim", type=int, default=512)
p.add_argument("--triplet_margin", type=float, default=0.5)
p.add_argument("--export", type=str, default="models/exports/vit_outfit_model.pth")
p.add_argument("--eval_every", type=int, default=1)
p.add_argument("--skip_validation", action="store_true", help="Skip validation for faster training")
p.add_argument("--max_samples", type=int, default=5000, help="Maximum number of training samples (for better quality)")
p.add_argument("--early_stopping_patience", type=int, default=5, help="Early stopping patience")
p.add_argument("--min_delta", type=float, default=1e-4, help="Minimum change to qualify as improvement")
p.add_argument("--gradient_clip", type=float, default=1.0, help="Gradient clipping value")
p.add_argument("--warmup_epochs", type=int, default=2, help="Learning rate warmup epochs")
return p.parse_args()
def embed_outfit(imgs: List[torch.Tensor], embedder: ResNetItemEmbedder, device: str, max_len: int = 4) -> torch.Tensor:
if len(imgs) == 0:
return torch.zeros((max_len, embedder.proj.out_features), device=device)
k = min(len(imgs), max_len)
x = torch.stack(imgs[:k], dim=0).to(device)
with torch.no_grad():
e = embedder(x) # (k, D)
tokens = torch.zeros((max_len, e.shape[-1]), device=device)
tokens[:k] = e
return tokens
def main() -> None:
args = parse_args()
device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
if device == "cuda":
torch.backends.cudnn.benchmark = True
print(f"πŸš€ Starting ViT Outfit training on {device}")
print(f"πŸ“ Data root: {args.data_root}")
print(f"βš™οΈ Config: {args.epochs} epochs, batch_size={args.batch_size}, lr={args.lr}")
# Ensure outfit triplets exist
splits_dir = os.path.join(args.data_root, "splits")
trip_path = os.path.join(splits_dir, "outfit_triplets_train.json")
if not os.path.exists(trip_path):
print(f"⚠️ Outfit triplet file not found: {trip_path}")
print("πŸ”§ Attempting to prepare dataset...")
os.makedirs(splits_dir, exist_ok=True)
try:
# Try to import and run the prepare script
sys.path.append(os.path.join(os.path.dirname(os.path.dirname(__file__)), "scripts"))
from prepare_polyvore import main as prepare_main
print("βœ… Successfully imported prepare_polyvore")
# Prepare dataset without random splits
prepare_main()
print("βœ… Dataset preparation completed")
except Exception as e:
print(f"❌ Failed to prepare dataset: {e}")
print("πŸ’‘ Please ensure the dataset is prepared manually")
return
else:
print(f"βœ… Found existing outfit triplets: {trip_path}")
try:
dataset = PolyvoreOutfitTripletDataset(args.data_root, split="train")
# Limit dataset size for faster training/testing
max_samples = min(len(dataset), args.max_samples)
print(f"πŸ” Debug: Original dataset size: {len(dataset)}, max_samples: {args.max_samples}")
if len(dataset) > max_samples:
dataset.samples = dataset.samples[:max_samples]
print(f"πŸ“Š Dataset limited to {max_samples} samples for faster training")
else:
print(f"πŸ“Š Dataset loaded: {len(dataset)} samples (no limiting needed)")
except Exception as e:
print(f"❌ Failed to load dataset: {e}")
return
def collate(batch):
return batch # variable length handled inside training loop
loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=2, pin_memory=(device=="cuda"), collate_fn=collate)
model = OutfitCompatibilityModel(embedding_dim=args.embedding_dim).to(device)
embedder = ResNetItemEmbedder(embedding_dim=args.embedding_dim).to(device).eval()
for p in embedder.parameters():
p.requires_grad_(False)
print(f"πŸ—οΈ Models created:")
print(f" - ViT Outfit: {model.__class__.__name__}")
print(f" - ResNet Embedder: {embedder.__class__.__name__}")
print(f"πŸ“ˆ Total parameters: {sum(p.numel() for p in model.parameters()):,}")
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=5e-2)
triplet = nn.TripletMarginWithDistanceLoss(distance_function=lambda x, y: 1 - nn.functional.cosine_similarity(x, y), margin=args.triplet_margin)
# Learning rate scheduler with warmup
total_steps = len(loader) * args.epochs
warmup_steps = len(loader) * args.warmup_epochs
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=args.lr,
total_steps=total_steps,
pct_start=warmup_steps/total_steps,
anneal_strategy='cos'
)
export_dir = ensure_export_dir(os.path.dirname(args.export) or "models/exports")
best_loss = float("inf")
hist = []
patience_counter = 0
best_epoch = 0
metrics_collector = AdvancedMetrics()
print(f"πŸ’Ύ Checkpoints will be saved to: {export_dir}")
print(f"πŸ›‘ Early stopping patience: {args.early_stopping_patience} epochs")
for epoch in range(args.epochs):
model.train()
running_loss = 0.0
steps = 0
print(f"πŸ”„ Epoch {epoch+1}/{args.epochs}")
for batch_idx, batch in enumerate(loader):
try:
# batch: List[(ga_imgs, gb_imgs, bd_imgs)]
anchor_tokens = []
positive_tokens = []
negative_tokens = []
for ga, gb, bd in batch:
ta = embed_outfit(ga, embedder, device)
tb = embed_outfit(gb, embedder, device)
tn = embed_outfit(bd, embedder, device)
anchor_tokens.append(ta.unsqueeze(0))
positive_tokens.append(tb.unsqueeze(0))
negative_tokens.append(tn.unsqueeze(0))
A = torch.cat(anchor_tokens, dim=0) # (B, N, D)
P = torch.cat(positive_tokens, dim=0)
N = torch.cat(negative_tokens, dim=0)
# get outfit-level embeddings via ViT encoder pooled output
with torch.autocast(device_type=("cuda" if device=="cuda" else "cpu"), enabled=(device=="cuda")):
ea = model.encoder(A).mean(dim=1)
ep = model.encoder(P).mean(dim=1)
en = model.encoder(N).mean(dim=1)
loss = triplet(ea, ep, en)
optimizer.zero_grad(set_to_none=True)
loss.backward()
# Gradient clipping for stability
if args.gradient_clip > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.gradient_clip)
optimizer.step()
scheduler.step() # Update learning rate
# Collect metrics (simplified for ViT training)
# Note: ViT training uses outfit-level embeddings, not classification predictions
# So we skip the problematic metrics collection that expects binary targets
running_loss += loss.item()
steps += 1
if batch_idx % 10 == 0: # More frequent logging
print(f" Batch {batch_idx}/{len(loader)}: loss={loss.item():.4f}")
except Exception as e:
print(f"❌ Error in batch {batch_idx}: {e}")
continue
# Print final batch completion
print(f" βœ… Batch {len(loader)-1}/{len(loader)}: loss={loss.item():.4f}")
print(f" πŸ“Š Epoch {epoch+1} completed: {len(loader)} batches processed")
avg_loss = running_loss / max(1, steps)
# Fast validation with limited samples to prevent hanging
val_path = os.path.join(args.data_root, "splits", "outfit_triplets_valid.json")
val_loss = None
if not args.skip_validation and os.path.exists(val_path) and (epoch + 1) % args.eval_every == 0:
try:
print(f" πŸ” Starting validation (limited to 50 samples for speed)...")
val_ds = PolyvoreOutfitTripletDataset(args.data_root, split="valid")
# Limit validation to first 50 samples to prevent hanging
val_samples = val_ds.samples[:50]
val_ds.samples = val_samples
val_loader = DataLoader(val_ds, batch_size=min(args.batch_size, 8), shuffle=False, num_workers=0, collate_fn=lambda x: x)
model.eval()
losses = []
with torch.no_grad():
for i, vbatch in enumerate(val_loader):
if i >= 10: # Limit to 10 batches max for speed
break
anchor_tokens = []
positive_tokens = []
negative_tokens = []
for ga, gb, bd in vbatch:
ta = embed_outfit(ga, embedder, device)
tb = embed_outfit(gb, embedder, device)
tn = embed_outfit(bd, embedder, device)
anchor_tokens.append(ta.unsqueeze(0))
positive_tokens.append(tb.unsqueeze(0))
negative_tokens.append(tn.unsqueeze(0))
A = torch.cat(anchor_tokens, dim=0)
P = torch.cat(positive_tokens, dim=0)
N = torch.cat(negative_tokens, dim=0)
ea = model.encoder(A).mean(dim=1)
ep = model.encoder(P).mean(dim=1)
en = model.encoder(N).mean(dim=1)
l = triplet(ea, ep, en).item()
losses.append(l)
val_loss = sum(losses) / max(1, len(losses))
print(f" πŸ“Š Validation loss: {val_loss:.4f} (from {len(losses)} batches)")
except Exception as e:
print(f" ⚠️ Validation failed: {e}")
val_loss = None
out_path = args.export
if not out_path.startswith("models/"):
out_path = os.path.join(export_dir, os.path.basename(args.export))
# Save checkpoint
torch.save({"state_dict": model.state_dict(), "epoch": epoch+1, "loss": avg_loss}, out_path)
if val_loss is not None:
print(f"βœ… Epoch {epoch+1}/{args.epochs} triplet_loss={avg_loss:.4f} val_triplet_loss={val_loss:.4f} saved -> {out_path}")
hist.append({"epoch": epoch + 1, "triplet_loss": float(avg_loss), "val_triplet_loss": float(val_loss)})
# Early stopping logic
if val_loss < best_loss - args.min_delta:
best_loss = val_loss
best_epoch = epoch + 1
patience_counter = 0
best_path = os.path.join(export_dir, "vit_outfit_model_best.pth")
torch.save({"state_dict": model.state_dict(), "epoch": epoch+1, "loss": avg_loss, "val_loss": val_loss}, best_path)
print(f"πŸ† New best model saved: {best_path} (val_loss: {val_loss:.4f})")
else:
patience_counter += 1
print(f"⏳ No improvement for {patience_counter} epochs (best: {best_loss:.4f} at epoch {best_epoch})")
if patience_counter >= args.early_stopping_patience:
print(f"πŸ›‘ Early stopping triggered after {patience_counter} epochs without improvement")
print(f"πŸ† Best model was at epoch {best_epoch} with val_loss {best_loss:.4f}")
break
else:
print(f"βœ… Epoch {epoch+1}/{args.epochs} triplet_loss={avg_loss:.4f} saved -> {out_path}")
hist.append({"epoch": epoch + 1, "triplet_loss": float(avg_loss)})
# Write comprehensive metrics
metrics_path = os.path.join(export_dir, "vit_metrics.json")
# Get advanced metrics (simplified for ViT training)
# Note: ViT training doesn't collect classification metrics, so we create empty metrics
advanced_metrics = {
"total_predictions": 0,
"total_targets": 0,
"total_scores": 0,
"total_embeddings": 0,
"total_outfit_scores": 0
}
final_metrics = {
"best_val_triplet_loss": best_loss if best_loss != float("inf") else None,
"best_epoch": best_epoch,
"total_epochs": epoch + 1,
"early_stopping_triggered": patience_counter >= args.early_stopping_patience,
"patience_counter": patience_counter,
"training_config": {
"epochs": args.epochs,
"batch_size": args.batch_size,
"learning_rate": args.lr,
"embedding_dim": args.embedding_dim,
"triplet_margin": args.triplet_margin,
"early_stopping_patience": args.early_stopping_patience,
"min_delta": args.min_delta
},
"history": hist,
"advanced_metrics": advanced_metrics
}
with open(metrics_path, "w") as f:
json.dump(final_metrics, f, indent=2)
# Always save a best model (use final model if no validation was done)
if best_loss == float("inf"):
best_path = os.path.join(export_dir, "vit_outfit_model_best.pth")
torch.save({"state_dict": model.state_dict(), "epoch": epoch+1, "loss": avg_loss}, best_path)
print(f"πŸ† Final model saved as best: {best_path} (loss: {avg_loss:.4f})")
print(f"πŸ“Š Training completed! Best val_loss: {best_loss:.4f} at epoch {best_epoch}")
print(f"πŸ“ˆ Comprehensive metrics saved to: {metrics_path}")
print(f"πŸ”¬ Advanced metrics: {advanced_metrics}")
if __name__ == "__main__":
main()