File size: 15,086 Bytes
4716563
 
24ea486
4716563
 
 
 
 
 
 
24ea486
72af8c3
24ea486
4716563
 
 
 
25bdf34
4716563
 
 
 
 
24ea486
25bdf34
227af5e
4716563
 
941ea8d
4716563
 
227af5e
941ea8d
 
25bdf34
941ea8d
 
4716563
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42733e7
 
4716563
24ea486
 
 
 
55c158e
 
 
24ea486
55c158e
24ea486
 
55c158e
 
24ea486
72af8c3
24ea486
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227af5e
 
 
 
 
 
 
 
24ea486
 
 
4716563
 
 
 
24ea486
4716563
 
 
 
 
 
24ea486
 
 
 
 
4716563
 
941ea8d
 
 
 
 
 
 
 
 
 
 
4716563
 
 
 
25bdf34
 
 
24ea486
 
25bdf34
24ea486
4716563
 
24ea486
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
941ea8d
 
 
 
 
24ea486
941ea8d
24ea486
58e8faf
 
 
25bdf34
24ea486
 
 
12f4b61
24ea486
 
 
 
 
 
12f4b61
 
 
 
24ea486
 
227af5e
4716563
 
24ea486
227af5e
24ea486
227af5e
24ea486
227af5e
 
 
 
24ea486
 
 
 
227af5e
 
 
24ea486
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227af5e
24ea486
 
 
227af5e
4716563
 
 
 
24ea486
 
 
 
4716563
24ea486
 
25bdf34
 
 
4716563
25bdf34
 
24ea486
 
25bdf34
 
 
 
 
 
 
 
 
4716563
24ea486
 
4716563
25bdf34
4716563
25bdf34
58e8faf
 
 
 
 
 
 
 
 
25bdf34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4716563
25bdf34
24ea486
1216fc5
 
 
 
 
 
25bdf34
 
1216fc5
4716563
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
import os
import argparse
import sys
from typing import List

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

# Fix import paths
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from data.polyvore import PolyvoreOutfitTripletDataset
from models.vit_outfit import OutfitCompatibilityModel
from models.resnet_embedder import ResNetItemEmbedder
from utils.export import ensure_export_dir
from utils.advanced_metrics import AdvancedMetrics, calculate_outfit_compatibility_metrics
import json


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser()
    p.add_argument("--data_root", type=str, default=os.getenv("POLYVORE_ROOT", "/home/user/app/data/Polyvore"))
    p.add_argument("--epochs", type=int, default=50)
    p.add_argument("--batch_size", type=int, default=4)
    p.add_argument("--lr", type=float, default=5e-4)
    p.add_argument("--embedding_dim", type=int, default=512)
    p.add_argument("--triplet_margin", type=float, default=0.5)
    p.add_argument("--export", type=str, default="models/exports/vit_outfit_model.pth")
    p.add_argument("--eval_every", type=int, default=1)
    p.add_argument("--skip_validation", action="store_true", help="Skip validation for faster training")
    p.add_argument("--max_samples", type=int, default=5000, help="Maximum number of training samples (for better quality)")
    p.add_argument("--early_stopping_patience", type=int, default=5, help="Early stopping patience")
    p.add_argument("--min_delta", type=float, default=1e-4, help="Minimum change to qualify as improvement")
    p.add_argument("--gradient_clip", type=float, default=1.0, help="Gradient clipping value")
    p.add_argument("--warmup_epochs", type=int, default=2, help="Learning rate warmup epochs")
    return p.parse_args()


def embed_outfit(imgs: List[torch.Tensor], embedder: ResNetItemEmbedder, device: str, max_len: int = 4) -> torch.Tensor:
    if len(imgs) == 0:
        return torch.zeros((max_len, embedder.proj.out_features), device=device)
    k = min(len(imgs), max_len)
    x = torch.stack(imgs[:k], dim=0).to(device)
    with torch.no_grad():
        e = embedder(x)  # (k, D)
    tokens = torch.zeros((max_len, e.shape[-1]), device=device)
    tokens[:k] = e
    return tokens


def main() -> None:
    args = parse_args()
    device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
    if device == "cuda":
        torch.backends.cudnn.benchmark = True

    print(f"πŸš€ Starting ViT Outfit training on {device}")
    print(f"πŸ“ Data root: {args.data_root}")
    print(f"βš™οΈ  Config: {args.epochs} epochs, batch_size={args.batch_size}, lr={args.lr}")

    # Ensure outfit triplets exist
    splits_dir = os.path.join(args.data_root, "splits")
    trip_path = os.path.join(splits_dir, "outfit_triplets_train.json")
    
    if not os.path.exists(trip_path):
        print(f"⚠️  Outfit triplet file not found: {trip_path}")
        print("πŸ”§ Attempting to prepare dataset...")
        os.makedirs(splits_dir, exist_ok=True)
        try:
            # Try to import and run the prepare script
            sys.path.append(os.path.join(os.path.dirname(os.path.dirname(__file__)), "scripts"))
            from prepare_polyvore import main as prepare_main
            print("βœ… Successfully imported prepare_polyvore")
            
            # Prepare dataset without random splits
            prepare_main()
            print("βœ… Dataset preparation completed")
        except Exception as e:
            print(f"❌ Failed to prepare dataset: {e}")
            print("πŸ’‘ Please ensure the dataset is prepared manually")
            return
    else:
        print(f"βœ… Found existing outfit triplets: {trip_path}")

    try:
        dataset = PolyvoreOutfitTripletDataset(args.data_root, split="train")
        # Limit dataset size for faster training/testing
        max_samples = min(len(dataset), args.max_samples)
        print(f"πŸ” Debug: Original dataset size: {len(dataset)}, max_samples: {args.max_samples}")
        if len(dataset) > max_samples:
            dataset.samples = dataset.samples[:max_samples]
            print(f"πŸ“Š Dataset limited to {max_samples} samples for faster training")
        else:
            print(f"πŸ“Š Dataset loaded: {len(dataset)} samples (no limiting needed)")
    except Exception as e:
        print(f"❌ Failed to load dataset: {e}")
        return

    def collate(batch):
        return batch  # variable length handled inside training loop

    loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=2, pin_memory=(device=="cuda"), collate_fn=collate)

    model = OutfitCompatibilityModel(embedding_dim=args.embedding_dim).to(device)
    embedder = ResNetItemEmbedder(embedding_dim=args.embedding_dim).to(device).eval()
    for p in embedder.parameters():
        p.requires_grad_(False)

    print(f"πŸ—οΈ  Models created:")
    print(f"  - ViT Outfit: {model.__class__.__name__}")
    print(f"  - ResNet Embedder: {embedder.__class__.__name__}")
    print(f"πŸ“ˆ Total parameters: {sum(p.numel() for p in model.parameters()):,}")

    optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=5e-2)
    triplet = nn.TripletMarginWithDistanceLoss(distance_function=lambda x, y: 1 - nn.functional.cosine_similarity(x, y), margin=args.triplet_margin)
    
    # Learning rate scheduler with warmup
    total_steps = len(loader) * args.epochs
    warmup_steps = len(loader) * args.warmup_epochs
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer, 
        max_lr=args.lr, 
        total_steps=total_steps,
        pct_start=warmup_steps/total_steps,
        anneal_strategy='cos'
    )

    export_dir = ensure_export_dir(os.path.dirname(args.export) or "models/exports")
    best_loss = float("inf")
    hist = []
    patience_counter = 0
    best_epoch = 0
    metrics_collector = AdvancedMetrics()
    
    print(f"πŸ’Ύ Checkpoints will be saved to: {export_dir}")
    print(f"πŸ›‘ Early stopping patience: {args.early_stopping_patience} epochs")
    
    for epoch in range(args.epochs):
        model.train()
        running_loss = 0.0
        steps = 0
        
        print(f"πŸ”„ Epoch {epoch+1}/{args.epochs}")
        
        for batch_idx, batch in enumerate(loader):
            try:
                # batch: List[(ga_imgs, gb_imgs, bd_imgs)]
                anchor_tokens = []
                positive_tokens = []
                negative_tokens = []
                
                for ga, gb, bd in batch:
                    ta = embed_outfit(ga, embedder, device)
                    tb = embed_outfit(gb, embedder, device)
                    tn = embed_outfit(bd, embedder, device)
                    anchor_tokens.append(ta.unsqueeze(0))
                    positive_tokens.append(tb.unsqueeze(0))
                    negative_tokens.append(tn.unsqueeze(0))
                
                A = torch.cat(anchor_tokens, dim=0)  # (B, N, D)
                P = torch.cat(positive_tokens, dim=0)
                N = torch.cat(negative_tokens, dim=0)

                # get outfit-level embeddings via ViT encoder pooled output
                with torch.autocast(device_type=("cuda" if device=="cuda" else "cpu"), enabled=(device=="cuda")):
                    ea = model.encoder(A).mean(dim=1)
                    ep = model.encoder(P).mean(dim=1)
                    en = model.encoder(N).mean(dim=1)
                    loss = triplet(ea, ep, en)
                
                optimizer.zero_grad(set_to_none=True)
                loss.backward()
                
                # Gradient clipping for stability
                if args.gradient_clip > 0:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.gradient_clip)
                
                optimizer.step()
                scheduler.step()  # Update learning rate
                
                # Collect metrics (simplified for ViT training)
                # Note: ViT training uses outfit-level embeddings, not classification predictions
                # So we skip the problematic metrics collection that expects binary targets
                
                running_loss += loss.item()
                steps += 1
                
                if batch_idx % 10 == 0:  # More frequent logging
                    print(f"  Batch {batch_idx}/{len(loader)}: loss={loss.item():.4f}")
                    
            except Exception as e:
                print(f"❌ Error in batch {batch_idx}: {e}")
                continue

        # Print final batch completion
        print(f"  βœ… Batch {len(loader)-1}/{len(loader)}: loss={loss.item():.4f}")
        print(f"  πŸ“Š Epoch {epoch+1} completed: {len(loader)} batches processed")

        avg_loss = running_loss / max(1, steps)
        
        # Fast validation with limited samples to prevent hanging
        val_path = os.path.join(args.data_root, "splits", "outfit_triplets_valid.json")
        val_loss = None
        
        if not args.skip_validation and os.path.exists(val_path) and (epoch + 1) % args.eval_every == 0:
            try:
                print(f"  πŸ” Starting validation (limited to 50 samples for speed)...")
                val_ds = PolyvoreOutfitTripletDataset(args.data_root, split="valid")
                # Limit validation to first 50 samples to prevent hanging
                val_samples = val_ds.samples[:50]
                val_ds.samples = val_samples
                val_loader = DataLoader(val_ds, batch_size=min(args.batch_size, 8), shuffle=False, num_workers=0, collate_fn=lambda x: x)
                model.eval()
                losses = []
                
                with torch.no_grad():
                    for i, vbatch in enumerate(val_loader):
                        if i >= 10:  # Limit to 10 batches max for speed
                            break
                        anchor_tokens = []
                        positive_tokens = []
                        negative_tokens = []
                        for ga, gb, bd in vbatch:
                            ta = embed_outfit(ga, embedder, device)
                            tb = embed_outfit(gb, embedder, device)
                            tn = embed_outfit(bd, embedder, device)
                            anchor_tokens.append(ta.unsqueeze(0))
                            positive_tokens.append(tb.unsqueeze(0))
                            negative_tokens.append(tn.unsqueeze(0))
                        A = torch.cat(anchor_tokens, dim=0)
                        P = torch.cat(positive_tokens, dim=0)
                        N = torch.cat(negative_tokens, dim=0)
                        ea = model.encoder(A).mean(dim=1)
                        ep = model.encoder(P).mean(dim=1)
                        en = model.encoder(N).mean(dim=1)
                        l = triplet(ea, ep, en).item()
                        losses.append(l)
                
                val_loss = sum(losses) / max(1, len(losses))
                print(f"  πŸ“Š Validation loss: {val_loss:.4f} (from {len(losses)} batches)")
                
            except Exception as e:
                print(f"  ⚠️  Validation failed: {e}")
                val_loss = None

        out_path = args.export
        if not out_path.startswith("models/"):
            out_path = os.path.join(export_dir, os.path.basename(args.export))
        
        # Save checkpoint
        torch.save({"state_dict": model.state_dict(), "epoch": epoch+1, "loss": avg_loss}, out_path)
        
        if val_loss is not None:
            print(f"βœ… Epoch {epoch+1}/{args.epochs} triplet_loss={avg_loss:.4f} val_triplet_loss={val_loss:.4f} saved -> {out_path}")
            hist.append({"epoch": epoch + 1, "triplet_loss": float(avg_loss), "val_triplet_loss": float(val_loss)})
            
            # Early stopping logic
            if val_loss < best_loss - args.min_delta:
                best_loss = val_loss
                best_epoch = epoch + 1
                patience_counter = 0
                best_path = os.path.join(export_dir, "vit_outfit_model_best.pth")
                torch.save({"state_dict": model.state_dict(), "epoch": epoch+1, "loss": avg_loss, "val_loss": val_loss}, best_path)
                print(f"πŸ† New best model saved: {best_path} (val_loss: {val_loss:.4f})")
            else:
                patience_counter += 1
                print(f"⏳ No improvement for {patience_counter} epochs (best: {best_loss:.4f} at epoch {best_epoch})")
                
                if patience_counter >= args.early_stopping_patience:
                    print(f"πŸ›‘ Early stopping triggered after {patience_counter} epochs without improvement")
                    print(f"πŸ† Best model was at epoch {best_epoch} with val_loss {best_loss:.4f}")
                    break
        else:
            print(f"βœ… Epoch {epoch+1}/{args.epochs} triplet_loss={avg_loss:.4f} saved -> {out_path}")
            hist.append({"epoch": epoch + 1, "triplet_loss": float(avg_loss)})

    # Write comprehensive metrics
    metrics_path = os.path.join(export_dir, "vit_metrics.json")
    
    # Get advanced metrics (simplified for ViT training)
    # Note: ViT training doesn't collect classification metrics, so we create empty metrics
    advanced_metrics = {
        "total_predictions": 0,
        "total_targets": 0, 
        "total_scores": 0,
        "total_embeddings": 0,
        "total_outfit_scores": 0
    }
    
    final_metrics = {
        "best_val_triplet_loss": best_loss if best_loss != float("inf") else None,
        "best_epoch": best_epoch,
        "total_epochs": epoch + 1,
        "early_stopping_triggered": patience_counter >= args.early_stopping_patience,
        "patience_counter": patience_counter,
        "training_config": {
            "epochs": args.epochs,
            "batch_size": args.batch_size,
            "learning_rate": args.lr,
            "embedding_dim": args.embedding_dim,
            "triplet_margin": args.triplet_margin,
            "early_stopping_patience": args.early_stopping_patience,
            "min_delta": args.min_delta
        },
        "history": hist,
        "advanced_metrics": advanced_metrics
    }
    
    with open(metrics_path, "w") as f:
        json.dump(final_metrics, f, indent=2)
    
    # Always save a best model (use final model if no validation was done)
    if best_loss == float("inf"):
        best_path = os.path.join(export_dir, "vit_outfit_model_best.pth")
        torch.save({"state_dict": model.state_dict(), "epoch": epoch+1, "loss": avg_loss}, best_path)
        print(f"πŸ† Final model saved as best: {best_path} (loss: {avg_loss:.4f})")
    
    print(f"πŸ“Š Training completed! Best val_loss: {best_loss:.4f} at epoch {best_epoch}")
    print(f"πŸ“ˆ Comprehensive metrics saved to: {metrics_path}")
    print(f"πŸ”¬ Advanced metrics: {advanced_metrics}")


if __name__ == "__main__":
    main()