Ali Mohsin commited on
Commit
4a5ec80
Β·
1 Parent(s): 25bdf34

Updated new changes

Browse files
Files changed (2) hide show
  1. app.py +61 -10
  2. inference.py +127 -36
app.py CHANGED
@@ -256,8 +256,8 @@ def _background_bootstrap():
256
  import sys
257
  argv_bak = sys.argv
258
  try:
259
- # Use official splits from nondisjoint/ and disjoint/ folders with default size limit (160 samples)
260
- sys.argv = ["prepare_polyvore.py", "--root", ds_root, "--max_samples", "160"]
261
  prepare_main()
262
  finally:
263
  sys.argv = argv_bak
@@ -390,6 +390,20 @@ def _stitch_strip(imgs: List[Image.Image], height: int = 256, pad: int = 6, bg=(
390
 
391
 
392
  def gradio_recommend(files: List[str], occasion: str, weather: str, num_outfits: int):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
393
  # Return stitched outfit images and a JSON with details
394
  if not files:
395
  return [], {"error": "No files uploaded"}
@@ -402,6 +416,11 @@ def gradio_recommend(files: List[str], occasion: str, weather: str, num_outfits:
402
  for i in range(len(images))
403
  ]
404
  res = service.compose_outfits(items, context={"occasion": occasion, "weather": weather, "num_outfits": int(num_outfits)})
 
 
 
 
 
405
  # Prepare stitched previews
406
  strips: List[Image.Image] = []
407
  for r in res:
@@ -595,7 +614,19 @@ def start_training_advanced(
595
  log_message += "πŸŽ‰ All training completed! Models saved to models/exports/\n"
596
  log_message += "πŸ”„ Reloading models for inference...\n"
597
  service.reload_models()
598
- log_message += "βœ… Models reloaded and ready for inference!\n"
 
 
 
 
 
 
 
 
 
 
 
 
599
 
600
  # Auto-upload to HF Hub if token is available
601
  hf_token = os.getenv("HF_TOKEN")
@@ -689,7 +720,21 @@ def start_training_simple(dataset_size: str, res_epochs: int, vit_epochs: int):
689
  log_message += f"❌ ViT training failed: {vit_result.stderr}\n"
690
  return log_message
691
  service.reload_models()
692
- log_message += "\nDone. Artifacts in models/exports."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
693
 
694
  # Auto-upload to HF Hub if token is available
695
  hf_token = os.getenv("HF_TOKEN")
@@ -740,12 +785,12 @@ with gr.Blocks(fill_height=True, title="Dressify - Advanced Outfit Recommendatio
740
 
741
  with gr.Row():
742
  gr.Markdown("#### πŸ“Š **Current Behavior**")
743
- gr.Markdown("β€’ **Bootstrap**: Downloads full dataset (53K outfits) + generates splits with **160 samples by default**\nβ€’ **Training**: Uses 160 samples (ultra-fast testing!)\nβ€’ **Apply Button**: Regenerates splits with your selected size limit")
744
 
745
  with gr.Row():
746
  global_dataset_size = gr.Dropdown(
747
  choices=["160", "2000", "5000", "10000", "25000", "50000", "full"],
748
- value="160",
749
  label="Global Dataset Size (Affects Prep + Training)"
750
  )
751
  gr.Markdown("**160**: Ultra-fast testing (~30 sec prep, ~1-2 min training)\n**2000**: Fast testing (~1-2 min prep, ~2-5 min training)\n**5000**: Fast testing (~2-3 min prep, ~5-10 min training)\n**10000**: Good testing (~3-5 min prep, ~10-20 min training)\n**full**: Production (~5-10 min prep, ~1-4 hours training)")
@@ -753,11 +798,11 @@ with gr.Blocks(fill_height=True, title="Dressify - Advanced Outfit Recommendatio
753
  with gr.Row():
754
  # Apply dataset size button
755
  apply_size_btn = gr.Button("πŸ”„ Apply Dataset Size & Regenerate Splits", variant="primary")
756
- size_status = gr.Textbox(label="Dataset Size Status", value="Dataset size: 160 samples (click Apply to regenerate splits)", interactive=False)
757
 
758
  # Current dataset info
759
  gr.Markdown("#### πŸ“Š **Current Dataset Status**")
760
- gr.Markdown("β€’ **Full dataset downloaded**: 53,306 outfits (required for system)\nβ€’ **Splits generated**: **160 samples by default** (ultra-fast testing!)\nβ€’ **Training will use**: 160 samples (ultra-fast!)\nβ€’ **Scale up**: Use Apply button to increase to larger sizes")
761
 
762
  def apply_dataset_size(size: str):
763
  """Apply global dataset size and regenerate splits."""
@@ -810,7 +855,7 @@ with gr.Blocks(fill_height=True, title="Dressify - Advanced Outfit Recommendatio
810
  gr.Markdown("#### πŸ“Š Dataset Size Control")
811
  gr.Markdown("Start small for testing, increase for production training")
812
  dataset_size = gr.Dropdown(
813
- choices=["2000", "5000", "10000", "25000", "50000", "full"],
814
  value="2000",
815
  label="Training Dataset Size"
816
  )
@@ -1003,7 +1048,7 @@ with gr.Blocks(fill_height=True, title="Dressify - Advanced Outfit Recommendatio
1003
  gr.Markdown("#### πŸ“Š Dataset Size Control")
1004
  gr.Markdown("Start small for testing, increase for production training")
1005
  dataset_size = gr.Dropdown(
1006
- choices=["2000", "5000", "10000", "25000", "50000", "full"],
1007
  value="2000",
1008
  label="Training Dataset Size"
1009
  )
@@ -1032,6 +1077,12 @@ with gr.Blocks(fill_height=True, title="Dressify - Advanced Outfit Recommendatio
1032
  refresh_status = gr.Button("πŸ”„ Refresh Status")
1033
  refresh_status.click(fn=lambda: BOOT_STATUS, inputs=[], outputs=status)
1034
 
 
 
 
 
 
 
1035
  # System info
1036
  gr.Markdown("#### πŸ’» System Information")
1037
  device_info = gr.Textbox(label="Device", value=lambda: f"Device: {service.device}")
 
256
  import sys
257
  argv_bak = sys.argv
258
  try:
259
+ # Use official splits from nondisjoint/ and disjoint/ folders with default size limit (2000 samples for better early stopping)
260
+ sys.argv = ["prepare_polyvore.py", "--root", ds_root, "--max_samples", "2000"]
261
  prepare_main()
262
  finally:
263
  sys.argv = argv_bak
 
390
 
391
 
392
  def gradio_recommend(files: List[str], occasion: str, weather: str, num_outfits: int):
393
+ # Check model status first
394
+ model_status = service.get_model_status()
395
+ if not model_status["can_recommend"]:
396
+ error_msg = "❌ Models not ready for recommendations!\n\n"
397
+ error_msg += "**Model Status:**\n"
398
+ error_msg += f"- ResNet: {'βœ… Loaded' if model_status['resnet_loaded'] else '❌ Not loaded'}\n"
399
+ error_msg += f"- ViT: {'βœ… Loaded' if model_status['vit_loaded'] else '❌ Not loaded'}\n\n"
400
+ error_msg += "**Errors:**\n"
401
+ for error in model_status["errors"]:
402
+ error_msg += f"- {error}\n\n"
403
+ error_msg += "**Solution:**\n"
404
+ error_msg += "Please train the models first using the 'Simple Training' or 'Advanced Training' tabs, or ensure trained checkpoints are available."
405
+ return [], {"error": error_msg, "model_status": model_status}
406
+
407
  # Return stitched outfit images and a JSON with details
408
  if not files:
409
  return [], {"error": "No files uploaded"}
 
416
  for i in range(len(images))
417
  ]
418
  res = service.compose_outfits(items, context={"occasion": occasion, "weather": weather, "num_outfits": int(num_outfits)})
419
+
420
+ # Check if compose_outfits returned an error
421
+ if res and isinstance(res[0], dict) and "error" in res[0]:
422
+ return [], res[0]
423
+
424
  # Prepare stitched previews
425
  strips: List[Image.Image] = []
426
  for r in res:
 
614
  log_message += "πŸŽ‰ All training completed! Models saved to models/exports/\n"
615
  log_message += "πŸ”„ Reloading models for inference...\n"
616
  service.reload_models()
617
+
618
+ # Check if models loaded successfully
619
+ model_status = service.get_model_status()
620
+ if model_status["can_recommend"]:
621
+ log_message += "βœ… Models reloaded and ready for inference!\n"
622
+ log_message += "πŸŽ‰ You can now generate outfit recommendations!\n"
623
+ else:
624
+ log_message += "⚠️ Models reloaded but validation failed!\n"
625
+ log_message += "**Model Status:**\n"
626
+ log_message += f"- ResNet: {'βœ… Loaded' if model_status['resnet_loaded'] else '❌ Failed'}\n"
627
+ log_message += f"- ViT: {'βœ… Loaded' if model_status['vit_loaded'] else '❌ Failed'}\n"
628
+ for error in model_status["errors"]:
629
+ log_message += f"- {error}\n"
630
 
631
  # Auto-upload to HF Hub if token is available
632
  hf_token = os.getenv("HF_TOKEN")
 
720
  log_message += f"❌ ViT training failed: {vit_result.stderr}\n"
721
  return log_message
722
  service.reload_models()
723
+
724
+ # Check if models loaded successfully
725
+ model_status = service.get_model_status()
726
+ if model_status["can_recommend"]:
727
+ log_message += "\nβœ… Training completed! Models reloaded and ready for inference.\n"
728
+ log_message += "πŸŽ‰ You can now generate outfit recommendations!\n"
729
+ else:
730
+ log_message += "\n⚠️ Training completed but models failed to load properly!\n"
731
+ log_message += "**Model Status:**\n"
732
+ log_message += f"- ResNet: {'βœ… Loaded' if model_status['resnet_loaded'] else '❌ Failed'}\n"
733
+ log_message += f"- ViT: {'βœ… Loaded' if model_status['vit_loaded'] else '❌ Failed'}\n"
734
+ for error in model_status["errors"]:
735
+ log_message += f"- {error}\n"
736
+
737
+ log_message += "\nArtifacts saved to models/exports/"
738
 
739
  # Auto-upload to HF Hub if token is available
740
  hf_token = os.getenv("HF_TOKEN")
 
785
 
786
  with gr.Row():
787
  gr.Markdown("#### πŸ“Š **Current Behavior**")
788
+ gr.Markdown("β€’ **Bootstrap**: Downloads full dataset (53K outfits) + generates splits with **2000 samples by default**\nβ€’ **Training**: Uses 2000 samples (good for early stopping demonstration!)\nβ€’ **Apply Button**: Regenerates splits with your selected size limit")
789
 
790
  with gr.Row():
791
  global_dataset_size = gr.Dropdown(
792
  choices=["160", "2000", "5000", "10000", "25000", "50000", "full"],
793
+ value="2000",
794
  label="Global Dataset Size (Affects Prep + Training)"
795
  )
796
  gr.Markdown("**160**: Ultra-fast testing (~30 sec prep, ~1-2 min training)\n**2000**: Fast testing (~1-2 min prep, ~2-5 min training)\n**5000**: Fast testing (~2-3 min prep, ~5-10 min training)\n**10000**: Good testing (~3-5 min prep, ~10-20 min training)\n**full**: Production (~5-10 min prep, ~1-4 hours training)")
 
798
  with gr.Row():
799
  # Apply dataset size button
800
  apply_size_btn = gr.Button("πŸ”„ Apply Dataset Size & Regenerate Splits", variant="primary")
801
+ size_status = gr.Textbox(label="Dataset Size Status", value="Dataset size: 2000 samples (click Apply to regenerate splits)", interactive=False)
802
 
803
  # Current dataset info
804
  gr.Markdown("#### πŸ“Š **Current Dataset Status**")
805
+ gr.Markdown("β€’ **Full dataset downloaded**: 53,306 outfits (required for system)\nβ€’ **Splits generated**: **2000 samples by default** (good for early stopping!)\nβ€’ **Training will use**: 2000 samples (good for early stopping demonstration!)\nβ€’ **Scale up**: Use Apply button to increase to larger sizes")
806
 
807
  def apply_dataset_size(size: str):
808
  """Apply global dataset size and regenerate splits."""
 
855
  gr.Markdown("#### πŸ“Š Dataset Size Control")
856
  gr.Markdown("Start small for testing, increase for production training")
857
  dataset_size = gr.Dropdown(
858
+ choices=["160", "2000", "5000", "10000", "25000", "50000", "full"],
859
  value="2000",
860
  label="Training Dataset Size"
861
  )
 
1048
  gr.Markdown("#### πŸ“Š Dataset Size Control")
1049
  gr.Markdown("Start small for testing, increase for production training")
1050
  dataset_size = gr.Dropdown(
1051
+ choices=["160", "2000", "5000", "10000", "25000", "50000", "full"],
1052
  value="2000",
1053
  label="Training Dataset Size"
1054
  )
 
1077
  refresh_status = gr.Button("πŸ”„ Refresh Status")
1078
  refresh_status.click(fn=lambda: BOOT_STATUS, inputs=[], outputs=status)
1079
 
1080
+ # Model Status
1081
+ gr.Markdown("#### πŸ€– Model Status")
1082
+ model_status = gr.JSON(label="Model Loading Status", value=lambda: service.get_model_status())
1083
+ refresh_models = gr.Button("πŸ”„ Refresh Model Status")
1084
+ refresh_models.click(fn=lambda: service.get_model_status(), inputs=[], outputs=model_status)
1085
+
1086
  # System info
1087
  gr.Markdown("#### πŸ’» System Information")
1088
  device_info = gr.Textbox(label="Device", value=lambda: f"Device: {service.device}")
inference.py CHANGED
@@ -27,20 +27,43 @@ class InferenceService:
27
  self.embed_dim = int(os.getenv("EMBED_DIM", "512"))
28
  self.resnet_version = "resnet_v1"
29
  self.vit_version = "vit_v1"
30
-
31
- self.resnet = self._load_resnet().to(self.device).eval()
32
- self.vit = self._load_vit().to(self.device).eval()
33
-
 
 
 
 
 
 
 
 
 
 
 
 
34
  for m in [self.resnet, self.vit]:
35
- for p in m.parameters():
36
- p.requires_grad_(False)
 
 
 
 
 
 
 
 
 
 
37
 
38
- def _load_resnet(self) -> nn.Module:
39
  strategy = os.getenv("MODEL_LOAD_STRATEGY", "state_dict")
40
  ckpt_path = os.getenv("RESNET_CHECKPOINT", "models/exports/resnet_item_embedder.pth")
41
- model = ResNetItemEmbedder(embedding_dim=self.embed_dim)
42
  if strategy == "random":
43
- return model
 
44
 
45
  # Try to download from Hugging Face Hub first
46
  try:
@@ -52,34 +75,48 @@ class InferenceService:
52
  local_dir_use_symlinks=False
53
  )
54
  print(f"πŸ“₯ Downloaded ResNet from HF Hub: {hf_path}")
 
55
  state = torch.load(hf_path, map_location="cpu")
56
  state_dict = state.get("state_dict", state) if isinstance(state, dict) else state
57
  model.load_state_dict(state_dict, strict=False)
58
- return model
 
59
  except Exception as e:
60
  print(f"❌ Failed to download ResNet from HF Hub: {e}")
61
- print("⚠️ WARNING: Using untrained ResNet model!")
62
- print("🚨 Recommendations will not be meaningful without trained weights!")
63
 
64
- # Fallback to local checkpoints
65
  best_path = os.path.join(os.path.dirname(ckpt_path), "resnet_item_embedder_best.pth")
66
  if os.path.exists(best_path):
67
- ckpt_to_use = best_path
68
- else:
69
- ckpt_to_use = ckpt_path
70
- if os.path.exists(ckpt_to_use):
71
- state = torch.load(ckpt_to_use, map_location="cpu")
 
 
 
 
 
 
 
 
72
  state_dict = state.get("state_dict", state) if isinstance(state, dict) else state
73
  model.load_state_dict(state_dict, strict=False)
74
- return model
75
- return model
 
 
 
 
 
76
 
77
- def _load_vit(self) -> nn.Module:
78
  strategy = os.getenv("MODEL_LOAD_STRATEGY", "state_dict")
79
  ckpt_path = os.getenv("VIT_CHECKPOINT", "models/exports/vit_outfit_model.pth")
80
- model = OutfitCompatibilityModel(embedding_dim=self.embed_dim)
81
  if strategy == "random":
82
- return model
 
83
 
84
  # Try to download from Hugging Face Hub first
85
  try:
@@ -91,32 +128,66 @@ class InferenceService:
91
  local_dir_use_symlinks=False
92
  )
93
  print(f"πŸ“₯ Downloaded ViT from HF Hub: {hf_path}")
 
94
  state = torch.load(hf_path, map_location="cpu")
95
  state_dict = state.get("state_dict", state) if isinstance(state, dict) else state
96
  model.load_state_dict(state_dict, strict=False)
97
- return model
 
98
  except Exception as e:
99
  print(f"❌ Failed to download ViT from HF Hub: {e}")
100
- print("⚠️ WARNING: Using untrained ViT model!")
101
- print("🚨 Recommendations will not be meaningful without trained weights!")
102
 
103
- # Fallback to local checkpoints
104
  best_path = os.path.join(os.path.dirname(ckpt_path), "vit_outfit_model_best.pth")
105
- ckpt_to_use = best_path if os.path.exists(best_path) else ckpt_path
106
- if os.path.exists(ckpt_to_use):
107
- state = torch.load(ckpt_to_use, map_location="cpu")
 
 
 
 
 
 
 
 
 
 
 
108
  state_dict = state.get("state_dict", state) if isinstance(state, dict) else state
109
  model.load_state_dict(state_dict, strict=False)
110
- return model
111
- return model
 
 
 
 
 
112
 
113
  def reload_models(self) -> None:
114
  """Reload weights from current checkpoint locations (used after background training)."""
115
- self.resnet = self._load_resnet().to(self.device).eval()
116
- self.vit = self._load_vit().to(self.device).eval()
 
 
 
 
 
 
 
 
117
  for m in [self.resnet, self.vit]:
118
- for p in m.parameters():
119
- p.requires_grad_(False)
 
 
 
 
 
 
 
 
 
 
120
 
121
  @torch.inference_mode()
122
  def embed_images(self, images: List[Image.Image]) -> List[np.ndarray]:
@@ -132,6 +203,16 @@ class InferenceService:
132
 
133
  @torch.inference_mode()
134
  def compose_outfits(self, items: List[Dict[str, Any]], context: Dict[str, Any]) -> List[Dict[str, Any]]:
 
 
 
 
 
 
 
 
 
 
135
  # 1) Ensure embeddings for each input item
136
  proc_items: List[Dict[str, Any]] = []
137
  for it in items:
@@ -248,5 +329,15 @@ class InferenceService:
248
  for subset, score in topk
249
  ]
250
  return results
 
 
 
 
 
 
 
 
 
 
251
 
252
 
 
27
  self.embed_dim = int(os.getenv("EMBED_DIM", "512"))
28
  self.resnet_version = "resnet_v1"
29
  self.vit_version = "vit_v1"
30
+
31
+ # Model loading status tracking
32
+ self.models_loaded = False
33
+ self.model_errors = []
34
+
35
+ # Load models with validation
36
+ self.resnet, self.resnet_loaded = self._load_resnet()
37
+ self.vit, self.vit_loaded = self._load_vit()
38
+
39
+ # Move to device and set eval mode
40
+ if self.resnet_loaded:
41
+ self.resnet = self.resnet.to(self.device).eval()
42
+ if self.vit_loaded:
43
+ self.vit = self.vit.to(self.device).eval()
44
+
45
+ # Disable gradients
46
  for m in [self.resnet, self.vit]:
47
+ if m is not None:
48
+ for p in m.parameters():
49
+ p.requires_grad_(False)
50
+
51
+ # Update overall status
52
+ self.models_loaded = self.resnet_loaded and self.vit_loaded
53
+ if not self.models_loaded:
54
+ self.model_errors = []
55
+ if not self.resnet_loaded:
56
+ self.model_errors.append("ResNet: No trained weights found")
57
+ if not self.vit_loaded:
58
+ self.model_errors.append("ViT: No trained weights found")
59
 
60
+ def _load_resnet(self) -> tuple[nn.Module, bool]:
61
  strategy = os.getenv("MODEL_LOAD_STRATEGY", "state_dict")
62
  ckpt_path = os.getenv("RESNET_CHECKPOINT", "models/exports/resnet_item_embedder.pth")
63
+
64
  if strategy == "random":
65
+ print("⚠️ Random strategy selected - no trained weights will be loaded!")
66
+ return ResNetItemEmbedder(embedding_dim=self.embed_dim), False
67
 
68
  # Try to download from Hugging Face Hub first
69
  try:
 
75
  local_dir_use_symlinks=False
76
  )
77
  print(f"πŸ“₯ Downloaded ResNet from HF Hub: {hf_path}")
78
+ model = ResNetItemEmbedder(embedding_dim=self.embed_dim)
79
  state = torch.load(hf_path, map_location="cpu")
80
  state_dict = state.get("state_dict", state) if isinstance(state, dict) else state
81
  model.load_state_dict(state_dict, strict=False)
82
+ print("βœ… ResNet model loaded successfully from HF Hub")
83
+ return model, True
84
  except Exception as e:
85
  print(f"❌ Failed to download ResNet from HF Hub: {e}")
 
 
86
 
87
+ # Check for local best checkpoint first
88
  best_path = os.path.join(os.path.dirname(ckpt_path), "resnet_item_embedder_best.pth")
89
  if os.path.exists(best_path):
90
+ print(f"πŸ“ Loading ResNet from best checkpoint: {best_path}")
91
+ model = ResNetItemEmbedder(embedding_dim=self.embed_dim)
92
+ state = torch.load(best_path, map_location="cpu")
93
+ state_dict = state.get("state_dict", state) if isinstance(state, dict) else state
94
+ model.load_state_dict(state_dict, strict=False)
95
+ print("βœ… ResNet model loaded successfully from best checkpoint")
96
+ return model, True
97
+
98
+ # Check for regular checkpoint
99
+ if os.path.exists(ckpt_path):
100
+ print(f"πŸ“ Loading ResNet from checkpoint: {ckpt_path}")
101
+ model = ResNetItemEmbedder(embedding_dim=self.embed_dim)
102
+ state = torch.load(ckpt_path, map_location="cpu")
103
  state_dict = state.get("state_dict", state) if isinstance(state, dict) else state
104
  model.load_state_dict(state_dict, strict=False)
105
+ print("βœ… ResNet model loaded successfully from checkpoint")
106
+ return model, True
107
+
108
+ print("❌ CRITICAL: No trained ResNet weights found!")
109
+ print("🚨 Cannot provide recommendations without trained weights!")
110
+ print("πŸ’‘ Please train the ResNet model first using the training tabs.")
111
+ return ResNetItemEmbedder(embedding_dim=self.embed_dim), False
112
 
113
+ def _load_vit(self) -> tuple[nn.Module, bool]:
114
  strategy = os.getenv("MODEL_LOAD_STRATEGY", "state_dict")
115
  ckpt_path = os.getenv("VIT_CHECKPOINT", "models/exports/vit_outfit_model.pth")
116
+
117
  if strategy == "random":
118
+ print("⚠️ Random strategy selected - no trained weights will be loaded!")
119
+ return OutfitCompatibilityModel(embedding_dim=self.embed_dim), False
120
 
121
  # Try to download from Hugging Face Hub first
122
  try:
 
128
  local_dir_use_symlinks=False
129
  )
130
  print(f"πŸ“₯ Downloaded ViT from HF Hub: {hf_path}")
131
+ model = OutfitCompatibilityModel(embedding_dim=self.embed_dim)
132
  state = torch.load(hf_path, map_location="cpu")
133
  state_dict = state.get("state_dict", state) if isinstance(state, dict) else state
134
  model.load_state_dict(state_dict, strict=False)
135
+ print("βœ… ViT model loaded successfully from HF Hub")
136
+ return model, True
137
  except Exception as e:
138
  print(f"❌ Failed to download ViT from HF Hub: {e}")
 
 
139
 
140
+ # Check for local best checkpoint first
141
  best_path = os.path.join(os.path.dirname(ckpt_path), "vit_outfit_model_best.pth")
142
+ if os.path.exists(best_path):
143
+ print(f"πŸ“ Loading ViT from best checkpoint: {best_path}")
144
+ model = OutfitCompatibilityModel(embedding_dim=self.embed_dim)
145
+ state = torch.load(best_path, map_location="cpu")
146
+ state_dict = state.get("state_dict", state) if isinstance(state, dict) else state
147
+ model.load_state_dict(state_dict, strict=False)
148
+ print("βœ… ViT model loaded successfully from best checkpoint")
149
+ return model, True
150
+
151
+ # Check for regular checkpoint
152
+ if os.path.exists(ckpt_path):
153
+ print(f"πŸ“ Loading ViT from checkpoint: {ckpt_path}")
154
+ model = OutfitCompatibilityModel(embedding_dim=self.embed_dim)
155
+ state = torch.load(ckpt_path, map_location="cpu")
156
  state_dict = state.get("state_dict", state) if isinstance(state, dict) else state
157
  model.load_state_dict(state_dict, strict=False)
158
+ print("βœ… ViT model loaded successfully from checkpoint")
159
+ return model, True
160
+
161
+ print("❌ CRITICAL: No trained ViT weights found!")
162
+ print("🚨 Cannot provide recommendations without trained weights!")
163
+ print("πŸ’‘ Please train the ViT model first using the training tabs.")
164
+ return OutfitCompatibilityModel(embedding_dim=self.embed_dim), False
165
 
166
  def reload_models(self) -> None:
167
  """Reload weights from current checkpoint locations (used after background training)."""
168
+ self.resnet, self.resnet_loaded = self._load_resnet()
169
+ self.vit, self.vit_loaded = self._load_vit()
170
+
171
+ # Move to device and set eval mode
172
+ if self.resnet_loaded:
173
+ self.resnet = self.resnet.to(self.device).eval()
174
+ if self.vit_loaded:
175
+ self.vit = self.vit.to(self.device).eval()
176
+
177
+ # Disable gradients
178
  for m in [self.resnet, self.vit]:
179
+ if m is not None:
180
+ for p in m.parameters():
181
+ p.requires_grad_(False)
182
+
183
+ # Update overall status
184
+ self.models_loaded = self.resnet_loaded and self.vit_loaded
185
+ if not self.models_loaded:
186
+ self.model_errors = []
187
+ if not self.resnet_loaded:
188
+ self.model_errors.append("ResNet: No trained weights found")
189
+ if not self.vit_loaded:
190
+ self.model_errors.append("ViT: No trained weights found")
191
 
192
  @torch.inference_mode()
193
  def embed_images(self, images: List[Image.Image]) -> List[np.ndarray]:
 
203
 
204
  @torch.inference_mode()
205
  def compose_outfits(self, items: List[Dict[str, Any]], context: Dict[str, Any]) -> List[Dict[str, Any]]:
206
+ # Validate that models are properly loaded
207
+ if not self.models_loaded:
208
+ error_msg = f"❌ Cannot provide recommendations: Models not properly loaded. Errors: {self.model_errors}"
209
+ print(error_msg)
210
+ return [{
211
+ "error": "Models not trained or loaded properly",
212
+ "details": self.model_errors,
213
+ "message": "Please ensure models are trained and checkpoints exist before generating recommendations."
214
+ }]
215
+
216
  # 1) Ensure embeddings for each input item
217
  proc_items: List[Dict[str, Any]] = []
218
  for it in items:
 
329
  for subset, score in topk
330
  ]
331
  return results
332
+
333
+ def get_model_status(self) -> Dict[str, Any]:
334
+ """Get current model loading status and errors."""
335
+ return {
336
+ "models_loaded": self.models_loaded,
337
+ "resnet_loaded": self.resnet_loaded,
338
+ "vit_loaded": self.vit_loaded,
339
+ "errors": self.model_errors,
340
+ "can_recommend": self.models_loaded
341
+ }
342
 
343