Spaces:
Paused
Paused
Ali Mohsin
commited on
Commit
Β·
12f4b61
1
Parent(s):
6e90fc7
updates
Browse files- app.py +54 -8
- train_resnet.py +5 -1
- train_vit_triplet.py +5 -1
app.py
CHANGED
|
@@ -546,10 +546,25 @@ def start_training_advanced(
|
|
| 546 |
result = subprocess.run(resnet_cmd, capture_output=True, text=True, check=False)
|
| 547 |
|
| 548 |
if result.returncode == 0:
|
| 549 |
-
log_message += "β
ResNet training completed successfully!\n
|
|
|
|
| 550 |
else:
|
| 551 |
log_message += f"β ResNet training failed: {result.stderr}\n\n"
|
| 552 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 553 |
|
| 554 |
# Train ViT with custom parameters
|
| 555 |
log_message += f"π Starting ViT training with custom parameters...\n"
|
|
@@ -573,7 +588,8 @@ def start_training_advanced(
|
|
| 573 |
result = subprocess.run(vit_cmd, capture_output=True, text=True, check=False)
|
| 574 |
|
| 575 |
if result.returncode == 0:
|
| 576 |
-
log_message += "β
ViT training completed successfully!\n
|
|
|
|
| 577 |
log_message += "π All training completed! Models saved to models/exports/\n"
|
| 578 |
log_message += "π Reloading models for inference...\n"
|
| 579 |
service.reload_models()
|
|
@@ -629,15 +645,45 @@ def start_training_simple(dataset_size: str, res_epochs: int, vit_epochs: int):
|
|
| 629 |
if dataset_size != "full":
|
| 630 |
dataset_args = ["--max_samples", dataset_size]
|
| 631 |
|
| 632 |
-
|
|
|
|
|
|
|
| 633 |
"python", "train_resnet.py", "--data_root", DATASET_ROOT, "--epochs", str(res_epochs),
|
| 634 |
"--out", os.path.join(export_dir, "resnet_item_embedder.pth")
|
| 635 |
-
] + dataset_args, check=False)
|
| 636 |
-
|
| 637 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 638 |
"python", "train_vit_triplet.py", "--data_root", DATASET_ROOT, "--epochs", str(vit_epochs),
|
| 639 |
"--export", os.path.join(export_dir, "vit_outfit_model.pth")
|
| 640 |
-
] + dataset_args, check=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 641 |
service.reload_models()
|
| 642 |
log_message += "\nDone. Artifacts in models/exports."
|
| 643 |
|
|
|
|
| 546 |
result = subprocess.run(resnet_cmd, capture_output=True, text=True, check=False)
|
| 547 |
|
| 548 |
if result.returncode == 0:
|
| 549 |
+
log_message += "β
ResNet training completed successfully!\n"
|
| 550 |
+
log_message += f"π ResNet Output:\n{result.stdout}\n\n"
|
| 551 |
else:
|
| 552 |
log_message += f"β ResNet training failed: {result.stderr}\n\n"
|
| 553 |
+
return log_message
|
| 554 |
+
|
| 555 |
+
# Wait a moment for file system sync and ensure ResNet is fully saved
|
| 556 |
+
import time
|
| 557 |
+
time.sleep(3)
|
| 558 |
+
log_message += "β³ Waiting for ResNet checkpoint to be fully saved...\n"
|
| 559 |
+
|
| 560 |
+
# Verify ResNet checkpoint exists before proceeding
|
| 561 |
+
resnet_checkpoint = os.path.join(export_dir, "resnet_item_embedder_custom.pth")
|
| 562 |
+
if not os.path.exists(resnet_checkpoint):
|
| 563 |
+
log_message += f"β ResNet checkpoint not found at {resnet_checkpoint}\n"
|
| 564 |
+
log_message += "Cannot proceed with ViT training without ResNet embeddings.\n"
|
| 565 |
+
return log_message
|
| 566 |
+
|
| 567 |
+
log_message += f"β
ResNet checkpoint verified: {resnet_checkpoint}\n"
|
| 568 |
|
| 569 |
# Train ViT with custom parameters
|
| 570 |
log_message += f"π Starting ViT training with custom parameters...\n"
|
|
|
|
| 588 |
result = subprocess.run(vit_cmd, capture_output=True, text=True, check=False)
|
| 589 |
|
| 590 |
if result.returncode == 0:
|
| 591 |
+
log_message += "β
ViT training completed successfully!\n"
|
| 592 |
+
log_message += f"π ViT Output:\n{result.stdout}\n\n"
|
| 593 |
log_message += "π All training completed! Models saved to models/exports/\n"
|
| 594 |
log_message += "π Reloading models for inference...\n"
|
| 595 |
service.reload_models()
|
|
|
|
| 645 |
if dataset_size != "full":
|
| 646 |
dataset_args = ["--max_samples", dataset_size]
|
| 647 |
|
| 648 |
+
# Train ResNet first and wait for completion
|
| 649 |
+
log_message += f"\nπ Starting ResNet training on {dataset_size} samples...\n"
|
| 650 |
+
resnet_result = subprocess.run([
|
| 651 |
"python", "train_resnet.py", "--data_root", DATASET_ROOT, "--epochs", str(res_epochs),
|
| 652 |
"--out", os.path.join(export_dir, "resnet_item_embedder.pth")
|
| 653 |
+
] + dataset_args, capture_output=True, text=True, check=False)
|
| 654 |
+
|
| 655 |
+
if resnet_result.returncode == 0:
|
| 656 |
+
log_message += "β
ResNet training completed successfully!\n"
|
| 657 |
+
log_message += f"π ResNet Output:\n{resnet_result.stdout}\n"
|
| 658 |
+
else:
|
| 659 |
+
log_message += f"β ResNet training failed: {resnet_result.stderr}\n"
|
| 660 |
+
return log_message
|
| 661 |
+
|
| 662 |
+
# Wait a moment for file system sync
|
| 663 |
+
import time
|
| 664 |
+
time.sleep(2)
|
| 665 |
+
|
| 666 |
+
# Verify ResNet checkpoint exists before proceeding
|
| 667 |
+
resnet_checkpoint = os.path.join(export_dir, "resnet_item_embedder.pth")
|
| 668 |
+
if not os.path.exists(resnet_checkpoint):
|
| 669 |
+
log_message += f"β ResNet checkpoint not found at {resnet_checkpoint}\n"
|
| 670 |
+
log_message += "Cannot proceed with ViT training without ResNet embeddings.\n"
|
| 671 |
+
return log_message
|
| 672 |
+
|
| 673 |
+
log_message += f"β
ResNet checkpoint verified: {resnet_checkpoint}\n"
|
| 674 |
+
|
| 675 |
+
log_message += f"\nπ Starting ViT training on {dataset_size} samples...\n"
|
| 676 |
+
vit_result = subprocess.run([
|
| 677 |
"python", "train_vit_triplet.py", "--data_root", DATASET_ROOT, "--epochs", str(vit_epochs),
|
| 678 |
"--export", os.path.join(export_dir, "vit_outfit_model.pth")
|
| 679 |
+
] + dataset_args, capture_output=True, text=True, check=False)
|
| 680 |
+
|
| 681 |
+
if vit_result.returncode == 0:
|
| 682 |
+
log_message += "β
ViT training completed successfully!\n"
|
| 683 |
+
log_message += f"π ViT Output:\n{vit_result.stdout}\n"
|
| 684 |
+
else:
|
| 685 |
+
log_message += f"β ViT training failed: {vit_result.stderr}\n"
|
| 686 |
+
return log_message
|
| 687 |
service.reload_models()
|
| 688 |
log_message += "\nDone. Artifacts in models/exports."
|
| 689 |
|
train_resnet.py
CHANGED
|
@@ -111,13 +111,17 @@ def main() -> None:
|
|
| 111 |
running_loss += loss.item()
|
| 112 |
steps += 1
|
| 113 |
|
| 114 |
-
if batch_idx %
|
| 115 |
print(f" Batch {batch_idx}/{len(loader)}: loss={loss.item():.4f}")
|
| 116 |
|
| 117 |
except Exception as e:
|
| 118 |
print(f"β Error in batch {batch_idx}: {e}")
|
| 119 |
continue
|
| 120 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
avg_loss = running_loss / max(1, steps)
|
| 122 |
|
| 123 |
# Save checkpoint with better path handling
|
|
|
|
| 111 |
running_loss += loss.item()
|
| 112 |
steps += 1
|
| 113 |
|
| 114 |
+
if batch_idx % 10 == 0: # More frequent logging
|
| 115 |
print(f" Batch {batch_idx}/{len(loader)}: loss={loss.item():.4f}")
|
| 116 |
|
| 117 |
except Exception as e:
|
| 118 |
print(f"β Error in batch {batch_idx}: {e}")
|
| 119 |
continue
|
| 120 |
|
| 121 |
+
# Print final batch completion
|
| 122 |
+
print(f" β
Batch {len(loader)-1}/{len(loader)}: loss={loss.item():.4f}")
|
| 123 |
+
print(f" π Epoch {epoch+1} completed: {len(loader)} batches processed")
|
| 124 |
+
|
| 125 |
avg_loss = running_loss / max(1, steps)
|
| 126 |
|
| 127 |
# Save checkpoint with better path handling
|
train_vit_triplet.py
CHANGED
|
@@ -148,13 +148,17 @@ def main() -> None:
|
|
| 148 |
running_loss += loss.item()
|
| 149 |
steps += 1
|
| 150 |
|
| 151 |
-
if batch_idx %
|
| 152 |
print(f" Batch {batch_idx}/{len(loader)}: loss={loss.item():.4f}")
|
| 153 |
|
| 154 |
except Exception as e:
|
| 155 |
print(f"β Error in batch {batch_idx}: {e}")
|
| 156 |
continue
|
| 157 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
avg_loss = running_loss / max(1, steps)
|
| 159 |
|
| 160 |
# Simple validation using a subset of training data as a proxy if no val split here
|
|
|
|
| 148 |
running_loss += loss.item()
|
| 149 |
steps += 1
|
| 150 |
|
| 151 |
+
if batch_idx % 10 == 0: # More frequent logging
|
| 152 |
print(f" Batch {batch_idx}/{len(loader)}: loss={loss.item():.4f}")
|
| 153 |
|
| 154 |
except Exception as e:
|
| 155 |
print(f"β Error in batch {batch_idx}: {e}")
|
| 156 |
continue
|
| 157 |
|
| 158 |
+
# Print final batch completion
|
| 159 |
+
print(f" β
Batch {len(loader)-1}/{len(loader)}: loss={loss.item():.4f}")
|
| 160 |
+
print(f" π Epoch {epoch+1} completed: {len(loader)} batches processed")
|
| 161 |
+
|
| 162 |
avg_loss = running_loss / max(1, steps)
|
| 163 |
|
| 164 |
# Simple validation using a subset of training data as a proxy if no val split here
|