Spaces:
Paused
Paused
Ali Mohsin
commited on
Commit
Β·
4a5ec80
1
Parent(s):
25bdf34
Updated new changes
Browse files- app.py +61 -10
- inference.py +127 -36
app.py
CHANGED
|
@@ -256,8 +256,8 @@ def _background_bootstrap():
|
|
| 256 |
import sys
|
| 257 |
argv_bak = sys.argv
|
| 258 |
try:
|
| 259 |
-
# Use official splits from nondisjoint/ and disjoint/ folders with default size limit (
|
| 260 |
-
sys.argv = ["prepare_polyvore.py", "--root", ds_root, "--max_samples", "
|
| 261 |
prepare_main()
|
| 262 |
finally:
|
| 263 |
sys.argv = argv_bak
|
|
@@ -390,6 +390,20 @@ def _stitch_strip(imgs: List[Image.Image], height: int = 256, pad: int = 6, bg=(
|
|
| 390 |
|
| 391 |
|
| 392 |
def gradio_recommend(files: List[str], occasion: str, weather: str, num_outfits: int):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 393 |
# Return stitched outfit images and a JSON with details
|
| 394 |
if not files:
|
| 395 |
return [], {"error": "No files uploaded"}
|
|
@@ -402,6 +416,11 @@ def gradio_recommend(files: List[str], occasion: str, weather: str, num_outfits:
|
|
| 402 |
for i in range(len(images))
|
| 403 |
]
|
| 404 |
res = service.compose_outfits(items, context={"occasion": occasion, "weather": weather, "num_outfits": int(num_outfits)})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 405 |
# Prepare stitched previews
|
| 406 |
strips: List[Image.Image] = []
|
| 407 |
for r in res:
|
|
@@ -595,7 +614,19 @@ def start_training_advanced(
|
|
| 595 |
log_message += "π All training completed! Models saved to models/exports/\n"
|
| 596 |
log_message += "π Reloading models for inference...\n"
|
| 597 |
service.reload_models()
|
| 598 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 599 |
|
| 600 |
# Auto-upload to HF Hub if token is available
|
| 601 |
hf_token = os.getenv("HF_TOKEN")
|
|
@@ -689,7 +720,21 @@ def start_training_simple(dataset_size: str, res_epochs: int, vit_epochs: int):
|
|
| 689 |
log_message += f"β ViT training failed: {vit_result.stderr}\n"
|
| 690 |
return log_message
|
| 691 |
service.reload_models()
|
| 692 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 693 |
|
| 694 |
# Auto-upload to HF Hub if token is available
|
| 695 |
hf_token = os.getenv("HF_TOKEN")
|
|
@@ -740,12 +785,12 @@ with gr.Blocks(fill_height=True, title="Dressify - Advanced Outfit Recommendatio
|
|
| 740 |
|
| 741 |
with gr.Row():
|
| 742 |
gr.Markdown("#### π **Current Behavior**")
|
| 743 |
-
gr.Markdown("β’ **Bootstrap**: Downloads full dataset (53K outfits) + generates splits with **
|
| 744 |
|
| 745 |
with gr.Row():
|
| 746 |
global_dataset_size = gr.Dropdown(
|
| 747 |
choices=["160", "2000", "5000", "10000", "25000", "50000", "full"],
|
| 748 |
-
value="
|
| 749 |
label="Global Dataset Size (Affects Prep + Training)"
|
| 750 |
)
|
| 751 |
gr.Markdown("**160**: Ultra-fast testing (~30 sec prep, ~1-2 min training)\n**2000**: Fast testing (~1-2 min prep, ~2-5 min training)\n**5000**: Fast testing (~2-3 min prep, ~5-10 min training)\n**10000**: Good testing (~3-5 min prep, ~10-20 min training)\n**full**: Production (~5-10 min prep, ~1-4 hours training)")
|
|
@@ -753,11 +798,11 @@ with gr.Blocks(fill_height=True, title="Dressify - Advanced Outfit Recommendatio
|
|
| 753 |
with gr.Row():
|
| 754 |
# Apply dataset size button
|
| 755 |
apply_size_btn = gr.Button("π Apply Dataset Size & Regenerate Splits", variant="primary")
|
| 756 |
-
size_status = gr.Textbox(label="Dataset Size Status", value="Dataset size:
|
| 757 |
|
| 758 |
# Current dataset info
|
| 759 |
gr.Markdown("#### π **Current Dataset Status**")
|
| 760 |
-
gr.Markdown("β’ **Full dataset downloaded**: 53,306 outfits (required for system)\nβ’ **Splits generated**: **
|
| 761 |
|
| 762 |
def apply_dataset_size(size: str):
|
| 763 |
"""Apply global dataset size and regenerate splits."""
|
|
@@ -810,7 +855,7 @@ with gr.Blocks(fill_height=True, title="Dressify - Advanced Outfit Recommendatio
|
|
| 810 |
gr.Markdown("#### π Dataset Size Control")
|
| 811 |
gr.Markdown("Start small for testing, increase for production training")
|
| 812 |
dataset_size = gr.Dropdown(
|
| 813 |
-
choices=["2000", "5000", "10000", "25000", "50000", "full"],
|
| 814 |
value="2000",
|
| 815 |
label="Training Dataset Size"
|
| 816 |
)
|
|
@@ -1003,7 +1048,7 @@ with gr.Blocks(fill_height=True, title="Dressify - Advanced Outfit Recommendatio
|
|
| 1003 |
gr.Markdown("#### π Dataset Size Control")
|
| 1004 |
gr.Markdown("Start small for testing, increase for production training")
|
| 1005 |
dataset_size = gr.Dropdown(
|
| 1006 |
-
choices=["2000", "5000", "10000", "25000", "50000", "full"],
|
| 1007 |
value="2000",
|
| 1008 |
label="Training Dataset Size"
|
| 1009 |
)
|
|
@@ -1032,6 +1077,12 @@ with gr.Blocks(fill_height=True, title="Dressify - Advanced Outfit Recommendatio
|
|
| 1032 |
refresh_status = gr.Button("π Refresh Status")
|
| 1033 |
refresh_status.click(fn=lambda: BOOT_STATUS, inputs=[], outputs=status)
|
| 1034 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1035 |
# System info
|
| 1036 |
gr.Markdown("#### π» System Information")
|
| 1037 |
device_info = gr.Textbox(label="Device", value=lambda: f"Device: {service.device}")
|
|
|
|
| 256 |
import sys
|
| 257 |
argv_bak = sys.argv
|
| 258 |
try:
|
| 259 |
+
# Use official splits from nondisjoint/ and disjoint/ folders with default size limit (2000 samples for better early stopping)
|
| 260 |
+
sys.argv = ["prepare_polyvore.py", "--root", ds_root, "--max_samples", "2000"]
|
| 261 |
prepare_main()
|
| 262 |
finally:
|
| 263 |
sys.argv = argv_bak
|
|
|
|
| 390 |
|
| 391 |
|
| 392 |
def gradio_recommend(files: List[str], occasion: str, weather: str, num_outfits: int):
|
| 393 |
+
# Check model status first
|
| 394 |
+
model_status = service.get_model_status()
|
| 395 |
+
if not model_status["can_recommend"]:
|
| 396 |
+
error_msg = "β Models not ready for recommendations!\n\n"
|
| 397 |
+
error_msg += "**Model Status:**\n"
|
| 398 |
+
error_msg += f"- ResNet: {'β
Loaded' if model_status['resnet_loaded'] else 'β Not loaded'}\n"
|
| 399 |
+
error_msg += f"- ViT: {'β
Loaded' if model_status['vit_loaded'] else 'β Not loaded'}\n\n"
|
| 400 |
+
error_msg += "**Errors:**\n"
|
| 401 |
+
for error in model_status["errors"]:
|
| 402 |
+
error_msg += f"- {error}\n\n"
|
| 403 |
+
error_msg += "**Solution:**\n"
|
| 404 |
+
error_msg += "Please train the models first using the 'Simple Training' or 'Advanced Training' tabs, or ensure trained checkpoints are available."
|
| 405 |
+
return [], {"error": error_msg, "model_status": model_status}
|
| 406 |
+
|
| 407 |
# Return stitched outfit images and a JSON with details
|
| 408 |
if not files:
|
| 409 |
return [], {"error": "No files uploaded"}
|
|
|
|
| 416 |
for i in range(len(images))
|
| 417 |
]
|
| 418 |
res = service.compose_outfits(items, context={"occasion": occasion, "weather": weather, "num_outfits": int(num_outfits)})
|
| 419 |
+
|
| 420 |
+
# Check if compose_outfits returned an error
|
| 421 |
+
if res and isinstance(res[0], dict) and "error" in res[0]:
|
| 422 |
+
return [], res[0]
|
| 423 |
+
|
| 424 |
# Prepare stitched previews
|
| 425 |
strips: List[Image.Image] = []
|
| 426 |
for r in res:
|
|
|
|
| 614 |
log_message += "π All training completed! Models saved to models/exports/\n"
|
| 615 |
log_message += "π Reloading models for inference...\n"
|
| 616 |
service.reload_models()
|
| 617 |
+
|
| 618 |
+
# Check if models loaded successfully
|
| 619 |
+
model_status = service.get_model_status()
|
| 620 |
+
if model_status["can_recommend"]:
|
| 621 |
+
log_message += "β
Models reloaded and ready for inference!\n"
|
| 622 |
+
log_message += "π You can now generate outfit recommendations!\n"
|
| 623 |
+
else:
|
| 624 |
+
log_message += "β οΈ Models reloaded but validation failed!\n"
|
| 625 |
+
log_message += "**Model Status:**\n"
|
| 626 |
+
log_message += f"- ResNet: {'β
Loaded' if model_status['resnet_loaded'] else 'β Failed'}\n"
|
| 627 |
+
log_message += f"- ViT: {'β
Loaded' if model_status['vit_loaded'] else 'β Failed'}\n"
|
| 628 |
+
for error in model_status["errors"]:
|
| 629 |
+
log_message += f"- {error}\n"
|
| 630 |
|
| 631 |
# Auto-upload to HF Hub if token is available
|
| 632 |
hf_token = os.getenv("HF_TOKEN")
|
|
|
|
| 720 |
log_message += f"β ViT training failed: {vit_result.stderr}\n"
|
| 721 |
return log_message
|
| 722 |
service.reload_models()
|
| 723 |
+
|
| 724 |
+
# Check if models loaded successfully
|
| 725 |
+
model_status = service.get_model_status()
|
| 726 |
+
if model_status["can_recommend"]:
|
| 727 |
+
log_message += "\nβ
Training completed! Models reloaded and ready for inference.\n"
|
| 728 |
+
log_message += "π You can now generate outfit recommendations!\n"
|
| 729 |
+
else:
|
| 730 |
+
log_message += "\nβ οΈ Training completed but models failed to load properly!\n"
|
| 731 |
+
log_message += "**Model Status:**\n"
|
| 732 |
+
log_message += f"- ResNet: {'β
Loaded' if model_status['resnet_loaded'] else 'β Failed'}\n"
|
| 733 |
+
log_message += f"- ViT: {'β
Loaded' if model_status['vit_loaded'] else 'β Failed'}\n"
|
| 734 |
+
for error in model_status["errors"]:
|
| 735 |
+
log_message += f"- {error}\n"
|
| 736 |
+
|
| 737 |
+
log_message += "\nArtifacts saved to models/exports/"
|
| 738 |
|
| 739 |
# Auto-upload to HF Hub if token is available
|
| 740 |
hf_token = os.getenv("HF_TOKEN")
|
|
|
|
| 785 |
|
| 786 |
with gr.Row():
|
| 787 |
gr.Markdown("#### π **Current Behavior**")
|
| 788 |
+
gr.Markdown("β’ **Bootstrap**: Downloads full dataset (53K outfits) + generates splits with **2000 samples by default**\nβ’ **Training**: Uses 2000 samples (good for early stopping demonstration!)\nβ’ **Apply Button**: Regenerates splits with your selected size limit")
|
| 789 |
|
| 790 |
with gr.Row():
|
| 791 |
global_dataset_size = gr.Dropdown(
|
| 792 |
choices=["160", "2000", "5000", "10000", "25000", "50000", "full"],
|
| 793 |
+
value="2000",
|
| 794 |
label="Global Dataset Size (Affects Prep + Training)"
|
| 795 |
)
|
| 796 |
gr.Markdown("**160**: Ultra-fast testing (~30 sec prep, ~1-2 min training)\n**2000**: Fast testing (~1-2 min prep, ~2-5 min training)\n**5000**: Fast testing (~2-3 min prep, ~5-10 min training)\n**10000**: Good testing (~3-5 min prep, ~10-20 min training)\n**full**: Production (~5-10 min prep, ~1-4 hours training)")
|
|
|
|
| 798 |
with gr.Row():
|
| 799 |
# Apply dataset size button
|
| 800 |
apply_size_btn = gr.Button("π Apply Dataset Size & Regenerate Splits", variant="primary")
|
| 801 |
+
size_status = gr.Textbox(label="Dataset Size Status", value="Dataset size: 2000 samples (click Apply to regenerate splits)", interactive=False)
|
| 802 |
|
| 803 |
# Current dataset info
|
| 804 |
gr.Markdown("#### π **Current Dataset Status**")
|
| 805 |
+
gr.Markdown("β’ **Full dataset downloaded**: 53,306 outfits (required for system)\nβ’ **Splits generated**: **2000 samples by default** (good for early stopping!)\nβ’ **Training will use**: 2000 samples (good for early stopping demonstration!)\nβ’ **Scale up**: Use Apply button to increase to larger sizes")
|
| 806 |
|
| 807 |
def apply_dataset_size(size: str):
|
| 808 |
"""Apply global dataset size and regenerate splits."""
|
|
|
|
| 855 |
gr.Markdown("#### π Dataset Size Control")
|
| 856 |
gr.Markdown("Start small for testing, increase for production training")
|
| 857 |
dataset_size = gr.Dropdown(
|
| 858 |
+
choices=["160", "2000", "5000", "10000", "25000", "50000", "full"],
|
| 859 |
value="2000",
|
| 860 |
label="Training Dataset Size"
|
| 861 |
)
|
|
|
|
| 1048 |
gr.Markdown("#### π Dataset Size Control")
|
| 1049 |
gr.Markdown("Start small for testing, increase for production training")
|
| 1050 |
dataset_size = gr.Dropdown(
|
| 1051 |
+
choices=["160", "2000", "5000", "10000", "25000", "50000", "full"],
|
| 1052 |
value="2000",
|
| 1053 |
label="Training Dataset Size"
|
| 1054 |
)
|
|
|
|
| 1077 |
refresh_status = gr.Button("π Refresh Status")
|
| 1078 |
refresh_status.click(fn=lambda: BOOT_STATUS, inputs=[], outputs=status)
|
| 1079 |
|
| 1080 |
+
# Model Status
|
| 1081 |
+
gr.Markdown("#### π€ Model Status")
|
| 1082 |
+
model_status = gr.JSON(label="Model Loading Status", value=lambda: service.get_model_status())
|
| 1083 |
+
refresh_models = gr.Button("π Refresh Model Status")
|
| 1084 |
+
refresh_models.click(fn=lambda: service.get_model_status(), inputs=[], outputs=model_status)
|
| 1085 |
+
|
| 1086 |
# System info
|
| 1087 |
gr.Markdown("#### π» System Information")
|
| 1088 |
device_info = gr.Textbox(label="Device", value=lambda: f"Device: {service.device}")
|
inference.py
CHANGED
|
@@ -27,20 +27,43 @@ class InferenceService:
|
|
| 27 |
self.embed_dim = int(os.getenv("EMBED_DIM", "512"))
|
| 28 |
self.resnet_version = "resnet_v1"
|
| 29 |
self.vit_version = "vit_v1"
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
self.
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
for m in [self.resnet, self.vit]:
|
| 35 |
-
|
| 36 |
-
p.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
-
def _load_resnet(self) -> nn.Module:
|
| 39 |
strategy = os.getenv("MODEL_LOAD_STRATEGY", "state_dict")
|
| 40 |
ckpt_path = os.getenv("RESNET_CHECKPOINT", "models/exports/resnet_item_embedder.pth")
|
| 41 |
-
|
| 42 |
if strategy == "random":
|
| 43 |
-
|
|
|
|
| 44 |
|
| 45 |
# Try to download from Hugging Face Hub first
|
| 46 |
try:
|
|
@@ -52,34 +75,48 @@ class InferenceService:
|
|
| 52 |
local_dir_use_symlinks=False
|
| 53 |
)
|
| 54 |
print(f"π₯ Downloaded ResNet from HF Hub: {hf_path}")
|
|
|
|
| 55 |
state = torch.load(hf_path, map_location="cpu")
|
| 56 |
state_dict = state.get("state_dict", state) if isinstance(state, dict) else state
|
| 57 |
model.load_state_dict(state_dict, strict=False)
|
| 58 |
-
|
|
|
|
| 59 |
except Exception as e:
|
| 60 |
print(f"β Failed to download ResNet from HF Hub: {e}")
|
| 61 |
-
print("β οΈ WARNING: Using untrained ResNet model!")
|
| 62 |
-
print("π¨ Recommendations will not be meaningful without trained weights!")
|
| 63 |
|
| 64 |
-
#
|
| 65 |
best_path = os.path.join(os.path.dirname(ckpt_path), "resnet_item_embedder_best.pth")
|
| 66 |
if os.path.exists(best_path):
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
state_dict = state.get("state_dict", state) if isinstance(state, dict) else state
|
| 73 |
model.load_state_dict(state_dict, strict=False)
|
| 74 |
-
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |
-
def _load_vit(self) -> nn.Module:
|
| 78 |
strategy = os.getenv("MODEL_LOAD_STRATEGY", "state_dict")
|
| 79 |
ckpt_path = os.getenv("VIT_CHECKPOINT", "models/exports/vit_outfit_model.pth")
|
| 80 |
-
|
| 81 |
if strategy == "random":
|
| 82 |
-
|
|
|
|
| 83 |
|
| 84 |
# Try to download from Hugging Face Hub first
|
| 85 |
try:
|
|
@@ -91,32 +128,66 @@ class InferenceService:
|
|
| 91 |
local_dir_use_symlinks=False
|
| 92 |
)
|
| 93 |
print(f"π₯ Downloaded ViT from HF Hub: {hf_path}")
|
|
|
|
| 94 |
state = torch.load(hf_path, map_location="cpu")
|
| 95 |
state_dict = state.get("state_dict", state) if isinstance(state, dict) else state
|
| 96 |
model.load_state_dict(state_dict, strict=False)
|
| 97 |
-
|
|
|
|
| 98 |
except Exception as e:
|
| 99 |
print(f"β Failed to download ViT from HF Hub: {e}")
|
| 100 |
-
print("β οΈ WARNING: Using untrained ViT model!")
|
| 101 |
-
print("π¨ Recommendations will not be meaningful without trained weights!")
|
| 102 |
|
| 103 |
-
#
|
| 104 |
best_path = os.path.join(os.path.dirname(ckpt_path), "vit_outfit_model_best.pth")
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
state_dict = state.get("state_dict", state) if isinstance(state, dict) else state
|
| 109 |
model.load_state_dict(state_dict, strict=False)
|
| 110 |
-
|
| 111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
def reload_models(self) -> None:
|
| 114 |
"""Reload weights from current checkpoint locations (used after background training)."""
|
| 115 |
-
self.resnet = self._load_resnet()
|
| 116 |
-
self.vit = self._load_vit()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
for m in [self.resnet, self.vit]:
|
| 118 |
-
|
| 119 |
-
p.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
|
| 121 |
@torch.inference_mode()
|
| 122 |
def embed_images(self, images: List[Image.Image]) -> List[np.ndarray]:
|
|
@@ -132,6 +203,16 @@ class InferenceService:
|
|
| 132 |
|
| 133 |
@torch.inference_mode()
|
| 134 |
def compose_outfits(self, items: List[Dict[str, Any]], context: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
# 1) Ensure embeddings for each input item
|
| 136 |
proc_items: List[Dict[str, Any]] = []
|
| 137 |
for it in items:
|
|
@@ -248,5 +329,15 @@ class InferenceService:
|
|
| 248 |
for subset, score in topk
|
| 249 |
]
|
| 250 |
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
|
| 252 |
|
|
|
|
| 27 |
self.embed_dim = int(os.getenv("EMBED_DIM", "512"))
|
| 28 |
self.resnet_version = "resnet_v1"
|
| 29 |
self.vit_version = "vit_v1"
|
| 30 |
+
|
| 31 |
+
# Model loading status tracking
|
| 32 |
+
self.models_loaded = False
|
| 33 |
+
self.model_errors = []
|
| 34 |
+
|
| 35 |
+
# Load models with validation
|
| 36 |
+
self.resnet, self.resnet_loaded = self._load_resnet()
|
| 37 |
+
self.vit, self.vit_loaded = self._load_vit()
|
| 38 |
+
|
| 39 |
+
# Move to device and set eval mode
|
| 40 |
+
if self.resnet_loaded:
|
| 41 |
+
self.resnet = self.resnet.to(self.device).eval()
|
| 42 |
+
if self.vit_loaded:
|
| 43 |
+
self.vit = self.vit.to(self.device).eval()
|
| 44 |
+
|
| 45 |
+
# Disable gradients
|
| 46 |
for m in [self.resnet, self.vit]:
|
| 47 |
+
if m is not None:
|
| 48 |
+
for p in m.parameters():
|
| 49 |
+
p.requires_grad_(False)
|
| 50 |
+
|
| 51 |
+
# Update overall status
|
| 52 |
+
self.models_loaded = self.resnet_loaded and self.vit_loaded
|
| 53 |
+
if not self.models_loaded:
|
| 54 |
+
self.model_errors = []
|
| 55 |
+
if not self.resnet_loaded:
|
| 56 |
+
self.model_errors.append("ResNet: No trained weights found")
|
| 57 |
+
if not self.vit_loaded:
|
| 58 |
+
self.model_errors.append("ViT: No trained weights found")
|
| 59 |
|
| 60 |
+
def _load_resnet(self) -> tuple[nn.Module, bool]:
|
| 61 |
strategy = os.getenv("MODEL_LOAD_STRATEGY", "state_dict")
|
| 62 |
ckpt_path = os.getenv("RESNET_CHECKPOINT", "models/exports/resnet_item_embedder.pth")
|
| 63 |
+
|
| 64 |
if strategy == "random":
|
| 65 |
+
print("β οΈ Random strategy selected - no trained weights will be loaded!")
|
| 66 |
+
return ResNetItemEmbedder(embedding_dim=self.embed_dim), False
|
| 67 |
|
| 68 |
# Try to download from Hugging Face Hub first
|
| 69 |
try:
|
|
|
|
| 75 |
local_dir_use_symlinks=False
|
| 76 |
)
|
| 77 |
print(f"π₯ Downloaded ResNet from HF Hub: {hf_path}")
|
| 78 |
+
model = ResNetItemEmbedder(embedding_dim=self.embed_dim)
|
| 79 |
state = torch.load(hf_path, map_location="cpu")
|
| 80 |
state_dict = state.get("state_dict", state) if isinstance(state, dict) else state
|
| 81 |
model.load_state_dict(state_dict, strict=False)
|
| 82 |
+
print("β
ResNet model loaded successfully from HF Hub")
|
| 83 |
+
return model, True
|
| 84 |
except Exception as e:
|
| 85 |
print(f"β Failed to download ResNet from HF Hub: {e}")
|
|
|
|
|
|
|
| 86 |
|
| 87 |
+
# Check for local best checkpoint first
|
| 88 |
best_path = os.path.join(os.path.dirname(ckpt_path), "resnet_item_embedder_best.pth")
|
| 89 |
if os.path.exists(best_path):
|
| 90 |
+
print(f"π Loading ResNet from best checkpoint: {best_path}")
|
| 91 |
+
model = ResNetItemEmbedder(embedding_dim=self.embed_dim)
|
| 92 |
+
state = torch.load(best_path, map_location="cpu")
|
| 93 |
+
state_dict = state.get("state_dict", state) if isinstance(state, dict) else state
|
| 94 |
+
model.load_state_dict(state_dict, strict=False)
|
| 95 |
+
print("β
ResNet model loaded successfully from best checkpoint")
|
| 96 |
+
return model, True
|
| 97 |
+
|
| 98 |
+
# Check for regular checkpoint
|
| 99 |
+
if os.path.exists(ckpt_path):
|
| 100 |
+
print(f"π Loading ResNet from checkpoint: {ckpt_path}")
|
| 101 |
+
model = ResNetItemEmbedder(embedding_dim=self.embed_dim)
|
| 102 |
+
state = torch.load(ckpt_path, map_location="cpu")
|
| 103 |
state_dict = state.get("state_dict", state) if isinstance(state, dict) else state
|
| 104 |
model.load_state_dict(state_dict, strict=False)
|
| 105 |
+
print("β
ResNet model loaded successfully from checkpoint")
|
| 106 |
+
return model, True
|
| 107 |
+
|
| 108 |
+
print("β CRITICAL: No trained ResNet weights found!")
|
| 109 |
+
print("π¨ Cannot provide recommendations without trained weights!")
|
| 110 |
+
print("π‘ Please train the ResNet model first using the training tabs.")
|
| 111 |
+
return ResNetItemEmbedder(embedding_dim=self.embed_dim), False
|
| 112 |
|
| 113 |
+
def _load_vit(self) -> tuple[nn.Module, bool]:
|
| 114 |
strategy = os.getenv("MODEL_LOAD_STRATEGY", "state_dict")
|
| 115 |
ckpt_path = os.getenv("VIT_CHECKPOINT", "models/exports/vit_outfit_model.pth")
|
| 116 |
+
|
| 117 |
if strategy == "random":
|
| 118 |
+
print("β οΈ Random strategy selected - no trained weights will be loaded!")
|
| 119 |
+
return OutfitCompatibilityModel(embedding_dim=self.embed_dim), False
|
| 120 |
|
| 121 |
# Try to download from Hugging Face Hub first
|
| 122 |
try:
|
|
|
|
| 128 |
local_dir_use_symlinks=False
|
| 129 |
)
|
| 130 |
print(f"π₯ Downloaded ViT from HF Hub: {hf_path}")
|
| 131 |
+
model = OutfitCompatibilityModel(embedding_dim=self.embed_dim)
|
| 132 |
state = torch.load(hf_path, map_location="cpu")
|
| 133 |
state_dict = state.get("state_dict", state) if isinstance(state, dict) else state
|
| 134 |
model.load_state_dict(state_dict, strict=False)
|
| 135 |
+
print("β
ViT model loaded successfully from HF Hub")
|
| 136 |
+
return model, True
|
| 137 |
except Exception as e:
|
| 138 |
print(f"β Failed to download ViT from HF Hub: {e}")
|
|
|
|
|
|
|
| 139 |
|
| 140 |
+
# Check for local best checkpoint first
|
| 141 |
best_path = os.path.join(os.path.dirname(ckpt_path), "vit_outfit_model_best.pth")
|
| 142 |
+
if os.path.exists(best_path):
|
| 143 |
+
print(f"π Loading ViT from best checkpoint: {best_path}")
|
| 144 |
+
model = OutfitCompatibilityModel(embedding_dim=self.embed_dim)
|
| 145 |
+
state = torch.load(best_path, map_location="cpu")
|
| 146 |
+
state_dict = state.get("state_dict", state) if isinstance(state, dict) else state
|
| 147 |
+
model.load_state_dict(state_dict, strict=False)
|
| 148 |
+
print("β
ViT model loaded successfully from best checkpoint")
|
| 149 |
+
return model, True
|
| 150 |
+
|
| 151 |
+
# Check for regular checkpoint
|
| 152 |
+
if os.path.exists(ckpt_path):
|
| 153 |
+
print(f"π Loading ViT from checkpoint: {ckpt_path}")
|
| 154 |
+
model = OutfitCompatibilityModel(embedding_dim=self.embed_dim)
|
| 155 |
+
state = torch.load(ckpt_path, map_location="cpu")
|
| 156 |
state_dict = state.get("state_dict", state) if isinstance(state, dict) else state
|
| 157 |
model.load_state_dict(state_dict, strict=False)
|
| 158 |
+
print("β
ViT model loaded successfully from checkpoint")
|
| 159 |
+
return model, True
|
| 160 |
+
|
| 161 |
+
print("β CRITICAL: No trained ViT weights found!")
|
| 162 |
+
print("π¨ Cannot provide recommendations without trained weights!")
|
| 163 |
+
print("π‘ Please train the ViT model first using the training tabs.")
|
| 164 |
+
return OutfitCompatibilityModel(embedding_dim=self.embed_dim), False
|
| 165 |
|
| 166 |
def reload_models(self) -> None:
|
| 167 |
"""Reload weights from current checkpoint locations (used after background training)."""
|
| 168 |
+
self.resnet, self.resnet_loaded = self._load_resnet()
|
| 169 |
+
self.vit, self.vit_loaded = self._load_vit()
|
| 170 |
+
|
| 171 |
+
# Move to device and set eval mode
|
| 172 |
+
if self.resnet_loaded:
|
| 173 |
+
self.resnet = self.resnet.to(self.device).eval()
|
| 174 |
+
if self.vit_loaded:
|
| 175 |
+
self.vit = self.vit.to(self.device).eval()
|
| 176 |
+
|
| 177 |
+
# Disable gradients
|
| 178 |
for m in [self.resnet, self.vit]:
|
| 179 |
+
if m is not None:
|
| 180 |
+
for p in m.parameters():
|
| 181 |
+
p.requires_grad_(False)
|
| 182 |
+
|
| 183 |
+
# Update overall status
|
| 184 |
+
self.models_loaded = self.resnet_loaded and self.vit_loaded
|
| 185 |
+
if not self.models_loaded:
|
| 186 |
+
self.model_errors = []
|
| 187 |
+
if not self.resnet_loaded:
|
| 188 |
+
self.model_errors.append("ResNet: No trained weights found")
|
| 189 |
+
if not self.vit_loaded:
|
| 190 |
+
self.model_errors.append("ViT: No trained weights found")
|
| 191 |
|
| 192 |
@torch.inference_mode()
|
| 193 |
def embed_images(self, images: List[Image.Image]) -> List[np.ndarray]:
|
|
|
|
| 203 |
|
| 204 |
@torch.inference_mode()
|
| 205 |
def compose_outfits(self, items: List[Dict[str, Any]], context: Dict[str, Any]) -> List[Dict[str, Any]]:
|
| 206 |
+
# Validate that models are properly loaded
|
| 207 |
+
if not self.models_loaded:
|
| 208 |
+
error_msg = f"β Cannot provide recommendations: Models not properly loaded. Errors: {self.model_errors}"
|
| 209 |
+
print(error_msg)
|
| 210 |
+
return [{
|
| 211 |
+
"error": "Models not trained or loaded properly",
|
| 212 |
+
"details": self.model_errors,
|
| 213 |
+
"message": "Please ensure models are trained and checkpoints exist before generating recommendations."
|
| 214 |
+
}]
|
| 215 |
+
|
| 216 |
# 1) Ensure embeddings for each input item
|
| 217 |
proc_items: List[Dict[str, Any]] = []
|
| 218 |
for it in items:
|
|
|
|
| 329 |
for subset, score in topk
|
| 330 |
]
|
| 331 |
return results
|
| 332 |
+
|
| 333 |
+
def get_model_status(self) -> Dict[str, Any]:
|
| 334 |
+
"""Get current model loading status and errors."""
|
| 335 |
+
return {
|
| 336 |
+
"models_loaded": self.models_loaded,
|
| 337 |
+
"resnet_loaded": self.resnet_loaded,
|
| 338 |
+
"vit_loaded": self.vit_loaded,
|
| 339 |
+
"errors": self.model_errors,
|
| 340 |
+
"can_recommend": self.models_loaded
|
| 341 |
+
}
|
| 342 |
|
| 343 |
|