Ali Mohsin commited on
Commit
12f4b61
Β·
1 Parent(s): 6e90fc7
Files changed (3) hide show
  1. app.py +54 -8
  2. train_resnet.py +5 -1
  3. train_vit_triplet.py +5 -1
app.py CHANGED
@@ -546,10 +546,25 @@ def start_training_advanced(
546
  result = subprocess.run(resnet_cmd, capture_output=True, text=True, check=False)
547
 
548
  if result.returncode == 0:
549
- log_message += "βœ… ResNet training completed successfully!\n\n"
 
550
  else:
551
  log_message += f"❌ ResNet training failed: {result.stderr}\n\n"
552
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
553
 
554
  # Train ViT with custom parameters
555
  log_message += f"πŸš€ Starting ViT training with custom parameters...\n"
@@ -573,7 +588,8 @@ def start_training_advanced(
573
  result = subprocess.run(vit_cmd, capture_output=True, text=True, check=False)
574
 
575
  if result.returncode == 0:
576
- log_message += "βœ… ViT training completed successfully!\n\n"
 
577
  log_message += "πŸŽ‰ All training completed! Models saved to models/exports/\n"
578
  log_message += "πŸ”„ Reloading models for inference...\n"
579
  service.reload_models()
@@ -629,15 +645,45 @@ def start_training_simple(dataset_size: str, res_epochs: int, vit_epochs: int):
629
  if dataset_size != "full":
630
  dataset_args = ["--max_samples", dataset_size]
631
 
632
- subprocess.run([
 
 
633
  "python", "train_resnet.py", "--data_root", DATASET_ROOT, "--epochs", str(res_epochs),
634
  "--out", os.path.join(export_dir, "resnet_item_embedder.pth")
635
- ] + dataset_args, check=False)
636
- log_message += f"\nTraining ViT (triplet) on {dataset_size} samples...\n"
637
- subprocess.run([
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
638
  "python", "train_vit_triplet.py", "--data_root", DATASET_ROOT, "--epochs", str(vit_epochs),
639
  "--export", os.path.join(export_dir, "vit_outfit_model.pth")
640
- ] + dataset_args, check=False)
 
 
 
 
 
 
 
641
  service.reload_models()
642
  log_message += "\nDone. Artifacts in models/exports."
643
 
 
546
  result = subprocess.run(resnet_cmd, capture_output=True, text=True, check=False)
547
 
548
  if result.returncode == 0:
549
+ log_message += "βœ… ResNet training completed successfully!\n"
550
+ log_message += f"πŸ“Š ResNet Output:\n{result.stdout}\n\n"
551
  else:
552
  log_message += f"❌ ResNet training failed: {result.stderr}\n\n"
553
+ return log_message
554
+
555
+ # Wait a moment for file system sync and ensure ResNet is fully saved
556
+ import time
557
+ time.sleep(3)
558
+ log_message += "⏳ Waiting for ResNet checkpoint to be fully saved...\n"
559
+
560
+ # Verify ResNet checkpoint exists before proceeding
561
+ resnet_checkpoint = os.path.join(export_dir, "resnet_item_embedder_custom.pth")
562
+ if not os.path.exists(resnet_checkpoint):
563
+ log_message += f"❌ ResNet checkpoint not found at {resnet_checkpoint}\n"
564
+ log_message += "Cannot proceed with ViT training without ResNet embeddings.\n"
565
+ return log_message
566
+
567
+ log_message += f"βœ… ResNet checkpoint verified: {resnet_checkpoint}\n"
568
 
569
  # Train ViT with custom parameters
570
  log_message += f"πŸš€ Starting ViT training with custom parameters...\n"
 
588
  result = subprocess.run(vit_cmd, capture_output=True, text=True, check=False)
589
 
590
  if result.returncode == 0:
591
+ log_message += "βœ… ViT training completed successfully!\n"
592
+ log_message += f"πŸ“Š ViT Output:\n{result.stdout}\n\n"
593
  log_message += "πŸŽ‰ All training completed! Models saved to models/exports/\n"
594
  log_message += "πŸ”„ Reloading models for inference...\n"
595
  service.reload_models()
 
645
  if dataset_size != "full":
646
  dataset_args = ["--max_samples", dataset_size]
647
 
648
+ # Train ResNet first and wait for completion
649
+ log_message += f"\nπŸš€ Starting ResNet training on {dataset_size} samples...\n"
650
+ resnet_result = subprocess.run([
651
  "python", "train_resnet.py", "--data_root", DATASET_ROOT, "--epochs", str(res_epochs),
652
  "--out", os.path.join(export_dir, "resnet_item_embedder.pth")
653
+ ] + dataset_args, capture_output=True, text=True, check=False)
654
+
655
+ if resnet_result.returncode == 0:
656
+ log_message += "βœ… ResNet training completed successfully!\n"
657
+ log_message += f"πŸ“Š ResNet Output:\n{resnet_result.stdout}\n"
658
+ else:
659
+ log_message += f"❌ ResNet training failed: {resnet_result.stderr}\n"
660
+ return log_message
661
+
662
+ # Wait a moment for file system sync
663
+ import time
664
+ time.sleep(2)
665
+
666
+ # Verify ResNet checkpoint exists before proceeding
667
+ resnet_checkpoint = os.path.join(export_dir, "resnet_item_embedder.pth")
668
+ if not os.path.exists(resnet_checkpoint):
669
+ log_message += f"❌ ResNet checkpoint not found at {resnet_checkpoint}\n"
670
+ log_message += "Cannot proceed with ViT training without ResNet embeddings.\n"
671
+ return log_message
672
+
673
+ log_message += f"βœ… ResNet checkpoint verified: {resnet_checkpoint}\n"
674
+
675
+ log_message += f"\nπŸš€ Starting ViT training on {dataset_size} samples...\n"
676
+ vit_result = subprocess.run([
677
  "python", "train_vit_triplet.py", "--data_root", DATASET_ROOT, "--epochs", str(vit_epochs),
678
  "--export", os.path.join(export_dir, "vit_outfit_model.pth")
679
+ ] + dataset_args, capture_output=True, text=True, check=False)
680
+
681
+ if vit_result.returncode == 0:
682
+ log_message += "βœ… ViT training completed successfully!\n"
683
+ log_message += f"πŸ“Š ViT Output:\n{vit_result.stdout}\n"
684
+ else:
685
+ log_message += f"❌ ViT training failed: {vit_result.stderr}\n"
686
+ return log_message
687
  service.reload_models()
688
  log_message += "\nDone. Artifacts in models/exports."
689
 
train_resnet.py CHANGED
@@ -111,13 +111,17 @@ def main() -> None:
111
  running_loss += loss.item()
112
  steps += 1
113
 
114
- if batch_idx % 100 == 0:
115
  print(f" Batch {batch_idx}/{len(loader)}: loss={loss.item():.4f}")
116
 
117
  except Exception as e:
118
  print(f"❌ Error in batch {batch_idx}: {e}")
119
  continue
120
 
 
 
 
 
121
  avg_loss = running_loss / max(1, steps)
122
 
123
  # Save checkpoint with better path handling
 
111
  running_loss += loss.item()
112
  steps += 1
113
 
114
+ if batch_idx % 10 == 0: # More frequent logging
115
  print(f" Batch {batch_idx}/{len(loader)}: loss={loss.item():.4f}")
116
 
117
  except Exception as e:
118
  print(f"❌ Error in batch {batch_idx}: {e}")
119
  continue
120
 
121
+ # Print final batch completion
122
+ print(f" βœ… Batch {len(loader)-1}/{len(loader)}: loss={loss.item():.4f}")
123
+ print(f" πŸ“Š Epoch {epoch+1} completed: {len(loader)} batches processed")
124
+
125
  avg_loss = running_loss / max(1, steps)
126
 
127
  # Save checkpoint with better path handling
train_vit_triplet.py CHANGED
@@ -148,13 +148,17 @@ def main() -> None:
148
  running_loss += loss.item()
149
  steps += 1
150
 
151
- if batch_idx % 50 == 0:
152
  print(f" Batch {batch_idx}/{len(loader)}: loss={loss.item():.4f}")
153
 
154
  except Exception as e:
155
  print(f"❌ Error in batch {batch_idx}: {e}")
156
  continue
157
 
 
 
 
 
158
  avg_loss = running_loss / max(1, steps)
159
 
160
  # Simple validation using a subset of training data as a proxy if no val split here
 
148
  running_loss += loss.item()
149
  steps += 1
150
 
151
+ if batch_idx % 10 == 0: # More frequent logging
152
  print(f" Batch {batch_idx}/{len(loader)}: loss={loss.item():.4f}")
153
 
154
  except Exception as e:
155
  print(f"❌ Error in batch {batch_idx}: {e}")
156
  continue
157
 
158
+ # Print final batch completion
159
+ print(f" βœ… Batch {len(loader)-1}/{len(loader)}: loss={loss.item():.4f}")
160
+ print(f" πŸ“Š Epoch {epoch+1} completed: {len(loader)} batches processed")
161
+
162
  avg_loss = running_loss / max(1, steps)
163
 
164
  # Simple validation using a subset of training data as a proxy if no val split here