import os import argparse import sys from typing import Tuple import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader # Fix import paths sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from data.polyvore import PolyvoreTripletDataset from models.resnet_embedder import ResNetItemEmbedder from utils.export import ensure_export_dir from utils.advanced_metrics import AdvancedMetrics, calculate_triplet_metrics import json def parse_args() -> argparse.Namespace: p = argparse.ArgumentParser() p.add_argument("--data_root", type=str, default=os.getenv("POLYVORE_ROOT", "/home/user/app/data/Polyvore")) p.add_argument("--epochs", type=int, default=50) p.add_argument("--batch_size", type=int, default=16) p.add_argument("--lr", type=float, default=1e-3) p.add_argument("--embedding_dim", type=int, default=512) p.add_argument("--out", type=str, default="models/exports/resnet_item_embedder.pth") p.add_argument("--early_stopping_patience", type=int, default=10, help="Early stopping patience") p.add_argument("--min_delta", type=float, default=1e-4, help="Minimum change to qualify as improvement") return p.parse_args() def main() -> None: args = parse_args() device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu") if device == "cuda": torch.backends.cudnn.benchmark = True print(f"🚀 Starting ResNet training on {device}") print(f"📁 Data root: {args.data_root}") print(f"⚙️ Config: {args.epochs} epochs, batch_size={args.batch_size}, lr={args.lr}") # Ensure splits exist; if missing, prepare from official splits splits_dir = os.path.join(args.data_root, "splits") triplet_path = os.path.join(splits_dir, "train.json") if not os.path.exists(triplet_path): print(f"⚠️ Triplet file not found: {triplet_path}") print("🔧 Attempting to prepare dataset...") os.makedirs(splits_dir, exist_ok=True) try: # Try to import and run the prepare script sys.path.append(os.path.join(os.path.dirname(os.path.dirname(__file__)), "scripts")) from prepare_polyvore import main as prepare_main print("✅ Successfully imported prepare_polyvore") # Prepare dataset without random splits prepare_main() print("✅ Dataset preparation completed") except Exception as e: print(f"❌ Failed to prepare dataset: {e}") print("💡 Please ensure the dataset is prepared manually") return else: print(f"✅ Found existing splits: {triplet_path}") try: dataset = PolyvoreTripletDataset(args.data_root, split="train") print(f"📊 Dataset loaded: {len(dataset)} samples") except Exception as e: print(f"❌ Failed to load dataset: {e}") return loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=2, pin_memory=(device=="cuda")) model = ResNetItemEmbedder(embedding_dim=args.embedding_dim).to(device) optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-4) criterion = nn.TripletMarginLoss(margin=0.2, p=2) print(f"🏗️ Model created: {model.__class__.__name__}") print(f"📈 Total parameters: {sum(p.numel() for p in model.parameters()):,}") export_dir = ensure_export_dir(os.path.dirname(args.out) or "models/exports") best_loss = float("inf") history = [] patience_counter = 0 best_epoch = 0 metrics_collector = AdvancedMetrics() print(f"💾 Checkpoints will be saved to: {export_dir}") print(f"🛑 Early stopping patience: {args.early_stopping_patience} epochs") for epoch in range(args.epochs): model.train() running_loss = 0.0 steps = 0 print(f"🔄 Epoch {epoch+1}/{args.epochs}") for batch_idx, batch in enumerate(loader): try: # Expect batch as (anchor, positive, negative) anchor, positive, negative = batch anchor = anchor.to(device, memory_format=torch.channels_last, non_blocking=True) positive = positive.to(device, memory_format=torch.channels_last, non_blocking=True) negative = negative.to(device, memory_format=torch.channels_last, non_blocking=True) with torch.autocast(device_type=("cuda" if device=="cuda" else "cpu"), enabled=(device=="cuda")): emb_a = model(anchor) emb_p = model(positive) emb_n = model(negative) loss = criterion(emb_a, emb_p, emb_n) optimizer.zero_grad(set_to_none=True) loss.backward() optimizer.step() # Collect metrics triplet_metrics = calculate_triplet_metrics(emb_a, emb_p, emb_n, margin=0.2) metrics_collector.add_batch( predictions=torch.ones(emb_a.size(0)), # Placeholder for compatibility targets=torch.ones(emb_a.size(0)), # Placeholder for compatibility embeddings=emb_a.detach() # Detach to avoid gradient issues ) running_loss += loss.item() steps += 1 if batch_idx % 10 == 0: # More frequent logging print(f" Batch {batch_idx}/{len(loader)}: loss={loss.item():.4f}") except Exception as e: print(f"❌ Error in batch {batch_idx}: {e}") continue # Print final batch completion print(f" ✅ Batch {len(loader)-1}/{len(loader)}: loss={loss.item():.4f}") print(f" 📊 Epoch {epoch+1} completed: {len(loader)} batches processed") avg_loss = running_loss / max(1, steps) # Save checkpoint with better path handling out_path = args.out if not out_path.startswith("models/"): out_path = os.path.join(export_dir, os.path.basename(args.out)) # Ensure the output directory exists os.makedirs(os.path.dirname(out_path), exist_ok=True) # Save checkpoint torch.save({"state_dict": model.state_dict(), "epoch": epoch+1, "loss": avg_loss}, out_path) print(f"✅ Epoch {epoch+1}/{args.epochs} avg_triplet_loss={avg_loss:.4f} saved -> {out_path}") history.append({"epoch": epoch + 1, "avg_triplet_loss": avg_loss}) # Early stopping logic if avg_loss < best_loss - args.min_delta: best_loss = avg_loss best_epoch = epoch + 1 patience_counter = 0 best_path = os.path.join(export_dir, "resnet_item_embedder_best.pth") torch.save({"state_dict": model.state_dict(), "epoch": epoch+1, "loss": avg_loss}, best_path) print(f"🏆 New best model saved: {best_path} (loss: {avg_loss:.4f})") else: patience_counter += 1 print(f"⏳ No improvement for {patience_counter} epochs (best: {best_loss:.4f} at epoch {best_epoch})") if patience_counter >= args.early_stopping_patience: print(f"🛑 Early stopping triggered after {patience_counter} epochs without improvement") print(f"🏆 Best model was at epoch {best_epoch} with loss {best_loss:.4f}") break # Write comprehensive metrics metrics_path = os.path.join(export_dir, "resnet_metrics.json") # Get advanced metrics advanced_metrics = metrics_collector.calculate_all_metrics() final_metrics = { "best_triplet_loss": best_loss, "best_epoch": best_epoch, "total_epochs": epoch + 1, "early_stopping_triggered": patience_counter >= args.early_stopping_patience, "patience_counter": patience_counter, "training_config": { "epochs": args.epochs, "batch_size": args.batch_size, "learning_rate": args.lr, "embedding_dim": args.embedding_dim, "early_stopping_patience": args.early_stopping_patience, "min_delta": args.min_delta }, "history": history, "advanced_metrics": advanced_metrics } with open(metrics_path, "w") as f: json.dump(final_metrics, f, indent=2) print(f"📊 Training completed! Best loss: {best_loss:.4f} at epoch {best_epoch}") print(f"📈 Comprehensive metrics saved to: {metrics_path}") print(f"🔬 Advanced metrics: {advanced_metrics['summary']}") if __name__ == "__main__": main()