Spaces:
Paused
Paused
Ali Mohsin
commited on
Commit
ยท
1a93f46
1
Parent(s):
3305d4a
faster
Browse files- app.py +4 -4
- train_resnet.py +1 -1
- train_vit_triplet.py +1 -1
app.py
CHANGED
|
@@ -649,7 +649,7 @@ def start_training_simple(dataset_size: str, res_epochs: int, vit_epochs: int):
|
|
| 649 |
log_message += f"\n๐ Starting ResNet training on {dataset_size} samples...\n"
|
| 650 |
resnet_result = subprocess.run([
|
| 651 |
"python", "train_resnet.py", "--data_root", DATASET_ROOT, "--epochs", str(res_epochs),
|
| 652 |
-
"--batch_size", "
|
| 653 |
] + dataset_args, capture_output=True, text=True, check=False)
|
| 654 |
|
| 655 |
if resnet_result.returncode == 0:
|
|
@@ -675,7 +675,7 @@ def start_training_simple(dataset_size: str, res_epochs: int, vit_epochs: int):
|
|
| 675 |
log_message += f"\n๐ Starting ViT training on {dataset_size} samples...\n"
|
| 676 |
vit_result = subprocess.run([
|
| 677 |
"python", "train_vit_triplet.py", "--data_root", DATASET_ROOT, "--epochs", str(vit_epochs),
|
| 678 |
-
"--batch_size", "
|
| 679 |
] + dataset_args, capture_output=True, text=True, check=False)
|
| 680 |
|
| 681 |
if vit_result.returncode == 0:
|
|
@@ -827,7 +827,7 @@ with gr.Blocks(fill_height=True, title="Dressify - Advanced Outfit Recommendatio
|
|
| 827 |
|
| 828 |
# Training parameters
|
| 829 |
resnet_epochs = gr.Slider(1, 100, value=20, step=1, label="Epochs")
|
| 830 |
-
resnet_batch_size = gr.Slider(8, 128, value=
|
| 831 |
resnet_lr = gr.Slider(1e-5, 1e-2, value=1e-3, step=1e-5, label="Learning Rate")
|
| 832 |
resnet_optimizer = gr.Dropdown(
|
| 833 |
choices=["adamw", "adam", "sgd", "rmsprop"],
|
|
@@ -849,7 +849,7 @@ with gr.Blocks(fill_height=True, title="Dressify - Advanced Outfit Recommendatio
|
|
| 849 |
|
| 850 |
# Training parameters
|
| 851 |
vit_epochs = gr.Slider(1, 100, value=30, step=1, label="Epochs")
|
| 852 |
-
vit_batch_size = gr.Slider(4, 64, value=
|
| 853 |
vit_lr = gr.Slider(1e-5, 1e-2, value=5e-4, step=1e-5, label="Learning Rate")
|
| 854 |
vit_optimizer = gr.Dropdown(
|
| 855 |
choices=["adamw", "adam", "sgd", "rmsprop"],
|
|
|
|
| 649 |
log_message += f"\n๐ Starting ResNet training on {dataset_size} samples...\n"
|
| 650 |
resnet_result = subprocess.run([
|
| 651 |
"python", "train_resnet.py", "--data_root", DATASET_ROOT, "--epochs", str(res_epochs),
|
| 652 |
+
"--batch_size", "8", "--out", os.path.join(export_dir, "resnet_item_embedder.pth")
|
| 653 |
] + dataset_args, capture_output=True, text=True, check=False)
|
| 654 |
|
| 655 |
if resnet_result.returncode == 0:
|
|
|
|
| 675 |
log_message += f"\n๐ Starting ViT training on {dataset_size} samples...\n"
|
| 676 |
vit_result = subprocess.run([
|
| 677 |
"python", "train_vit_triplet.py", "--data_root", DATASET_ROOT, "--epochs", str(vit_epochs),
|
| 678 |
+
"--batch_size", "8", "--export", os.path.join(export_dir, "vit_outfit_model.pth")
|
| 679 |
] + dataset_args, capture_output=True, text=True, check=False)
|
| 680 |
|
| 681 |
if vit_result.returncode == 0:
|
|
|
|
| 827 |
|
| 828 |
# Training parameters
|
| 829 |
resnet_epochs = gr.Slider(1, 100, value=20, step=1, label="Epochs")
|
| 830 |
+
resnet_batch_size = gr.Slider(8, 128, value=8, step=8, label="Batch Size")
|
| 831 |
resnet_lr = gr.Slider(1e-5, 1e-2, value=1e-3, step=1e-5, label="Learning Rate")
|
| 832 |
resnet_optimizer = gr.Dropdown(
|
| 833 |
choices=["adamw", "adam", "sgd", "rmsprop"],
|
|
|
|
| 849 |
|
| 850 |
# Training parameters
|
| 851 |
vit_epochs = gr.Slider(1, 100, value=30, step=1, label="Epochs")
|
| 852 |
+
vit_batch_size = gr.Slider(4, 64, value=8, step=4, label="Batch Size")
|
| 853 |
vit_lr = gr.Slider(1e-5, 1e-2, value=5e-4, step=1e-5, label="Learning Rate")
|
| 854 |
vit_optimizer = gr.Dropdown(
|
| 855 |
choices=["adamw", "adam", "sgd", "rmsprop"],
|
train_resnet.py
CHANGED
|
@@ -21,7 +21,7 @@ def parse_args() -> argparse.Namespace:
|
|
| 21 |
p = argparse.ArgumentParser()
|
| 22 |
p.add_argument("--data_root", type=str, default=os.getenv("POLYVORE_ROOT", "/home/user/app/data/Polyvore"))
|
| 23 |
p.add_argument("--epochs", type=int, default=20)
|
| 24 |
-
p.add_argument("--batch_size", type=int, default=
|
| 25 |
p.add_argument("--lr", type=float, default=1e-3)
|
| 26 |
p.add_argument("--embedding_dim", type=int, default=512)
|
| 27 |
p.add_argument("--out", type=str, default="models/exports/resnet_item_embedder.pth")
|
|
|
|
| 21 |
p = argparse.ArgumentParser()
|
| 22 |
p.add_argument("--data_root", type=str, default=os.getenv("POLYVORE_ROOT", "/home/user/app/data/Polyvore"))
|
| 23 |
p.add_argument("--epochs", type=int, default=20)
|
| 24 |
+
p.add_argument("--batch_size", type=int, default=8)
|
| 25 |
p.add_argument("--lr", type=float, default=1e-3)
|
| 26 |
p.add_argument("--embedding_dim", type=int, default=512)
|
| 27 |
p.add_argument("--out", type=str, default="models/exports/resnet_item_embedder.pth")
|
train_vit_triplet.py
CHANGED
|
@@ -22,7 +22,7 @@ def parse_args() -> argparse.Namespace:
|
|
| 22 |
p = argparse.ArgumentParser()
|
| 23 |
p.add_argument("--data_root", type=str, default=os.getenv("POLYVORE_ROOT", "/home/user/app/data/Polyvore"))
|
| 24 |
p.add_argument("--epochs", type=int, default=30)
|
| 25 |
-
p.add_argument("--batch_size", type=int, default=
|
| 26 |
p.add_argument("--lr", type=float, default=5e-4)
|
| 27 |
p.add_argument("--embedding_dim", type=int, default=512)
|
| 28 |
p.add_argument("--triplet_margin", type=float, default=0.3)
|
|
|
|
| 22 |
p = argparse.ArgumentParser()
|
| 23 |
p.add_argument("--data_root", type=str, default=os.getenv("POLYVORE_ROOT", "/home/user/app/data/Polyvore"))
|
| 24 |
p.add_argument("--epochs", type=int, default=30)
|
| 25 |
+
p.add_argument("--batch_size", type=int, default=8)
|
| 26 |
p.add_argument("--lr", type=float, default=5e-4)
|
| 27 |
p.add_argument("--embedding_dim", type=int, default=512)
|
| 28 |
p.add_argument("--triplet_margin", type=float, default=0.3)
|