File size: 8,885 Bytes
4716563
 
24ea486
4716563
 
 
 
 
 
 
24ea486
72af8c3
24ea486
4716563
 
 
25bdf34
4716563
 
 
 
 
24ea486
25bdf34
 
4716563
 
 
25bdf34
 
4716563
 
 
 
 
 
42733e7
 
4716563
24ea486
 
 
 
55c158e
 
 
24ea486
55c158e
24ea486
 
55c158e
 
24ea486
72af8c3
24ea486
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4716563
 
 
 
24ea486
 
 
4716563
 
 
25bdf34
 
 
24ea486
 
25bdf34
24ea486
4716563
 
24ea486
4716563
24ea486
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25bdf34
 
 
 
 
78db21d
25bdf34
 
24ea486
 
 
12f4b61
24ea486
 
 
 
 
 
12f4b61
 
 
 
24ea486
 
 
4716563
 
 
24ea486
 
 
 
 
 
 
 
4716563
24ea486
25bdf34
 
4716563
25bdf34
 
24ea486
 
25bdf34
 
 
 
 
 
 
 
 
4716563
25bdf34
4716563
25bdf34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4716563
25bdf34
24ea486
25bdf34
 
 
4716563
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
import os
import argparse
import sys
from typing import Tuple

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

# Fix import paths
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from data.polyvore import PolyvoreTripletDataset
from models.resnet_embedder import ResNetItemEmbedder
from utils.export import ensure_export_dir
from utils.advanced_metrics import AdvancedMetrics, calculate_triplet_metrics
import json


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser()
    p.add_argument("--data_root", type=str, default=os.getenv("POLYVORE_ROOT", "/home/user/app/data/Polyvore"))
    p.add_argument("--epochs", type=int, default=50)
    p.add_argument("--batch_size", type=int, default=16)
    p.add_argument("--lr", type=float, default=1e-3)
    p.add_argument("--embedding_dim", type=int, default=512)
    p.add_argument("--out", type=str, default="models/exports/resnet_item_embedder.pth")
    p.add_argument("--early_stopping_patience", type=int, default=10, help="Early stopping patience")
    p.add_argument("--min_delta", type=float, default=1e-4, help="Minimum change to qualify as improvement")
    return p.parse_args()


def main() -> None:
    args = parse_args()
    device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
    if device == "cuda":
        torch.backends.cudnn.benchmark = True

    print(f"πŸš€ Starting ResNet training on {device}")
    print(f"πŸ“ Data root: {args.data_root}")
    print(f"βš™οΈ  Config: {args.epochs} epochs, batch_size={args.batch_size}, lr={args.lr}")

    # Ensure splits exist; if missing, prepare from official splits
    splits_dir = os.path.join(args.data_root, "splits")
    triplet_path = os.path.join(splits_dir, "train.json")
    
    if not os.path.exists(triplet_path):
        print(f"⚠️  Triplet file not found: {triplet_path}")
        print("πŸ”§ Attempting to prepare dataset...")
        os.makedirs(splits_dir, exist_ok=True)
        try:
            # Try to import and run the prepare script
            sys.path.append(os.path.join(os.path.dirname(os.path.dirname(__file__)), "scripts"))
            from prepare_polyvore import main as prepare_main
            print("βœ… Successfully imported prepare_polyvore")
            
            # Prepare dataset without random splits
            prepare_main()
            print("βœ… Dataset preparation completed")
        except Exception as e:
            print(f"❌ Failed to prepare dataset: {e}")
            print("πŸ’‘ Please ensure the dataset is prepared manually")
            return
    else:
        print(f"βœ… Found existing splits: {triplet_path}")

    try:
        dataset = PolyvoreTripletDataset(args.data_root, split="train")
        print(f"πŸ“Š Dataset loaded: {len(dataset)} samples")
    except Exception as e:
        print(f"❌ Failed to load dataset: {e}")
        return

    loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=2, pin_memory=(device=="cuda"))
    model = ResNetItemEmbedder(embedding_dim=args.embedding_dim).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-4)
    criterion = nn.TripletMarginLoss(margin=0.2, p=2)

    print(f"πŸ—οΈ  Model created: {model.__class__.__name__}")
    print(f"πŸ“ˆ Total parameters: {sum(p.numel() for p in model.parameters()):,}")

    export_dir = ensure_export_dir(os.path.dirname(args.out) or "models/exports")
    best_loss = float("inf")
    history = []
    patience_counter = 0
    best_epoch = 0
    metrics_collector = AdvancedMetrics()
    
    print(f"πŸ’Ύ Checkpoints will be saved to: {export_dir}")
    print(f"πŸ›‘ Early stopping patience: {args.early_stopping_patience} epochs")
    
    for epoch in range(args.epochs):
        model.train()
        running_loss = 0.0
        steps = 0
        
        print(f"πŸ”„ Epoch {epoch+1}/{args.epochs}")
        
        for batch_idx, batch in enumerate(loader):
            try:
                # Expect batch as (anchor, positive, negative)
                anchor, positive, negative = batch
                anchor = anchor.to(device, memory_format=torch.channels_last, non_blocking=True)
                positive = positive.to(device, memory_format=torch.channels_last, non_blocking=True)
                negative = negative.to(device, memory_format=torch.channels_last, non_blocking=True)
                
                with torch.autocast(device_type=("cuda" if device=="cuda" else "cpu"), enabled=(device=="cuda")):
                    emb_a = model(anchor)
                    emb_p = model(positive)
                    emb_n = model(negative)
                
                loss = criterion(emb_a, emb_p, emb_n)
                optimizer.zero_grad(set_to_none=True)
                loss.backward()
                optimizer.step()
                
                # Collect metrics
                triplet_metrics = calculate_triplet_metrics(emb_a, emb_p, emb_n, margin=0.2)
                metrics_collector.add_batch(
                    predictions=torch.ones(emb_a.size(0)),  # Placeholder for compatibility
                    targets=torch.ones(emb_a.size(0)),      # Placeholder for compatibility
                    embeddings=emb_a.detach()  # Detach to avoid gradient issues
                )
                
                running_loss += loss.item()
                steps += 1
                
                if batch_idx % 10 == 0:  # More frequent logging
                    print(f"  Batch {batch_idx}/{len(loader)}: loss={loss.item():.4f}")
                    
            except Exception as e:
                print(f"❌ Error in batch {batch_idx}: {e}")
                continue
        
        # Print final batch completion
        print(f"  βœ… Batch {len(loader)-1}/{len(loader)}: loss={loss.item():.4f}")
        print(f"  πŸ“Š Epoch {epoch+1} completed: {len(loader)} batches processed")
        
        avg_loss = running_loss / max(1, steps)
        
        # Save checkpoint with better path handling
        out_path = args.out
        if not out_path.startswith("models/"):
            out_path = os.path.join(export_dir, os.path.basename(args.out))
        
        # Ensure the output directory exists
        os.makedirs(os.path.dirname(out_path), exist_ok=True)
        
        # Save checkpoint
        torch.save({"state_dict": model.state_dict(), "epoch": epoch+1, "loss": avg_loss}, out_path)
        print(f"βœ… Epoch {epoch+1}/{args.epochs} avg_triplet_loss={avg_loss:.4f} saved -> {out_path}")
        
        history.append({"epoch": epoch + 1, "avg_triplet_loss": avg_loss})
        
        # Early stopping logic
        if avg_loss < best_loss - args.min_delta:
            best_loss = avg_loss
            best_epoch = epoch + 1
            patience_counter = 0
            best_path = os.path.join(export_dir, "resnet_item_embedder_best.pth")
            torch.save({"state_dict": model.state_dict(), "epoch": epoch+1, "loss": avg_loss}, best_path)
            print(f"πŸ† New best model saved: {best_path} (loss: {avg_loss:.4f})")
        else:
            patience_counter += 1
            print(f"⏳ No improvement for {patience_counter} epochs (best: {best_loss:.4f} at epoch {best_epoch})")
            
            if patience_counter >= args.early_stopping_patience:
                print(f"πŸ›‘ Early stopping triggered after {patience_counter} epochs without improvement")
                print(f"πŸ† Best model was at epoch {best_epoch} with loss {best_loss:.4f}")
                break

    # Write comprehensive metrics
    metrics_path = os.path.join(export_dir, "resnet_metrics.json")
    
    # Get advanced metrics
    advanced_metrics = metrics_collector.calculate_all_metrics()
    
    final_metrics = {
        "best_triplet_loss": best_loss,
        "best_epoch": best_epoch,
        "total_epochs": epoch + 1,
        "early_stopping_triggered": patience_counter >= args.early_stopping_patience,
        "patience_counter": patience_counter,
        "training_config": {
            "epochs": args.epochs,
            "batch_size": args.batch_size,
            "learning_rate": args.lr,
            "embedding_dim": args.embedding_dim,
            "early_stopping_patience": args.early_stopping_patience,
            "min_delta": args.min_delta
        },
        "history": history,
        "advanced_metrics": advanced_metrics
    }
    
    with open(metrics_path, "w") as f:
        json.dump(final_metrics, f, indent=2)
    
    print(f"πŸ“Š Training completed! Best loss: {best_loss:.4f} at epoch {best_epoch}")
    print(f"πŸ“ˆ Comprehensive metrics saved to: {metrics_path}")
    print(f"πŸ”¬ Advanced metrics: {advanced_metrics['summary']}")


if __name__ == "__main__":
    main()