Spaces:
Paused
Paused
| import os | |
| import argparse | |
| import sys | |
| from typing import List | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.utils.data import DataLoader | |
| # Fix import paths | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from data.polyvore import PolyvoreOutfitTripletDataset | |
| from models.vit_outfit import OutfitCompatibilityModel | |
| from models.resnet_embedder import ResNetItemEmbedder | |
| from utils.export import ensure_export_dir | |
| from utils.advanced_metrics import AdvancedMetrics, calculate_outfit_compatibility_metrics | |
| import json | |
| def parse_args() -> argparse.Namespace: | |
| p = argparse.ArgumentParser() | |
| p.add_argument("--data_root", type=str, default=os.getenv("POLYVORE_ROOT", "/home/user/app/data/Polyvore")) | |
| p.add_argument("--epochs", type=int, default=50) | |
| p.add_argument("--batch_size", type=int, default=4) | |
| p.add_argument("--lr", type=float, default=5e-4) | |
| p.add_argument("--embedding_dim", type=int, default=512) | |
| p.add_argument("--triplet_margin", type=float, default=0.5) | |
| p.add_argument("--export", type=str, default="models/exports/vit_outfit_model.pth") | |
| p.add_argument("--eval_every", type=int, default=1) | |
| p.add_argument("--skip_validation", action="store_true", help="Skip validation for faster training") | |
| p.add_argument("--max_samples", type=int, default=5000, help="Maximum number of training samples (for better quality)") | |
| p.add_argument("--early_stopping_patience", type=int, default=5, help="Early stopping patience") | |
| p.add_argument("--min_delta", type=float, default=1e-4, help="Minimum change to qualify as improvement") | |
| p.add_argument("--gradient_clip", type=float, default=1.0, help="Gradient clipping value") | |
| p.add_argument("--warmup_epochs", type=int, default=2, help="Learning rate warmup epochs") | |
| return p.parse_args() | |
| def embed_outfit(imgs: List[torch.Tensor], embedder: ResNetItemEmbedder, device: str, max_len: int = 4) -> torch.Tensor: | |
| if len(imgs) == 0: | |
| return torch.zeros((max_len, embedder.proj.out_features), device=device) | |
| k = min(len(imgs), max_len) | |
| x = torch.stack(imgs[:k], dim=0).to(device) | |
| with torch.no_grad(): | |
| e = embedder(x) # (k, D) | |
| tokens = torch.zeros((max_len, e.shape[-1]), device=device) | |
| tokens[:k] = e | |
| return tokens | |
| def main() -> None: | |
| args = parse_args() | |
| device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu") | |
| if device == "cuda": | |
| torch.backends.cudnn.benchmark = True | |
| print(f"π Starting ViT Outfit training on {device}") | |
| print(f"π Data root: {args.data_root}") | |
| print(f"βοΈ Config: {args.epochs} epochs, batch_size={args.batch_size}, lr={args.lr}") | |
| # Ensure outfit triplets exist | |
| splits_dir = os.path.join(args.data_root, "splits") | |
| trip_path = os.path.join(splits_dir, "outfit_triplets_train.json") | |
| if not os.path.exists(trip_path): | |
| print(f"β οΈ Outfit triplet file not found: {trip_path}") | |
| print("π§ Attempting to prepare dataset...") | |
| os.makedirs(splits_dir, exist_ok=True) | |
| try: | |
| # Try to import and run the prepare script | |
| sys.path.append(os.path.join(os.path.dirname(os.path.dirname(__file__)), "scripts")) | |
| from prepare_polyvore import main as prepare_main | |
| print("β Successfully imported prepare_polyvore") | |
| # Prepare dataset without random splits | |
| prepare_main() | |
| print("β Dataset preparation completed") | |
| except Exception as e: | |
| print(f"β Failed to prepare dataset: {e}") | |
| print("π‘ Please ensure the dataset is prepared manually") | |
| return | |
| else: | |
| print(f"β Found existing outfit triplets: {trip_path}") | |
| try: | |
| dataset = PolyvoreOutfitTripletDataset(args.data_root, split="train") | |
| # Limit dataset size for faster training/testing | |
| max_samples = min(len(dataset), args.max_samples) | |
| print(f"π Debug: Original dataset size: {len(dataset)}, max_samples: {args.max_samples}") | |
| if len(dataset) > max_samples: | |
| dataset.samples = dataset.samples[:max_samples] | |
| print(f"π Dataset limited to {max_samples} samples for faster training") | |
| else: | |
| print(f"π Dataset loaded: {len(dataset)} samples (no limiting needed)") | |
| except Exception as e: | |
| print(f"β Failed to load dataset: {e}") | |
| return | |
| def collate(batch): | |
| return batch # variable length handled inside training loop | |
| loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=2, pin_memory=(device=="cuda"), collate_fn=collate) | |
| model = OutfitCompatibilityModel(embedding_dim=args.embedding_dim).to(device) | |
| embedder = ResNetItemEmbedder(embedding_dim=args.embedding_dim).to(device).eval() | |
| for p in embedder.parameters(): | |
| p.requires_grad_(False) | |
| print(f"ποΈ Models created:") | |
| print(f" - ViT Outfit: {model.__class__.__name__}") | |
| print(f" - ResNet Embedder: {embedder.__class__.__name__}") | |
| print(f"π Total parameters: {sum(p.numel() for p in model.parameters()):,}") | |
| optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=5e-2) | |
| triplet = nn.TripletMarginWithDistanceLoss(distance_function=lambda x, y: 1 - nn.functional.cosine_similarity(x, y), margin=args.triplet_margin) | |
| # Learning rate scheduler with warmup | |
| total_steps = len(loader) * args.epochs | |
| warmup_steps = len(loader) * args.warmup_epochs | |
| scheduler = torch.optim.lr_scheduler.OneCycleLR( | |
| optimizer, | |
| max_lr=args.lr, | |
| total_steps=total_steps, | |
| pct_start=warmup_steps/total_steps, | |
| anneal_strategy='cos' | |
| ) | |
| export_dir = ensure_export_dir(os.path.dirname(args.export) or "models/exports") | |
| best_loss = float("inf") | |
| hist = [] | |
| patience_counter = 0 | |
| best_epoch = 0 | |
| metrics_collector = AdvancedMetrics() | |
| print(f"πΎ Checkpoints will be saved to: {export_dir}") | |
| print(f"π Early stopping patience: {args.early_stopping_patience} epochs") | |
| for epoch in range(args.epochs): | |
| model.train() | |
| running_loss = 0.0 | |
| steps = 0 | |
| print(f"π Epoch {epoch+1}/{args.epochs}") | |
| for batch_idx, batch in enumerate(loader): | |
| try: | |
| # batch: List[(ga_imgs, gb_imgs, bd_imgs)] | |
| anchor_tokens = [] | |
| positive_tokens = [] | |
| negative_tokens = [] | |
| for ga, gb, bd in batch: | |
| ta = embed_outfit(ga, embedder, device) | |
| tb = embed_outfit(gb, embedder, device) | |
| tn = embed_outfit(bd, embedder, device) | |
| anchor_tokens.append(ta.unsqueeze(0)) | |
| positive_tokens.append(tb.unsqueeze(0)) | |
| negative_tokens.append(tn.unsqueeze(0)) | |
| A = torch.cat(anchor_tokens, dim=0) # (B, N, D) | |
| P = torch.cat(positive_tokens, dim=0) | |
| N = torch.cat(negative_tokens, dim=0) | |
| # get outfit-level embeddings via ViT encoder pooled output | |
| with torch.autocast(device_type=("cuda" if device=="cuda" else "cpu"), enabled=(device=="cuda")): | |
| ea = model.encoder(A).mean(dim=1) | |
| ep = model.encoder(P).mean(dim=1) | |
| en = model.encoder(N).mean(dim=1) | |
| loss = triplet(ea, ep, en) | |
| optimizer.zero_grad(set_to_none=True) | |
| loss.backward() | |
| # Gradient clipping for stability | |
| if args.gradient_clip > 0: | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), args.gradient_clip) | |
| optimizer.step() | |
| scheduler.step() # Update learning rate | |
| # Collect metrics (simplified for ViT training) | |
| # Note: ViT training uses outfit-level embeddings, not classification predictions | |
| # So we skip the problematic metrics collection that expects binary targets | |
| running_loss += loss.item() | |
| steps += 1 | |
| if batch_idx % 10 == 0: # More frequent logging | |
| print(f" Batch {batch_idx}/{len(loader)}: loss={loss.item():.4f}") | |
| except Exception as e: | |
| print(f"β Error in batch {batch_idx}: {e}") | |
| continue | |
| # Print final batch completion | |
| print(f" β Batch {len(loader)-1}/{len(loader)}: loss={loss.item():.4f}") | |
| print(f" π Epoch {epoch+1} completed: {len(loader)} batches processed") | |
| avg_loss = running_loss / max(1, steps) | |
| # Fast validation with limited samples to prevent hanging | |
| val_path = os.path.join(args.data_root, "splits", "outfit_triplets_valid.json") | |
| val_loss = None | |
| if not args.skip_validation and os.path.exists(val_path) and (epoch + 1) % args.eval_every == 0: | |
| try: | |
| print(f" π Starting validation (limited to 50 samples for speed)...") | |
| val_ds = PolyvoreOutfitTripletDataset(args.data_root, split="valid") | |
| # Limit validation to first 50 samples to prevent hanging | |
| val_samples = val_ds.samples[:50] | |
| val_ds.samples = val_samples | |
| val_loader = DataLoader(val_ds, batch_size=min(args.batch_size, 8), shuffle=False, num_workers=0, collate_fn=lambda x: x) | |
| model.eval() | |
| losses = [] | |
| with torch.no_grad(): | |
| for i, vbatch in enumerate(val_loader): | |
| if i >= 10: # Limit to 10 batches max for speed | |
| break | |
| anchor_tokens = [] | |
| positive_tokens = [] | |
| negative_tokens = [] | |
| for ga, gb, bd in vbatch: | |
| ta = embed_outfit(ga, embedder, device) | |
| tb = embed_outfit(gb, embedder, device) | |
| tn = embed_outfit(bd, embedder, device) | |
| anchor_tokens.append(ta.unsqueeze(0)) | |
| positive_tokens.append(tb.unsqueeze(0)) | |
| negative_tokens.append(tn.unsqueeze(0)) | |
| A = torch.cat(anchor_tokens, dim=0) | |
| P = torch.cat(positive_tokens, dim=0) | |
| N = torch.cat(negative_tokens, dim=0) | |
| ea = model.encoder(A).mean(dim=1) | |
| ep = model.encoder(P).mean(dim=1) | |
| en = model.encoder(N).mean(dim=1) | |
| l = triplet(ea, ep, en).item() | |
| losses.append(l) | |
| val_loss = sum(losses) / max(1, len(losses)) | |
| print(f" π Validation loss: {val_loss:.4f} (from {len(losses)} batches)") | |
| except Exception as e: | |
| print(f" β οΈ Validation failed: {e}") | |
| val_loss = None | |
| out_path = args.export | |
| if not out_path.startswith("models/"): | |
| out_path = os.path.join(export_dir, os.path.basename(args.export)) | |
| # Save checkpoint | |
| torch.save({"state_dict": model.state_dict(), "epoch": epoch+1, "loss": avg_loss}, out_path) | |
| if val_loss is not None: | |
| print(f"β Epoch {epoch+1}/{args.epochs} triplet_loss={avg_loss:.4f} val_triplet_loss={val_loss:.4f} saved -> {out_path}") | |
| hist.append({"epoch": epoch + 1, "triplet_loss": float(avg_loss), "val_triplet_loss": float(val_loss)}) | |
| # Early stopping logic | |
| if val_loss < best_loss - args.min_delta: | |
| best_loss = val_loss | |
| best_epoch = epoch + 1 | |
| patience_counter = 0 | |
| best_path = os.path.join(export_dir, "vit_outfit_model_best.pth") | |
| torch.save({"state_dict": model.state_dict(), "epoch": epoch+1, "loss": avg_loss, "val_loss": val_loss}, best_path) | |
| print(f"π New best model saved: {best_path} (val_loss: {val_loss:.4f})") | |
| else: | |
| patience_counter += 1 | |
| print(f"β³ No improvement for {patience_counter} epochs (best: {best_loss:.4f} at epoch {best_epoch})") | |
| if patience_counter >= args.early_stopping_patience: | |
| print(f"π Early stopping triggered after {patience_counter} epochs without improvement") | |
| print(f"π Best model was at epoch {best_epoch} with val_loss {best_loss:.4f}") | |
| break | |
| else: | |
| print(f"β Epoch {epoch+1}/{args.epochs} triplet_loss={avg_loss:.4f} saved -> {out_path}") | |
| hist.append({"epoch": epoch + 1, "triplet_loss": float(avg_loss)}) | |
| # Write comprehensive metrics | |
| metrics_path = os.path.join(export_dir, "vit_metrics.json") | |
| # Get advanced metrics (simplified for ViT training) | |
| # Note: ViT training doesn't collect classification metrics, so we create empty metrics | |
| advanced_metrics = { | |
| "total_predictions": 0, | |
| "total_targets": 0, | |
| "total_scores": 0, | |
| "total_embeddings": 0, | |
| "total_outfit_scores": 0 | |
| } | |
| final_metrics = { | |
| "best_val_triplet_loss": best_loss if best_loss != float("inf") else None, | |
| "best_epoch": best_epoch, | |
| "total_epochs": epoch + 1, | |
| "early_stopping_triggered": patience_counter >= args.early_stopping_patience, | |
| "patience_counter": patience_counter, | |
| "training_config": { | |
| "epochs": args.epochs, | |
| "batch_size": args.batch_size, | |
| "learning_rate": args.lr, | |
| "embedding_dim": args.embedding_dim, | |
| "triplet_margin": args.triplet_margin, | |
| "early_stopping_patience": args.early_stopping_patience, | |
| "min_delta": args.min_delta | |
| }, | |
| "history": hist, | |
| "advanced_metrics": advanced_metrics | |
| } | |
| with open(metrics_path, "w") as f: | |
| json.dump(final_metrics, f, indent=2) | |
| # Always save a best model (use final model if no validation was done) | |
| if best_loss == float("inf"): | |
| best_path = os.path.join(export_dir, "vit_outfit_model_best.pth") | |
| torch.save({"state_dict": model.state_dict(), "epoch": epoch+1, "loss": avg_loss}, best_path) | |
| print(f"π Final model saved as best: {best_path} (loss: {avg_loss:.4f})") | |
| print(f"π Training completed! Best val_loss: {best_loss:.4f} at epoch {best_epoch}") | |
| print(f"π Comprehensive metrics saved to: {metrics_path}") | |
| print(f"π¬ Advanced metrics: {advanced_metrics}") | |
| if __name__ == "__main__": | |
| main() | |