recomendation / training /train_resnet.py
Ali Mohsin
folder reorganise
72af8c3
import os
import argparse
import sys
from typing import Tuple
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
# Fix import paths
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from data.polyvore import PolyvoreTripletDataset
from models.resnet_embedder import ResNetItemEmbedder
from utils.export import ensure_export_dir
from utils.advanced_metrics import AdvancedMetrics, calculate_triplet_metrics
import json
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser()
p.add_argument("--data_root", type=str, default=os.getenv("POLYVORE_ROOT", "/home/user/app/data/Polyvore"))
p.add_argument("--epochs", type=int, default=50)
p.add_argument("--batch_size", type=int, default=16)
p.add_argument("--lr", type=float, default=1e-3)
p.add_argument("--embedding_dim", type=int, default=512)
p.add_argument("--out", type=str, default="models/exports/resnet_item_embedder.pth")
p.add_argument("--early_stopping_patience", type=int, default=10, help="Early stopping patience")
p.add_argument("--min_delta", type=float, default=1e-4, help="Minimum change to qualify as improvement")
return p.parse_args()
def main() -> None:
args = parse_args()
device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
if device == "cuda":
torch.backends.cudnn.benchmark = True
print(f"πŸš€ Starting ResNet training on {device}")
print(f"πŸ“ Data root: {args.data_root}")
print(f"βš™οΈ Config: {args.epochs} epochs, batch_size={args.batch_size}, lr={args.lr}")
# Ensure splits exist; if missing, prepare from official splits
splits_dir = os.path.join(args.data_root, "splits")
triplet_path = os.path.join(splits_dir, "train.json")
if not os.path.exists(triplet_path):
print(f"⚠️ Triplet file not found: {triplet_path}")
print("πŸ”§ Attempting to prepare dataset...")
os.makedirs(splits_dir, exist_ok=True)
try:
# Try to import and run the prepare script
sys.path.append(os.path.join(os.path.dirname(os.path.dirname(__file__)), "scripts"))
from prepare_polyvore import main as prepare_main
print("βœ… Successfully imported prepare_polyvore")
# Prepare dataset without random splits
prepare_main()
print("βœ… Dataset preparation completed")
except Exception as e:
print(f"❌ Failed to prepare dataset: {e}")
print("πŸ’‘ Please ensure the dataset is prepared manually")
return
else:
print(f"βœ… Found existing splits: {triplet_path}")
try:
dataset = PolyvoreTripletDataset(args.data_root, split="train")
print(f"πŸ“Š Dataset loaded: {len(dataset)} samples")
except Exception as e:
print(f"❌ Failed to load dataset: {e}")
return
loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=2, pin_memory=(device=="cuda"))
model = ResNetItemEmbedder(embedding_dim=args.embedding_dim).to(device)
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-4)
criterion = nn.TripletMarginLoss(margin=0.2, p=2)
print(f"πŸ—οΈ Model created: {model.__class__.__name__}")
print(f"πŸ“ˆ Total parameters: {sum(p.numel() for p in model.parameters()):,}")
export_dir = ensure_export_dir(os.path.dirname(args.out) or "models/exports")
best_loss = float("inf")
history = []
patience_counter = 0
best_epoch = 0
metrics_collector = AdvancedMetrics()
print(f"πŸ’Ύ Checkpoints will be saved to: {export_dir}")
print(f"πŸ›‘ Early stopping patience: {args.early_stopping_patience} epochs")
for epoch in range(args.epochs):
model.train()
running_loss = 0.0
steps = 0
print(f"πŸ”„ Epoch {epoch+1}/{args.epochs}")
for batch_idx, batch in enumerate(loader):
try:
# Expect batch as (anchor, positive, negative)
anchor, positive, negative = batch
anchor = anchor.to(device, memory_format=torch.channels_last, non_blocking=True)
positive = positive.to(device, memory_format=torch.channels_last, non_blocking=True)
negative = negative.to(device, memory_format=torch.channels_last, non_blocking=True)
with torch.autocast(device_type=("cuda" if device=="cuda" else "cpu"), enabled=(device=="cuda")):
emb_a = model(anchor)
emb_p = model(positive)
emb_n = model(negative)
loss = criterion(emb_a, emb_p, emb_n)
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
# Collect metrics
triplet_metrics = calculate_triplet_metrics(emb_a, emb_p, emb_n, margin=0.2)
metrics_collector.add_batch(
predictions=torch.ones(emb_a.size(0)), # Placeholder for compatibility
targets=torch.ones(emb_a.size(0)), # Placeholder for compatibility
embeddings=emb_a.detach() # Detach to avoid gradient issues
)
running_loss += loss.item()
steps += 1
if batch_idx % 10 == 0: # More frequent logging
print(f" Batch {batch_idx}/{len(loader)}: loss={loss.item():.4f}")
except Exception as e:
print(f"❌ Error in batch {batch_idx}: {e}")
continue
# Print final batch completion
print(f" βœ… Batch {len(loader)-1}/{len(loader)}: loss={loss.item():.4f}")
print(f" πŸ“Š Epoch {epoch+1} completed: {len(loader)} batches processed")
avg_loss = running_loss / max(1, steps)
# Save checkpoint with better path handling
out_path = args.out
if not out_path.startswith("models/"):
out_path = os.path.join(export_dir, os.path.basename(args.out))
# Ensure the output directory exists
os.makedirs(os.path.dirname(out_path), exist_ok=True)
# Save checkpoint
torch.save({"state_dict": model.state_dict(), "epoch": epoch+1, "loss": avg_loss}, out_path)
print(f"βœ… Epoch {epoch+1}/{args.epochs} avg_triplet_loss={avg_loss:.4f} saved -> {out_path}")
history.append({"epoch": epoch + 1, "avg_triplet_loss": avg_loss})
# Early stopping logic
if avg_loss < best_loss - args.min_delta:
best_loss = avg_loss
best_epoch = epoch + 1
patience_counter = 0
best_path = os.path.join(export_dir, "resnet_item_embedder_best.pth")
torch.save({"state_dict": model.state_dict(), "epoch": epoch+1, "loss": avg_loss}, best_path)
print(f"πŸ† New best model saved: {best_path} (loss: {avg_loss:.4f})")
else:
patience_counter += 1
print(f"⏳ No improvement for {patience_counter} epochs (best: {best_loss:.4f} at epoch {best_epoch})")
if patience_counter >= args.early_stopping_patience:
print(f"πŸ›‘ Early stopping triggered after {patience_counter} epochs without improvement")
print(f"πŸ† Best model was at epoch {best_epoch} with loss {best_loss:.4f}")
break
# Write comprehensive metrics
metrics_path = os.path.join(export_dir, "resnet_metrics.json")
# Get advanced metrics
advanced_metrics = metrics_collector.calculate_all_metrics()
final_metrics = {
"best_triplet_loss": best_loss,
"best_epoch": best_epoch,
"total_epochs": epoch + 1,
"early_stopping_triggered": patience_counter >= args.early_stopping_patience,
"patience_counter": patience_counter,
"training_config": {
"epochs": args.epochs,
"batch_size": args.batch_size,
"learning_rate": args.lr,
"embedding_dim": args.embedding_dim,
"early_stopping_patience": args.early_stopping_patience,
"min_delta": args.min_delta
},
"history": history,
"advanced_metrics": advanced_metrics
}
with open(metrics_path, "w") as f:
json.dump(final_metrics, f, indent=2)
print(f"πŸ“Š Training completed! Best loss: {best_loss:.4f} at epoch {best_epoch}")
print(f"πŸ“ˆ Comprehensive metrics saved to: {metrics_path}")
print(f"πŸ”¬ Advanced metrics: {advanced_metrics['summary']}")
if __name__ == "__main__":
main()