Spaces:
Paused
Paused
Ali Mohsin
commited on
Commit
·
58e8faf
1
Parent(s):
283a139
Hoping these are the last fixes I am ever going to make
Browse files- train_vit_triplet.py +12 -11
train_vit_triplet.py
CHANGED
|
@@ -152,15 +152,9 @@ def main() -> None:
|
|
| 152 |
loss.backward()
|
| 153 |
optimizer.step()
|
| 154 |
|
| 155 |
-
# Collect metrics
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
torch.cat([torch.ones(ea.size(0)), torch.ones(ep.size(0)), torch.zeros(en.size(0))], dim=0)
|
| 159 |
-
)
|
| 160 |
-
metrics_collector.add_batch(
|
| 161 |
-
predictions=torch.cat([ea, ep, en], dim=0).detach(),
|
| 162 |
-
targets=torch.cat([torch.ones(ea.size(0)), torch.ones(ep.size(0)), torch.zeros(en.size(0))], dim=0)
|
| 163 |
-
)
|
| 164 |
|
| 165 |
running_loss += loss.item()
|
| 166 |
steps += 1
|
|
@@ -251,8 +245,15 @@ def main() -> None:
|
|
| 251 |
# Write comprehensive metrics
|
| 252 |
metrics_path = os.path.join(export_dir, "vit_metrics.json")
|
| 253 |
|
| 254 |
-
# Get advanced metrics
|
| 255 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
|
| 257 |
final_metrics = {
|
| 258 |
"best_val_triplet_loss": best_loss if best_loss != float("inf") else None,
|
|
|
|
| 152 |
loss.backward()
|
| 153 |
optimizer.step()
|
| 154 |
|
| 155 |
+
# Collect metrics (simplified for ViT training)
|
| 156 |
+
# Note: ViT training uses outfit-level embeddings, not classification predictions
|
| 157 |
+
# So we skip the problematic metrics collection that expects binary targets
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
|
| 159 |
running_loss += loss.item()
|
| 160 |
steps += 1
|
|
|
|
| 245 |
# Write comprehensive metrics
|
| 246 |
metrics_path = os.path.join(export_dir, "vit_metrics.json")
|
| 247 |
|
| 248 |
+
# Get advanced metrics (simplified for ViT training)
|
| 249 |
+
# Note: ViT training doesn't collect classification metrics, so we create empty metrics
|
| 250 |
+
advanced_metrics = {
|
| 251 |
+
"total_predictions": 0,
|
| 252 |
+
"total_targets": 0,
|
| 253 |
+
"total_scores": 0,
|
| 254 |
+
"total_embeddings": 0,
|
| 255 |
+
"total_outfit_scores": 0
|
| 256 |
+
}
|
| 257 |
|
| 258 |
final_metrics = {
|
| 259 |
"best_val_triplet_loss": best_loss if best_loss != float("inf") else None,
|