Ali Mohsin commited on
Commit
1a93f46
ยท
1 Parent(s): 3305d4a
Files changed (3) hide show
  1. app.py +4 -4
  2. train_resnet.py +1 -1
  3. train_vit_triplet.py +1 -1
app.py CHANGED
@@ -649,7 +649,7 @@ def start_training_simple(dataset_size: str, res_epochs: int, vit_epochs: int):
649
  log_message += f"\n๐Ÿš€ Starting ResNet training on {dataset_size} samples...\n"
650
  resnet_result = subprocess.run([
651
  "python", "train_resnet.py", "--data_root", DATASET_ROOT, "--epochs", str(res_epochs),
652
- "--batch_size", "20", "--out", os.path.join(export_dir, "resnet_item_embedder.pth")
653
  ] + dataset_args, capture_output=True, text=True, check=False)
654
 
655
  if resnet_result.returncode == 0:
@@ -675,7 +675,7 @@ def start_training_simple(dataset_size: str, res_epochs: int, vit_epochs: int):
675
  log_message += f"\n๐Ÿš€ Starting ViT training on {dataset_size} samples...\n"
676
  vit_result = subprocess.run([
677
  "python", "train_vit_triplet.py", "--data_root", DATASET_ROOT, "--epochs", str(vit_epochs),
678
- "--batch_size", "20", "--export", os.path.join(export_dir, "vit_outfit_model.pth")
679
  ] + dataset_args, capture_output=True, text=True, check=False)
680
 
681
  if vit_result.returncode == 0:
@@ -827,7 +827,7 @@ with gr.Blocks(fill_height=True, title="Dressify - Advanced Outfit Recommendatio
827
 
828
  # Training parameters
829
  resnet_epochs = gr.Slider(1, 100, value=20, step=1, label="Epochs")
830
- resnet_batch_size = gr.Slider(8, 128, value=20, step=8, label="Batch Size")
831
  resnet_lr = gr.Slider(1e-5, 1e-2, value=1e-3, step=1e-5, label="Learning Rate")
832
  resnet_optimizer = gr.Dropdown(
833
  choices=["adamw", "adam", "sgd", "rmsprop"],
@@ -849,7 +849,7 @@ with gr.Blocks(fill_height=True, title="Dressify - Advanced Outfit Recommendatio
849
 
850
  # Training parameters
851
  vit_epochs = gr.Slider(1, 100, value=30, step=1, label="Epochs")
852
- vit_batch_size = gr.Slider(4, 64, value=20, step=4, label="Batch Size")
853
  vit_lr = gr.Slider(1e-5, 1e-2, value=5e-4, step=1e-5, label="Learning Rate")
854
  vit_optimizer = gr.Dropdown(
855
  choices=["adamw", "adam", "sgd", "rmsprop"],
 
649
  log_message += f"\n๐Ÿš€ Starting ResNet training on {dataset_size} samples...\n"
650
  resnet_result = subprocess.run([
651
  "python", "train_resnet.py", "--data_root", DATASET_ROOT, "--epochs", str(res_epochs),
652
+ "--batch_size", "8", "--out", os.path.join(export_dir, "resnet_item_embedder.pth")
653
  ] + dataset_args, capture_output=True, text=True, check=False)
654
 
655
  if resnet_result.returncode == 0:
 
675
  log_message += f"\n๐Ÿš€ Starting ViT training on {dataset_size} samples...\n"
676
  vit_result = subprocess.run([
677
  "python", "train_vit_triplet.py", "--data_root", DATASET_ROOT, "--epochs", str(vit_epochs),
678
+ "--batch_size", "8", "--export", os.path.join(export_dir, "vit_outfit_model.pth")
679
  ] + dataset_args, capture_output=True, text=True, check=False)
680
 
681
  if vit_result.returncode == 0:
 
827
 
828
  # Training parameters
829
  resnet_epochs = gr.Slider(1, 100, value=20, step=1, label="Epochs")
830
+ resnet_batch_size = gr.Slider(8, 128, value=8, step=8, label="Batch Size")
831
  resnet_lr = gr.Slider(1e-5, 1e-2, value=1e-3, step=1e-5, label="Learning Rate")
832
  resnet_optimizer = gr.Dropdown(
833
  choices=["adamw", "adam", "sgd", "rmsprop"],
 
849
 
850
  # Training parameters
851
  vit_epochs = gr.Slider(1, 100, value=30, step=1, label="Epochs")
852
+ vit_batch_size = gr.Slider(4, 64, value=8, step=4, label="Batch Size")
853
  vit_lr = gr.Slider(1e-5, 1e-2, value=5e-4, step=1e-5, label="Learning Rate")
854
  vit_optimizer = gr.Dropdown(
855
  choices=["adamw", "adam", "sgd", "rmsprop"],
train_resnet.py CHANGED
@@ -21,7 +21,7 @@ def parse_args() -> argparse.Namespace:
21
  p = argparse.ArgumentParser()
22
  p.add_argument("--data_root", type=str, default=os.getenv("POLYVORE_ROOT", "/home/user/app/data/Polyvore"))
23
  p.add_argument("--epochs", type=int, default=20)
24
- p.add_argument("--batch_size", type=int, default=20)
25
  p.add_argument("--lr", type=float, default=1e-3)
26
  p.add_argument("--embedding_dim", type=int, default=512)
27
  p.add_argument("--out", type=str, default="models/exports/resnet_item_embedder.pth")
 
21
  p = argparse.ArgumentParser()
22
  p.add_argument("--data_root", type=str, default=os.getenv("POLYVORE_ROOT", "/home/user/app/data/Polyvore"))
23
  p.add_argument("--epochs", type=int, default=20)
24
+ p.add_argument("--batch_size", type=int, default=8)
25
  p.add_argument("--lr", type=float, default=1e-3)
26
  p.add_argument("--embedding_dim", type=int, default=512)
27
  p.add_argument("--out", type=str, default="models/exports/resnet_item_embedder.pth")
train_vit_triplet.py CHANGED
@@ -22,7 +22,7 @@ def parse_args() -> argparse.Namespace:
22
  p = argparse.ArgumentParser()
23
  p.add_argument("--data_root", type=str, default=os.getenv("POLYVORE_ROOT", "/home/user/app/data/Polyvore"))
24
  p.add_argument("--epochs", type=int, default=30)
25
- p.add_argument("--batch_size", type=int, default=20)
26
  p.add_argument("--lr", type=float, default=5e-4)
27
  p.add_argument("--embedding_dim", type=int, default=512)
28
  p.add_argument("--triplet_margin", type=float, default=0.3)
 
22
  p = argparse.ArgumentParser()
23
  p.add_argument("--data_root", type=str, default=os.getenv("POLYVORE_ROOT", "/home/user/app/data/Polyvore"))
24
  p.add_argument("--epochs", type=int, default=30)
25
+ p.add_argument("--batch_size", type=int, default=8)
26
  p.add_argument("--lr", type=float, default=5e-4)
27
  p.add_argument("--embedding_dim", type=int, default=512)
28
  p.add_argument("--triplet_margin", type=float, default=0.3)