Ali Mohsin commited on
Commit
3b3cac8
Β·
1 Parent(s): 24ea486
Files changed (2) hide show
  1. app.py +54 -10
  2. utils/data_fetch.py +23 -8
app.py CHANGED
@@ -415,6 +415,9 @@ def gradio_recommend(files: List[str], occasion: str, weather: str, num_outfits:
415
 
416
 
417
  def start_training_advanced(
 
 
 
418
  # ResNet parameters
419
  resnet_epochs: int, resnet_batch_size: int, resnet_lr: float, resnet_optimizer: str,
420
  resnet_weight_decay: float, resnet_triplet_margin: float, resnet_embedding_dim: int,
@@ -511,10 +514,16 @@ def start_training_advanced(
511
 
512
  # Train ResNet with custom parameters
513
  log_message = f"πŸš€ Starting ResNet training with custom parameters...\n"
 
514
  log_message += f"Backbone: {resnet_backbone}, Embedding Dim: {resnet_embedding_dim}\n"
515
  log_message += f"Epochs: {resnet_epochs}, Batch Size: {resnet_batch_size}, LR: {resnet_lr}\n"
516
  log_message += f"Optimizer: {resnet_optimizer}, Triplet Margin: {resnet_triplet_margin}\n"
517
 
 
 
 
 
 
518
  resnet_cmd = [
519
  "python", "train_resnet.py",
520
  "--data_root", DATASET_ROOT,
@@ -525,7 +534,7 @@ def start_training_advanced(
525
  "--triplet_margin", str(resnet_triplet_margin),
526
  "--embedding_dim", str(resnet_embedding_dim),
527
  "--out", os.path.join(export_dir, "resnet_item_embedder_custom.pth")
528
- ]
529
 
530
  if resnet_backbone != "resnet50":
531
  resnet_cmd.extend(["--backbone", resnet_backbone])
@@ -540,6 +549,7 @@ def start_training_advanced(
540
 
541
  # Train ViT with custom parameters
542
  log_message += f"πŸš€ Starting ViT training with custom parameters...\n"
 
543
  log_message += f"Layers: {vit_num_layers}, Heads: {vit_num_heads}, FF Multiplier: {vit_ff_multiplier}\n"
544
  log_message += f"Epochs: {vit_epochs}, Batch Size: {vit_batch_size}, LR: {vit_lr}\n"
545
  log_message += f"Optimizer: {vit_optimizer}, Triplet Margin: {vit_triplet_margin}\n"
@@ -554,7 +564,7 @@ def start_training_advanced(
554
  "--triplet_margin", str(vit_triplet_margin),
555
  "--embedding_dim", str(vit_embedding_dim),
556
  "--export", os.path.join(export_dir, "vit_outfit_model_custom.pth")
557
- ]
558
 
559
  result = subprocess.run(vit_cmd, capture_output=True, text=True, check=False)
560
 
@@ -593,7 +603,7 @@ def start_training_advanced(
593
  return log_message
594
 
595
 
596
- def start_training_simple(res_epochs: int, vit_epochs: int):
597
  """Start simple training with basic parameters."""
598
  log_message = "Starting training..."
599
  def _runner():
@@ -605,16 +615,21 @@ def start_training_simple(res_epochs: int, vit_epochs: int):
605
  return
606
  export_dir = os.getenv("EXPORT_DIR", "models/exports")
607
  os.makedirs(export_dir, exist_ok=True)
608
- log_message = "Training ResNet…\n"
 
 
 
 
 
609
  subprocess.run([
610
  "python", "train_resnet.py", "--data_root", DATASET_ROOT, "--epochs", str(res_epochs),
611
  "--out", os.path.join(export_dir, "resnet_item_embedder.pth")
612
- ], check=False)
613
- log_message += "\nTraining ViT (triplet)…\n"
614
  subprocess.run([
615
  "python", "train_vit_triplet.py", "--data_root", DATASET_ROOT, "--epochs", str(vit_epochs),
616
  "--export", os.path.join(export_dir, "vit_outfit_model.pth")
617
- ], check=False)
618
  service.reload_models()
619
  log_message += "\nDone. Artifacts in models/exports."
620
 
@@ -644,6 +659,7 @@ def start_training_simple(res_epochs: int, vit_epochs: int):
644
 
645
  with gr.Blocks(fill_height=True, title="Dressify - Advanced Outfit Recommendation") as demo:
646
  gr.Markdown("## πŸ† Dressify – Advanced Outfit Recommendation System\n*Research-grade, self-contained outfit recommendation with comprehensive training controls*")
 
647
 
648
  with gr.Tab("🎨 Recommend"):
649
  inp2 = gr.Files(label="Upload wardrobe images", file_types=["image"], file_count="multiple")
@@ -660,6 +676,16 @@ with gr.Blocks(fill_height=True, title="Dressify - Advanced Outfit Recommendatio
660
  gr.Markdown("### 🎯 Comprehensive Training Parameter Control\nCustomize every aspect of model training for research and experimentation.")
661
 
662
  with gr.Row():
 
 
 
 
 
 
 
 
 
 
663
  with gr.Column(scale=1):
664
  gr.Markdown("#### πŸ–ΌοΈ ResNet Item Embedder")
665
 
@@ -768,6 +794,9 @@ with gr.Blocks(fill_height=True, title="Dressify - Advanced Outfit Recommendatio
768
  start_advanced_btn.click(
769
  fn=start_training_advanced,
770
  inputs=[
 
 
 
771
  # ResNet parameters
772
  resnet_epochs, resnet_batch_size, resnet_lr, resnet_optimizer,
773
  resnet_weight_decay, resnet_triplet_margin, resnet_embedding_dim,
@@ -838,11 +867,26 @@ with gr.Blocks(fill_height=True, title="Dressify - Advanced Outfit Recommendatio
838
 
839
  with gr.Tab("πŸ”§ Simple Training"):
840
  gr.Markdown("### πŸš€ Quick Training with Default Parameters\nFast training with proven configurations for immediate results.")
841
- epochs_res = gr.Slider(1, 50, value=10, step=1, label="ResNet epochs")
842
- epochs_vit = gr.Slider(1, 100, value=20, step=1, label="ViT epochs")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
843
  train_log = gr.Textbox(label="Training Log", lines=10)
844
  start_btn = gr.Button("Start Training")
845
- start_btn.click(fn=start_training_simple, inputs=[epochs_res, epochs_vit], outputs=train_log)
846
 
847
  with gr.Tab("πŸ“Š Embed (Debug)"):
848
  inp = gr.Files(label="Upload Items (multiple images)")
 
415
 
416
 
417
  def start_training_advanced(
418
+ # Dataset size
419
+ dataset_size: str,
420
+
421
  # ResNet parameters
422
  resnet_epochs: int, resnet_batch_size: int, resnet_lr: float, resnet_optimizer: str,
423
  resnet_weight_decay: float, resnet_triplet_margin: float, resnet_embedding_dim: int,
 
514
 
515
  # Train ResNet with custom parameters
516
  log_message = f"πŸš€ Starting ResNet training with custom parameters...\n"
517
+ log_message += f"Dataset Size: {dataset_size} samples\n"
518
  log_message += f"Backbone: {resnet_backbone}, Embedding Dim: {resnet_embedding_dim}\n"
519
  log_message += f"Epochs: {resnet_epochs}, Batch Size: {resnet_batch_size}, LR: {resnet_lr}\n"
520
  log_message += f"Optimizer: {resnet_optimizer}, Triplet Margin: {resnet_triplet_margin}\n"
521
 
522
+ # Add dataset size limit if not full
523
+ dataset_args = []
524
+ if dataset_size != "full":
525
+ dataset_args = ["--max_samples", dataset_size]
526
+
527
  resnet_cmd = [
528
  "python", "train_resnet.py",
529
  "--data_root", DATASET_ROOT,
 
534
  "--triplet_margin", str(resnet_triplet_margin),
535
  "--embedding_dim", str(resnet_embedding_dim),
536
  "--out", os.path.join(export_dir, "resnet_item_embedder_custom.pth")
537
+ ] + dataset_args
538
 
539
  if resnet_backbone != "resnet50":
540
  resnet_cmd.extend(["--backbone", resnet_backbone])
 
549
 
550
  # Train ViT with custom parameters
551
  log_message += f"πŸš€ Starting ViT training with custom parameters...\n"
552
+ log_message += f"Dataset Size: {dataset_size} samples\n"
553
  log_message += f"Layers: {vit_num_layers}, Heads: {vit_num_heads}, FF Multiplier: {vit_ff_multiplier}\n"
554
  log_message += f"Epochs: {vit_epochs}, Batch Size: {vit_batch_size}, LR: {vit_lr}\n"
555
  log_message += f"Optimizer: {vit_optimizer}, Triplet Margin: {vit_triplet_margin}\n"
 
564
  "--triplet_margin", str(vit_triplet_margin),
565
  "--embedding_dim", str(vit_embedding_dim),
566
  "--export", os.path.join(export_dir, "vit_outfit_model_custom.pth")
567
+ ] + dataset_args
568
 
569
  result = subprocess.run(vit_cmd, capture_output=True, text=True, check=False)
570
 
 
603
  return log_message
604
 
605
 
606
+ def start_training_simple(dataset_size: str, res_epochs: int, vit_epochs: int):
607
  """Start simple training with basic parameters."""
608
  log_message = "Starting training..."
609
  def _runner():
 
615
  return
616
  export_dir = os.getenv("EXPORT_DIR", "models/exports")
617
  os.makedirs(export_dir, exist_ok=True)
618
+ log_message = f"Training ResNet on {dataset_size} samples...\n"
619
+ # Add dataset size limit if not full
620
+ dataset_args = []
621
+ if dataset_size != "full":
622
+ dataset_args = ["--max_samples", dataset_size]
623
+
624
  subprocess.run([
625
  "python", "train_resnet.py", "--data_root", DATASET_ROOT, "--epochs", str(res_epochs),
626
  "--out", os.path.join(export_dir, "resnet_item_embedder.pth")
627
+ ] + dataset_args, check=False)
628
+ log_message += f"\nTraining ViT (triplet) on {dataset_size} samples...\n"
629
  subprocess.run([
630
  "python", "train_vit_triplet.py", "--data_root", DATASET_ROOT, "--epochs", str(vit_epochs),
631
  "--export", os.path.join(export_dir, "vit_outfit_model.pth")
632
+ ] + dataset_args, check=False)
633
  service.reload_models()
634
  log_message += "\nDone. Artifacts in models/exports."
635
 
 
659
 
660
  with gr.Blocks(fill_height=True, title="Dressify - Advanced Outfit Recommendation") as demo:
661
  gr.Markdown("## πŸ† Dressify – Advanced Outfit Recommendation System\n*Research-grade, self-contained outfit recommendation with comprehensive training controls*")
662
+ gr.Markdown("πŸ’‘ **Pro Tip**: Start with 2000 samples for quick testing, then increase to 50000+ for production training!")
663
 
664
  with gr.Tab("🎨 Recommend"):
665
  inp2 = gr.Files(label="Upload wardrobe images", file_types=["image"], file_count="multiple")
 
676
  gr.Markdown("### 🎯 Comprehensive Training Parameter Control\nCustomize every aspect of model training for research and experimentation.")
677
 
678
  with gr.Row():
679
+ with gr.Column(scale=1):
680
+ gr.Markdown("#### πŸ“Š Dataset Size Control")
681
+ gr.Markdown("Start small for testing, increase for production training")
682
+ dataset_size = gr.Dropdown(
683
+ choices=["2000", "5000", "10000", "25000", "50000", "full"],
684
+ value="2000",
685
+ label="Training Dataset Size"
686
+ )
687
+ gr.Markdown("**2000**: Quick testing (~2-5 min)\n**5000**: Fast validation (~5-10 min)\n**10000**: Good validation (~10-20 min)\n**25000+**: Production training")
688
+
689
  with gr.Column(scale=1):
690
  gr.Markdown("#### πŸ–ΌοΈ ResNet Item Embedder")
691
 
 
794
  start_advanced_btn.click(
795
  fn=start_training_advanced,
796
  inputs=[
797
+ # Dataset size
798
+ dataset_size,
799
+
800
  # ResNet parameters
801
  resnet_epochs, resnet_batch_size, resnet_lr, resnet_optimizer,
802
  resnet_weight_decay, resnet_triplet_margin, resnet_embedding_dim,
 
867
 
868
  with gr.Tab("πŸ”§ Simple Training"):
869
  gr.Markdown("### πŸš€ Quick Training with Default Parameters\nFast training with proven configurations for immediate results.")
870
+
871
+ with gr.Row():
872
+ with gr.Column(scale=1):
873
+ gr.Markdown("#### πŸ“Š Dataset Size Control")
874
+ gr.Markdown("Start small for testing, increase for production training")
875
+ dataset_size = gr.Dropdown(
876
+ choices=["2000", "5000", "10000", "25000", "50000", "full"],
877
+ value="2000",
878
+ label="Training Dataset Size"
879
+ )
880
+ gr.Markdown("**2000**: Quick testing (~2-5 min)\n**5000**: Fast validation (~5-10 min)\n**10000**: Good validation (~10-20 min)\n**25000+**: Production training")
881
+
882
+ with gr.Column(scale=1):
883
+ gr.Markdown("#### βš™οΈ Training Parameters")
884
+ epochs_res = gr.Slider(1, 50, value=10, step=1, label="ResNet epochs")
885
+ epochs_vit = gr.Slider(1, 100, value=20, step=1, label="ViT epochs")
886
+
887
  train_log = gr.Textbox(label="Training Log", lines=10)
888
  start_btn = gr.Button("Start Training")
889
+ start_btn.click(fn=start_training_simple, inputs=[dataset_size, epochs_res, epochs_vit], outputs=train_log)
890
 
891
  with gr.Tab("πŸ“Š Embed (Debug)"):
892
  inp = gr.Files(label="Upload Items (multiple images)")
utils/data_fetch.py CHANGED
@@ -170,8 +170,13 @@ def ensure_dataset_ready() -> Optional[str]:
170
  for meta_file in metadata_files:
171
  meta_path = os.path.join(root, meta_file)
172
  if os.path.exists(meta_path):
173
- size_mb = os.path.getsize(meta_path) / (1024 * 1024)
174
- print(f"πŸ“‹ {meta_file}: {size_mb:.1f} MB")
 
 
 
 
 
175
  else:
176
  print(f"⚠️ Missing: {meta_file}")
177
 
@@ -209,10 +214,15 @@ def check_dataset_structure(root: str) -> dict:
209
  for meta_file in metadata_files:
210
  meta_path = os.path.join(root, meta_file)
211
  if os.path.exists(meta_path):
212
- size_mb = os.path.getsize(meta_path) / (1024 * 1024)
213
- structure["metadata"][meta_file] = {"exists": True, "size_mb": size_mb}
 
 
 
 
 
214
  else:
215
- structure["metadata"][meta_file] = {"exists": False, "size_mb": 0}
216
 
217
  # Check for splits
218
  split_locations = [
@@ -229,10 +239,15 @@ def check_dataset_structure(root: str) -> dict:
229
  for split_file in files:
230
  split_path = os.path.join(location_path, split_file)
231
  if os.path.exists(split_path):
232
- size_mb = os.path.getsize(split_path) / (1024 * 1024)
233
- structure["splits"][location][split_file] = {"exists": True, "size_mb": size_mb}
 
 
 
 
 
234
  else:
235
- structure["splits"][location][split_file] = {"exists": False, "size_mb": 0}
236
  else:
237
  structure["splits"][location] = "directory_not_found"
238
 
 
170
  for meta_file in metadata_files:
171
  meta_path = os.path.join(root, meta_file)
172
  if os.path.exists(meta_path):
173
+ size_bytes = os.path.getsize(meta_path)
174
+ if size_bytes < 1024 * 1024: # Less than 1MB
175
+ size_kb = size_bytes / 1024
176
+ print(f"πŸ“‹ {meta_file}: {size_kb:.1f} KB")
177
+ else:
178
+ size_mb = size_bytes / (1024 * 1024)
179
+ print(f"πŸ“‹ {meta_file}: {size_mb:.1f} MB")
180
  else:
181
  print(f"⚠️ Missing: {meta_file}")
182
 
 
214
  for meta_file in metadata_files:
215
  meta_path = os.path.join(root, meta_file)
216
  if os.path.exists(meta_path):
217
+ size_bytes = os.path.getsize(meta_path)
218
+ if size_bytes < 1024 * 1024: # Less than 1MB
219
+ size_kb = size_bytes / 1024
220
+ structure["metadata"][meta_file] = {"exists": True, "size_kb": size_kb}
221
+ else:
222
+ size_mb = size_bytes / (1024 * 1024)
223
+ structure["metadata"][meta_file] = {"exists": True, "size_mb": size_mb}
224
  else:
225
+ structure["metadata"][meta_file] = {"exists": False, "size_mb": 0, "size_kb": 0}
226
 
227
  # Check for splits
228
  split_locations = [
 
239
  for split_file in files:
240
  split_path = os.path.join(location_path, split_file)
241
  if os.path.exists(split_path):
242
+ size_bytes = os.path.getsize(split_path)
243
+ if size_bytes < 1024 * 1024: # Less than 1MB
244
+ size_kb = size_bytes / 1024
245
+ structure["splits"][location][split_file] = {"exists": True, "size_kb": size_kb}
246
+ else:
247
+ size_mb = size_bytes / (1024 * 1024)
248
+ structure["splits"][location][split_file] = {"exists": True, "size_mb": size_mb}
249
  else:
250
+ structure["splits"][location][split_file] = {"exists": False, "size_mb": 0, "size_kb": 0}
251
  else:
252
  structure["splits"][location] = "directory_not_found"
253