Ali Mohsin commited on
Commit
aa9a482
Β·
1 Parent(s): 9b6249a
Files changed (4) hide show
  1. API_DOCUMENTATION.md +1 -0
  2. app.py +4 -4
  3. inference.py +10 -10
  4. utils/tag_system.py +1 -0
API_DOCUMENTATION.md CHANGED
@@ -409,3 +409,4 @@ fetch(`${BASE_URL}/compose`, {
409
 
410
  For API support, please contact: support@dressify.com
411
 
 
 
409
 
410
  For API support, please contact: support@dressify.com
411
 
412
+
app.py CHANGED
@@ -870,9 +870,9 @@ def start_training_simple(dataset_size: str, res_epochs: int, vit_epochs: int):
870
  # Train ResNet first and wait for completion
871
  log_message += f"\nπŸš€ Starting ResNet training on {dataset_size} samples...\n"
872
  resnet_result = subprocess.run([
873
- "python", "train_resnet.py", "--data_root", DATASET_ROOT, "--epochs", str(res_epochs),
874
  "--batch_size", "4", "--lr", "1e-3", "--early_stopping_patience", "3",
875
- "--out", os.path.join(export_dir, "resnet_item_embedder.pth")
876
  ] + dataset_args, capture_output=True, text=True, check=False)
877
 
878
  if resnet_result.returncode == 0:
@@ -897,7 +897,7 @@ def start_training_simple(dataset_size: str, res_epochs: int, vit_epochs: int):
897
 
898
  log_message += f"\nπŸš€ Starting ViT training on {dataset_size} samples...\n"
899
  vit_result = subprocess.run([
900
- "python", "train_vit_triplet.py", "--data_root", DATASET_ROOT, "--epochs", str(vit_epochs),
901
  "--batch_size", "4", "--lr", "5e-4", "--early_stopping_patience", "5",
902
  "--max_samples", "5000", "--triplet_margin", "0.5", "--gradient_clip", "1.0",
903
  "--warmup_epochs", "2", "--export", os.path.join(export_dir, "vit_outfit_model.pth")
@@ -909,7 +909,7 @@ def start_training_simple(dataset_size: str, res_epochs: int, vit_epochs: int):
909
  else:
910
  log_message += f"❌ ViT training failed: {vit_result.stderr}\n"
911
  return log_message
912
- service.reload_models()
913
 
914
  # Check if models loaded successfully
915
  model_status = service.get_model_status()
 
870
  # Train ResNet first and wait for completion
871
  log_message += f"\nπŸš€ Starting ResNet training on {dataset_size} samples...\n"
872
  resnet_result = subprocess.run([
873
+ "python", "train_resnet.py", "--data_root", DATASET_ROOT, "--epochs", str(res_epochs),
874
  "--batch_size", "4", "--lr", "1e-3", "--early_stopping_patience", "3",
875
+ "--out", os.path.join(export_dir, "resnet_item_embedder.pth")
876
  ] + dataset_args, capture_output=True, text=True, check=False)
877
 
878
  if resnet_result.returncode == 0:
 
897
 
898
  log_message += f"\nπŸš€ Starting ViT training on {dataset_size} samples...\n"
899
  vit_result = subprocess.run([
900
+ "python", "train_vit_triplet.py", "--data_root", DATASET_ROOT, "--epochs", str(vit_epochs),
901
  "--batch_size", "4", "--lr", "5e-4", "--early_stopping_patience", "5",
902
  "--max_samples", "5000", "--triplet_margin", "0.5", "--gradient_clip", "1.0",
903
  "--warmup_epochs", "2", "--export", os.path.join(export_dir, "vit_outfit_model.pth")
 
909
  else:
910
  log_message += f"❌ ViT training failed: {vit_result.stderr}\n"
911
  return log_message
912
+ service.reload_models()
913
 
914
  # Check if models loaded successfully
915
  model_status = service.get_model_status()
inference.py CHANGED
@@ -58,8 +58,8 @@ class InferenceService:
58
  # Disable gradients
59
  for m in [self.resnet, self.vit]:
60
  if m is not None:
61
- for p in m.parameters():
62
- p.requires_grad_(False)
63
 
64
  # Update overall status
65
  self.models_loaded = self.resnet_loaded and self.vit_loaded
@@ -298,8 +298,8 @@ class InferenceService:
298
  # Disable gradients
299
  for m in [self.resnet, self.vit]:
300
  if m is not None:
301
- for p in m.parameters():
302
- p.requires_grad_(False)
303
 
304
  # Update overall status
305
  self.models_loaded = self.resnet_loaded and self.vit_loaded
@@ -323,12 +323,12 @@ class InferenceService:
323
  return []
324
 
325
  try:
326
- batch = torch.stack([self.transform(img) for img in images])
327
- batch = batch.to(self.device, memory_format=torch.channels_last)
328
- use_amp = (self.device == "cuda")
329
- with torch.autocast(device_type=("cuda" if use_amp else "cpu"), enabled=use_amp):
330
- emb = self.resnet(batch)
331
- emb = nn.functional.normalize(emb, dim=-1)
332
  result = [e.detach().cpu().numpy().astype(np.float32) for e in emb]
333
  print(f"πŸ” DEBUG: Successfully generated {len(result)} embeddings")
334
  return result
 
58
  # Disable gradients
59
  for m in [self.resnet, self.vit]:
60
  if m is not None:
61
+ for p in m.parameters():
62
+ p.requires_grad_(False)
63
 
64
  # Update overall status
65
  self.models_loaded = self.resnet_loaded and self.vit_loaded
 
298
  # Disable gradients
299
  for m in [self.resnet, self.vit]:
300
  if m is not None:
301
+ for p in m.parameters():
302
+ p.requires_grad_(False)
303
 
304
  # Update overall status
305
  self.models_loaded = self.resnet_loaded and self.vit_loaded
 
323
  return []
324
 
325
  try:
326
+ batch = torch.stack([self.transform(img) for img in images])
327
+ batch = batch.to(self.device, memory_format=torch.channels_last)
328
+ use_amp = (self.device == "cuda")
329
+ with torch.autocast(device_type=("cuda" if use_amp else "cpu"), enabled=use_amp):
330
+ emb = self.resnet(batch)
331
+ emb = nn.functional.normalize(emb, dim=-1)
332
  result = [e.detach().cpu().numpy().astype(np.float32) for e in emb]
333
  print(f"πŸ” DEBUG: Successfully generated {len(result)} embeddings")
334
  return result
utils/tag_system.py CHANGED
@@ -516,3 +516,4 @@ def validate_tags(tags: Dict[str, Any]) -> tuple[bool, List[str]]:
516
 
517
  return len(errors) == 0, errors
518
 
 
 
516
 
517
  return len(errors) == 0, errors
518
 
519
+