Spaces:
Paused
Paused
| import os | |
| import argparse | |
| import sys | |
| from typing import Tuple | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.utils.data import DataLoader | |
| # Fix import paths | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from data.polyvore import PolyvoreTripletDataset | |
| from models.resnet_embedder import ResNetItemEmbedder | |
| from utils.export import ensure_export_dir | |
| from utils.advanced_metrics import AdvancedMetrics, calculate_triplet_metrics | |
| import json | |
| def parse_args() -> argparse.Namespace: | |
| p = argparse.ArgumentParser() | |
| p.add_argument("--data_root", type=str, default=os.getenv("POLYVORE_ROOT", "/home/user/app/data/Polyvore")) | |
| p.add_argument("--epochs", type=int, default=50) | |
| p.add_argument("--batch_size", type=int, default=16) | |
| p.add_argument("--lr", type=float, default=1e-3) | |
| p.add_argument("--embedding_dim", type=int, default=512) | |
| p.add_argument("--out", type=str, default="models/exports/resnet_item_embedder.pth") | |
| p.add_argument("--early_stopping_patience", type=int, default=10, help="Early stopping patience") | |
| p.add_argument("--min_delta", type=float, default=1e-4, help="Minimum change to qualify as improvement") | |
| return p.parse_args() | |
| def main() -> None: | |
| args = parse_args() | |
| device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu") | |
| if device == "cuda": | |
| torch.backends.cudnn.benchmark = True | |
| print(f"π Starting ResNet training on {device}") | |
| print(f"π Data root: {args.data_root}") | |
| print(f"βοΈ Config: {args.epochs} epochs, batch_size={args.batch_size}, lr={args.lr}") | |
| # Ensure splits exist; if missing, prepare from official splits | |
| splits_dir = os.path.join(args.data_root, "splits") | |
| triplet_path = os.path.join(splits_dir, "train.json") | |
| if not os.path.exists(triplet_path): | |
| print(f"β οΈ Triplet file not found: {triplet_path}") | |
| print("π§ Attempting to prepare dataset...") | |
| os.makedirs(splits_dir, exist_ok=True) | |
| try: | |
| # Try to import and run the prepare script | |
| sys.path.append(os.path.join(os.path.dirname(os.path.dirname(__file__)), "scripts")) | |
| from prepare_polyvore import main as prepare_main | |
| print("β Successfully imported prepare_polyvore") | |
| # Prepare dataset without random splits | |
| prepare_main() | |
| print("β Dataset preparation completed") | |
| except Exception as e: | |
| print(f"β Failed to prepare dataset: {e}") | |
| print("π‘ Please ensure the dataset is prepared manually") | |
| return | |
| else: | |
| print(f"β Found existing splits: {triplet_path}") | |
| try: | |
| dataset = PolyvoreTripletDataset(args.data_root, split="train") | |
| print(f"π Dataset loaded: {len(dataset)} samples") | |
| except Exception as e: | |
| print(f"β Failed to load dataset: {e}") | |
| return | |
| loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=2, pin_memory=(device=="cuda")) | |
| model = ResNetItemEmbedder(embedding_dim=args.embedding_dim).to(device) | |
| optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-4) | |
| criterion = nn.TripletMarginLoss(margin=0.2, p=2) | |
| print(f"ποΈ Model created: {model.__class__.__name__}") | |
| print(f"π Total parameters: {sum(p.numel() for p in model.parameters()):,}") | |
| export_dir = ensure_export_dir(os.path.dirname(args.out) or "models/exports") | |
| best_loss = float("inf") | |
| history = [] | |
| patience_counter = 0 | |
| best_epoch = 0 | |
| metrics_collector = AdvancedMetrics() | |
| print(f"πΎ Checkpoints will be saved to: {export_dir}") | |
| print(f"π Early stopping patience: {args.early_stopping_patience} epochs") | |
| for epoch in range(args.epochs): | |
| model.train() | |
| running_loss = 0.0 | |
| steps = 0 | |
| print(f"π Epoch {epoch+1}/{args.epochs}") | |
| for batch_idx, batch in enumerate(loader): | |
| try: | |
| # Expect batch as (anchor, positive, negative) | |
| anchor, positive, negative = batch | |
| anchor = anchor.to(device, memory_format=torch.channels_last, non_blocking=True) | |
| positive = positive.to(device, memory_format=torch.channels_last, non_blocking=True) | |
| negative = negative.to(device, memory_format=torch.channels_last, non_blocking=True) | |
| with torch.autocast(device_type=("cuda" if device=="cuda" else "cpu"), enabled=(device=="cuda")): | |
| emb_a = model(anchor) | |
| emb_p = model(positive) | |
| emb_n = model(negative) | |
| loss = criterion(emb_a, emb_p, emb_n) | |
| optimizer.zero_grad(set_to_none=True) | |
| loss.backward() | |
| optimizer.step() | |
| # Collect metrics | |
| triplet_metrics = calculate_triplet_metrics(emb_a, emb_p, emb_n, margin=0.2) | |
| metrics_collector.add_batch( | |
| predictions=torch.ones(emb_a.size(0)), # Placeholder for compatibility | |
| targets=torch.ones(emb_a.size(0)), # Placeholder for compatibility | |
| embeddings=emb_a.detach() # Detach to avoid gradient issues | |
| ) | |
| running_loss += loss.item() | |
| steps += 1 | |
| if batch_idx % 10 == 0: # More frequent logging | |
| print(f" Batch {batch_idx}/{len(loader)}: loss={loss.item():.4f}") | |
| except Exception as e: | |
| print(f"β Error in batch {batch_idx}: {e}") | |
| continue | |
| # Print final batch completion | |
| print(f" β Batch {len(loader)-1}/{len(loader)}: loss={loss.item():.4f}") | |
| print(f" π Epoch {epoch+1} completed: {len(loader)} batches processed") | |
| avg_loss = running_loss / max(1, steps) | |
| # Save checkpoint with better path handling | |
| out_path = args.out | |
| if not out_path.startswith("models/"): | |
| out_path = os.path.join(export_dir, os.path.basename(args.out)) | |
| # Ensure the output directory exists | |
| os.makedirs(os.path.dirname(out_path), exist_ok=True) | |
| # Save checkpoint | |
| torch.save({"state_dict": model.state_dict(), "epoch": epoch+1, "loss": avg_loss}, out_path) | |
| print(f"β Epoch {epoch+1}/{args.epochs} avg_triplet_loss={avg_loss:.4f} saved -> {out_path}") | |
| history.append({"epoch": epoch + 1, "avg_triplet_loss": avg_loss}) | |
| # Early stopping logic | |
| if avg_loss < best_loss - args.min_delta: | |
| best_loss = avg_loss | |
| best_epoch = epoch + 1 | |
| patience_counter = 0 | |
| best_path = os.path.join(export_dir, "resnet_item_embedder_best.pth") | |
| torch.save({"state_dict": model.state_dict(), "epoch": epoch+1, "loss": avg_loss}, best_path) | |
| print(f"π New best model saved: {best_path} (loss: {avg_loss:.4f})") | |
| else: | |
| patience_counter += 1 | |
| print(f"β³ No improvement for {patience_counter} epochs (best: {best_loss:.4f} at epoch {best_epoch})") | |
| if patience_counter >= args.early_stopping_patience: | |
| print(f"π Early stopping triggered after {patience_counter} epochs without improvement") | |
| print(f"π Best model was at epoch {best_epoch} with loss {best_loss:.4f}") | |
| break | |
| # Write comprehensive metrics | |
| metrics_path = os.path.join(export_dir, "resnet_metrics.json") | |
| # Get advanced metrics | |
| advanced_metrics = metrics_collector.calculate_all_metrics() | |
| final_metrics = { | |
| "best_triplet_loss": best_loss, | |
| "best_epoch": best_epoch, | |
| "total_epochs": epoch + 1, | |
| "early_stopping_triggered": patience_counter >= args.early_stopping_patience, | |
| "patience_counter": patience_counter, | |
| "training_config": { | |
| "epochs": args.epochs, | |
| "batch_size": args.batch_size, | |
| "learning_rate": args.lr, | |
| "embedding_dim": args.embedding_dim, | |
| "early_stopping_patience": args.early_stopping_patience, | |
| "min_delta": args.min_delta | |
| }, | |
| "history": history, | |
| "advanced_metrics": advanced_metrics | |
| } | |
| with open(metrics_path, "w") as f: | |
| json.dump(final_metrics, f, indent=2) | |
| print(f"π Training completed! Best loss: {best_loss:.4f} at epoch {best_epoch}") | |
| print(f"π Comprehensive metrics saved to: {metrics_path}") | |
| print(f"π¬ Advanced metrics: {advanced_metrics['summary']}") | |
| if __name__ == "__main__": | |
| main() | |