Ali Mohsin commited on
Commit
58e8faf
·
1 Parent(s): 283a139

Hoping these are the last fixes I am ever going to make

Browse files
Files changed (1) hide show
  1. train_vit_triplet.py +12 -11
train_vit_triplet.py CHANGED
@@ -152,15 +152,9 @@ def main() -> None:
152
  loss.backward()
153
  optimizer.step()
154
 
155
- # Collect metrics
156
- compatibility_metrics = calculate_outfit_compatibility_metrics(
157
- torch.cat([ea, ep, en], dim=0).detach(),
158
- torch.cat([torch.ones(ea.size(0)), torch.ones(ep.size(0)), torch.zeros(en.size(0))], dim=0)
159
- )
160
- metrics_collector.add_batch(
161
- predictions=torch.cat([ea, ep, en], dim=0).detach(),
162
- targets=torch.cat([torch.ones(ea.size(0)), torch.ones(ep.size(0)), torch.zeros(en.size(0))], dim=0)
163
- )
164
 
165
  running_loss += loss.item()
166
  steps += 1
@@ -251,8 +245,15 @@ def main() -> None:
251
  # Write comprehensive metrics
252
  metrics_path = os.path.join(export_dir, "vit_metrics.json")
253
 
254
- # Get advanced metrics
255
- advanced_metrics = metrics_collector.calculate_all_metrics()
 
 
 
 
 
 
 
256
 
257
  final_metrics = {
258
  "best_val_triplet_loss": best_loss if best_loss != float("inf") else None,
 
152
  loss.backward()
153
  optimizer.step()
154
 
155
+ # Collect metrics (simplified for ViT training)
156
+ # Note: ViT training uses outfit-level embeddings, not classification predictions
157
+ # So we skip the problematic metrics collection that expects binary targets
 
 
 
 
 
 
158
 
159
  running_loss += loss.item()
160
  steps += 1
 
245
  # Write comprehensive metrics
246
  metrics_path = os.path.join(export_dir, "vit_metrics.json")
247
 
248
+ # Get advanced metrics (simplified for ViT training)
249
+ # Note: ViT training doesn't collect classification metrics, so we create empty metrics
250
+ advanced_metrics = {
251
+ "total_predictions": 0,
252
+ "total_targets": 0,
253
+ "total_scores": 0,
254
+ "total_embeddings": 0,
255
+ "total_outfit_scores": 0
256
+ }
257
 
258
  final_metrics = {
259
  "best_val_triplet_loss": best_loss if best_loss != float("inf") else None,