Ali Mohsin commited on
Commit
227af5e
Β·
1 Parent(s): 58e8faf

Fixed 1 million more errors

Browse files
Files changed (3) hide show
  1. app.py +12 -10
  2. train_vit_triplet.py +23 -8
  3. utils/hf_utils.py +15 -1
app.py CHANGED
@@ -272,15 +272,15 @@ def _background_bootstrap():
272
  BOOT_STATUS = "training-resnet"
273
  subprocess.run([
274
  "python", "train_resnet.py", "--data_root", ds_root, "--epochs", "3",
275
- "--batch_size", "8", "--lr", "1e-3", "--early_stopping_patience", "3",
276
  "--out", os.path.join(export_dir, "resnet_item_embedder.pth")
277
  ], check=False)
278
  if not os.path.exists(vit_ckpt):
279
  BOOT_STATUS = "training-vit"
280
  subprocess.run([
281
  "python", "train_vit_triplet.py", "--data_root", ds_root, "--epochs", "3",
282
- "--batch_size", "8", "--lr", "5e-4", "--early_stopping_patience", "3",
283
- "--export", os.path.join(export_dir, "vit_outfit_model.pth")
284
  ], check=False)
285
  service.reload_models()
286
  BOOT_STATUS = "ready"
@@ -445,7 +445,7 @@ def start_training_advanced(
445
  resnet_backbone: str, resnet_use_pretrained: bool, resnet_dropout: float,
446
 
447
  # ViT parameters
448
- vit_epochs: int, vit_batch_size: int, vit_lr: float, vit_optimizer: str,
449
  vit_weight_decay: float, vit_triplet_margin: float, vit_embedding_dim: int,
450
  vit_num_layers: int, vit_num_heads: int, vit_ff_multiplier: int, vit_dropout: float,
451
 
@@ -599,6 +599,7 @@ def start_training_advanced(
599
  "--data_root", DATASET_ROOT,
600
  "--epochs", str(vit_epochs),
601
  "--batch_size", str(vit_batch_size),
 
602
  "--lr", str(vit_lr),
603
  "--weight_decay", str(vit_weight_decay),
604
  "--triplet_margin", str(vit_triplet_margin),
@@ -682,7 +683,7 @@ def start_training_simple(dataset_size: str, res_epochs: int, vit_epochs: int):
682
  log_message += f"\nπŸš€ Starting ResNet training on {dataset_size} samples...\n"
683
  resnet_result = subprocess.run([
684
  "python", "train_resnet.py", "--data_root", DATASET_ROOT, "--epochs", str(res_epochs),
685
- "--batch_size", "8", "--lr", "1e-3", "--early_stopping_patience", "3",
686
  "--out", os.path.join(export_dir, "resnet_item_embedder.pth")
687
  ] + dataset_args, capture_output=True, text=True, check=False)
688
 
@@ -709,8 +710,8 @@ def start_training_simple(dataset_size: str, res_epochs: int, vit_epochs: int):
709
  log_message += f"\nπŸš€ Starting ViT training on {dataset_size} samples...\n"
710
  vit_result = subprocess.run([
711
  "python", "train_vit_triplet.py", "--data_root", DATASET_ROOT, "--epochs", str(vit_epochs),
712
- "--batch_size", "8", "--lr", "5e-4", "--early_stopping_patience", "3",
713
- "--export", os.path.join(export_dir, "vit_outfit_model.pth")
714
  ] + dataset_args, capture_output=True, text=True, check=False)
715
 
716
  if vit_result.returncode == 0:
@@ -876,7 +877,7 @@ with gr.Blocks(fill_height=True, title="Dressify - Advanced Outfit Recommendatio
876
 
877
  # Training parameters
878
  resnet_epochs = gr.Slider(1, 100, value=20, step=1, label="Epochs")
879
- resnet_batch_size = gr.Slider(8, 128, value=8, step=8, label="Batch Size")
880
  resnet_lr = gr.Slider(1e-5, 1e-2, value=1e-3, step=1e-5, label="Learning Rate")
881
  resnet_optimizer = gr.Dropdown(
882
  choices=["adamw", "adam", "sgd", "rmsprop"],
@@ -898,7 +899,8 @@ with gr.Blocks(fill_height=True, title="Dressify - Advanced Outfit Recommendatio
898
 
899
  # Training parameters
900
  vit_epochs = gr.Slider(1, 100, value=30, step=1, label="Epochs")
901
- vit_batch_size = gr.Slider(4, 64, value=8, step=4, label="Batch Size")
 
902
  vit_lr = gr.Slider(1e-5, 1e-2, value=5e-4, step=1e-5, label="Learning Rate")
903
  vit_optimizer = gr.Dropdown(
904
  choices=["adamw", "adam", "sgd", "rmsprop"],
@@ -978,7 +980,7 @@ with gr.Blocks(fill_height=True, title="Dressify - Advanced Outfit Recommendatio
978
  resnet_backbone, resnet_use_pretrained, resnet_dropout,
979
 
980
  # ViT parameters
981
- vit_epochs, vit_batch_size, vit_lr, vit_optimizer,
982
  vit_weight_decay, vit_triplet_margin, vit_embedding_dim,
983
  vit_num_layers, vit_num_heads, vit_ff_multiplier, vit_dropout,
984
 
 
272
  BOOT_STATUS = "training-resnet"
273
  subprocess.run([
274
  "python", "train_resnet.py", "--data_root", ds_root, "--epochs", "3",
275
+ "--batch_size", "4", "--lr", "1e-3", "--early_stopping_patience", "3",
276
  "--out", os.path.join(export_dir, "resnet_item_embedder.pth")
277
  ], check=False)
278
  if not os.path.exists(vit_ckpt):
279
  BOOT_STATUS = "training-vit"
280
  subprocess.run([
281
  "python", "train_vit_triplet.py", "--data_root", ds_root, "--epochs", "3",
282
+ "--batch_size", "4", "--lr", "5e-4", "--early_stopping_patience", "3",
283
+ "--skip_validation", "--max_samples", "200", "--export", os.path.join(export_dir, "vit_outfit_model.pth")
284
  ], check=False)
285
  service.reload_models()
286
  BOOT_STATUS = "ready"
 
445
  resnet_backbone: str, resnet_use_pretrained: bool, resnet_dropout: float,
446
 
447
  # ViT parameters
448
+ vit_epochs: int, vit_batch_size: int, vit_max_samples: int, vit_lr: float, vit_optimizer: str,
449
  vit_weight_decay: float, vit_triplet_margin: float, vit_embedding_dim: int,
450
  vit_num_layers: int, vit_num_heads: int, vit_ff_multiplier: int, vit_dropout: float,
451
 
 
599
  "--data_root", DATASET_ROOT,
600
  "--epochs", str(vit_epochs),
601
  "--batch_size", str(vit_batch_size),
602
+ "--max_samples", str(vit_max_samples),
603
  "--lr", str(vit_lr),
604
  "--weight_decay", str(vit_weight_decay),
605
  "--triplet_margin", str(vit_triplet_margin),
 
683
  log_message += f"\nπŸš€ Starting ResNet training on {dataset_size} samples...\n"
684
  resnet_result = subprocess.run([
685
  "python", "train_resnet.py", "--data_root", DATASET_ROOT, "--epochs", str(res_epochs),
686
+ "--batch_size", "4", "--lr", "1e-3", "--early_stopping_patience", "3",
687
  "--out", os.path.join(export_dir, "resnet_item_embedder.pth")
688
  ] + dataset_args, capture_output=True, text=True, check=False)
689
 
 
710
  log_message += f"\nπŸš€ Starting ViT training on {dataset_size} samples...\n"
711
  vit_result = subprocess.run([
712
  "python", "train_vit_triplet.py", "--data_root", DATASET_ROOT, "--epochs", str(vit_epochs),
713
+ "--batch_size", "4", "--lr", "5e-4", "--early_stopping_patience", "3",
714
+ "--skip_validation", "--max_samples", "200", "--export", os.path.join(export_dir, "vit_outfit_model.pth")
715
  ] + dataset_args, capture_output=True, text=True, check=False)
716
 
717
  if vit_result.returncode == 0:
 
877
 
878
  # Training parameters
879
  resnet_epochs = gr.Slider(1, 100, value=20, step=1, label="Epochs")
880
+ resnet_batch_size = gr.Slider(4, 128, value=4, step=4, label="Batch Size")
881
  resnet_lr = gr.Slider(1e-5, 1e-2, value=1e-3, step=1e-5, label="Learning Rate")
882
  resnet_optimizer = gr.Dropdown(
883
  choices=["adamw", "adam", "sgd", "rmsprop"],
 
899
 
900
  # Training parameters
901
  vit_epochs = gr.Slider(1, 100, value=30, step=1, label="Epochs")
902
+ vit_batch_size = gr.Slider(2, 64, value=4, step=2, label="Batch Size")
903
+ vit_max_samples = gr.Slider(100, 5000, value=500, step=100, label="Max Training Samples")
904
  vit_lr = gr.Slider(1e-5, 1e-2, value=5e-4, step=1e-5, label="Learning Rate")
905
  vit_optimizer = gr.Dropdown(
906
  choices=["adamw", "adam", "sgd", "rmsprop"],
 
980
  resnet_backbone, resnet_use_pretrained, resnet_dropout,
981
 
982
  # ViT parameters
983
+ vit_epochs, vit_batch_size, vit_max_samples, vit_lr, vit_optimizer,
984
  vit_weight_decay, vit_triplet_margin, vit_embedding_dim,
985
  vit_num_layers, vit_num_heads, vit_ff_multiplier, vit_dropout,
986
 
train_vit_triplet.py CHANGED
@@ -23,12 +23,14 @@ def parse_args() -> argparse.Namespace:
23
  p = argparse.ArgumentParser()
24
  p.add_argument("--data_root", type=str, default=os.getenv("POLYVORE_ROOT", "/home/user/app/data/Polyvore"))
25
  p.add_argument("--epochs", type=int, default=50)
26
- p.add_argument("--batch_size", type=int, default=16)
27
  p.add_argument("--lr", type=float, default=5e-4)
28
  p.add_argument("--embedding_dim", type=int, default=512)
29
  p.add_argument("--triplet_margin", type=float, default=0.3)
30
  p.add_argument("--export", type=str, default="models/exports/vit_outfit_model.pth")
31
  p.add_argument("--eval_every", type=int, default=1)
 
 
32
  p.add_argument("--early_stopping_patience", type=int, default=10, help="Early stopping patience")
33
  p.add_argument("--min_delta", type=float, default=1e-4, help="Minimum change to qualify as improvement")
34
  return p.parse_args()
@@ -82,7 +84,14 @@ def main() -> None:
82
 
83
  try:
84
  dataset = PolyvoreOutfitTripletDataset(args.data_root, split="train")
85
- print(f"πŸ“Š Dataset loaded: {len(dataset)} samples")
 
 
 
 
 
 
 
86
  except Exception as e:
87
  print(f"❌ Failed to load dataset: {e}")
88
  return
@@ -172,20 +181,25 @@ def main() -> None:
172
 
173
  avg_loss = running_loss / max(1, steps)
174
 
175
- # Simple validation using a subset of training data as a proxy if no val split here
176
- # For true 70/10/10, prepare_polyvore.py will create outfit_triplets_valid.json
177
  val_path = os.path.join(args.data_root, "splits", "outfit_triplets_valid.json")
178
  val_loss = None
179
 
180
- if os.path.exists(val_path) and (epoch + 1) % args.eval_every == 0:
181
  try:
 
182
  val_ds = PolyvoreOutfitTripletDataset(args.data_root, split="valid")
183
- val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, num_workers=2, collate_fn=lambda x: x)
 
 
 
184
  model.eval()
185
  losses = []
186
 
187
  with torch.no_grad():
188
- for vbatch in val_loader:
 
 
189
  anchor_tokens = []
190
  positive_tokens = []
191
  negative_tokens = []
@@ -206,10 +220,11 @@ def main() -> None:
206
  losses.append(l)
207
 
208
  val_loss = sum(losses) / max(1, len(losses))
209
- print(f" πŸ“Š Validation loss: {val_loss:.4f}")
210
 
211
  except Exception as e:
212
  print(f" ⚠️ Validation failed: {e}")
 
213
 
214
  out_path = args.export
215
  if not out_path.startswith("models/"):
 
23
  p = argparse.ArgumentParser()
24
  p.add_argument("--data_root", type=str, default=os.getenv("POLYVORE_ROOT", "/home/user/app/data/Polyvore"))
25
  p.add_argument("--epochs", type=int, default=50)
26
+ p.add_argument("--batch_size", type=int, default=4)
27
  p.add_argument("--lr", type=float, default=5e-4)
28
  p.add_argument("--embedding_dim", type=int, default=512)
29
  p.add_argument("--triplet_margin", type=float, default=0.3)
30
  p.add_argument("--export", type=str, default="models/exports/vit_outfit_model.pth")
31
  p.add_argument("--eval_every", type=int, default=1)
32
+ p.add_argument("--skip_validation", action="store_true", help="Skip validation for faster training")
33
+ p.add_argument("--max_samples", type=int, default=500, help="Maximum number of training samples (for faster testing)")
34
  p.add_argument("--early_stopping_patience", type=int, default=10, help="Early stopping patience")
35
  p.add_argument("--min_delta", type=float, default=1e-4, help="Minimum change to qualify as improvement")
36
  return p.parse_args()
 
84
 
85
  try:
86
  dataset = PolyvoreOutfitTripletDataset(args.data_root, split="train")
87
+ # Limit dataset size for faster training/testing
88
+ max_samples = min(len(dataset), args.max_samples)
89
+ print(f"πŸ” Debug: Original dataset size: {len(dataset)}, max_samples: {args.max_samples}")
90
+ if len(dataset) > max_samples:
91
+ dataset.samples = dataset.samples[:max_samples]
92
+ print(f"πŸ“Š Dataset limited to {max_samples} samples for faster training")
93
+ else:
94
+ print(f"πŸ“Š Dataset loaded: {len(dataset)} samples (no limiting needed)")
95
  except Exception as e:
96
  print(f"❌ Failed to load dataset: {e}")
97
  return
 
181
 
182
  avg_loss = running_loss / max(1, steps)
183
 
184
+ # Fast validation with limited samples to prevent hanging
 
185
  val_path = os.path.join(args.data_root, "splits", "outfit_triplets_valid.json")
186
  val_loss = None
187
 
188
+ if not args.skip_validation and os.path.exists(val_path) and (epoch + 1) % args.eval_every == 0:
189
  try:
190
+ print(f" πŸ” Starting validation (limited to 50 samples for speed)...")
191
  val_ds = PolyvoreOutfitTripletDataset(args.data_root, split="valid")
192
+ # Limit validation to first 50 samples to prevent hanging
193
+ val_samples = val_ds.samples[:50]
194
+ val_ds.samples = val_samples
195
+ val_loader = DataLoader(val_ds, batch_size=min(args.batch_size, 8), shuffle=False, num_workers=0, collate_fn=lambda x: x)
196
  model.eval()
197
  losses = []
198
 
199
  with torch.no_grad():
200
+ for i, vbatch in enumerate(val_loader):
201
+ if i >= 10: # Limit to 10 batches max for speed
202
+ break
203
  anchor_tokens = []
204
  positive_tokens = []
205
  negative_tokens = []
 
220
  losses.append(l)
221
 
222
  val_loss = sum(losses) / max(1, len(losses))
223
+ print(f" πŸ“Š Validation loss: {val_loss:.4f} (from {len(losses)} batches)")
224
 
225
  except Exception as e:
226
  print(f" ⚠️ Validation failed: {e}")
227
+ val_loss = None
228
 
229
  out_path = args.export
230
  if not out_path.startswith("models/"):
utils/hf_utils.py CHANGED
@@ -10,10 +10,24 @@ class HFModelManager:
10
  """Utility class for managing model checkpoints on Hugging Face Hub."""
11
 
12
  def __init__(self, token: Optional[str] = None, username: Optional[str] = None):
13
- self.api = HfApi(token=token or os.getenv("HF_TOKEN"))
14
  self.username = username or os.getenv("HF_USERNAME")
 
 
 
15
  if not self.username:
16
  raise ValueError("HF_USERNAME environment variable must be set")
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  def create_model_repo(self, model_name: str, private: bool = False) -> str:
19
  """Create a new model repository."""
 
10
  """Utility class for managing model checkpoints on Hugging Face Hub."""
11
 
12
  def __init__(self, token: Optional[str] = None, username: Optional[str] = None):
13
+ self.token = token or os.getenv("HF_TOKEN")
14
  self.username = username or os.getenv("HF_USERNAME")
15
+
16
+ if not self.token:
17
+ raise ValueError("HF_TOKEN environment variable must be set")
18
  if not self.username:
19
  raise ValueError("HF_USERNAME environment variable must be set")
20
+
21
+ # Set up authentication
22
+ try:
23
+ from huggingface_hub import login
24
+ login(token=self.token, write_permission=True)
25
+ print("βœ… Hugging Face authentication successful")
26
+ except Exception as e:
27
+ print(f"⚠️ Hugging Face authentication failed: {e}")
28
+ raise
29
+
30
+ self.api = HfApi(token=self.token)
31
 
32
  def create_model_repo(self, model_name: str, private: bool = False) -> str:
33
  """Create a new model repository."""