Spaces:
Paused
Paused
Ali Mohsin
commited on
Commit
Β·
8d2d202
1
Parent(s):
4a2dbbf
Final fixes 5666444
Browse files- app.py +37 -0
- inference.py +64 -12
app.py
CHANGED
|
@@ -296,6 +296,43 @@ threading.Thread(target=_background_bootstrap, daemon=True).start()
|
|
| 296 |
def health() -> dict:
|
| 297 |
return {"status": "ok", "device": service.device, "resnet": service.resnet_version, "vit": service.vit_version}
|
| 298 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 299 |
|
| 300 |
@app.post("/embed")
|
| 301 |
def embed(req: EmbedRequest, x_api_key: Optional[str] = Header(None)) -> dict:
|
|
|
|
| 296 |
def health() -> dict:
|
| 297 |
return {"status": "ok", "device": service.device, "resnet": service.resnet_version, "vit": service.vit_version}
|
| 298 |
|
| 299 |
+
@app.get("/model-status")
|
| 300 |
+
def model_status() -> dict:
|
| 301 |
+
"""Get detailed model loading status."""
|
| 302 |
+
return service.get_model_status()
|
| 303 |
+
|
| 304 |
+
@app.post("/reload-models")
|
| 305 |
+
def reload_models() -> dict:
|
| 306 |
+
"""Force reload models - useful for debugging."""
|
| 307 |
+
try:
|
| 308 |
+
service.force_reload_models()
|
| 309 |
+
return {"status": "success", "message": "Models reloaded successfully"}
|
| 310 |
+
except Exception as e:
|
| 311 |
+
return {"status": "error", "message": str(e)}
|
| 312 |
+
|
| 313 |
+
@app.post("/test-recommend")
|
| 314 |
+
def test_recommend() -> dict:
|
| 315 |
+
"""Test recommendation with dummy data to debug the issue."""
|
| 316 |
+
try:
|
| 317 |
+
# Create dummy items for testing
|
| 318 |
+
dummy_items = [
|
| 319 |
+
{"id": "test_1", "image": None, "category": "shirt"},
|
| 320 |
+
{"id": "test_2", "image": None, "category": "pants"},
|
| 321 |
+
{"id": "test_3", "image": None, "category": "shoes"}
|
| 322 |
+
]
|
| 323 |
+
|
| 324 |
+
# Try to get recommendations
|
| 325 |
+
result = service.compose_outfits(dummy_items, {"num_outfits": 1})
|
| 326 |
+
|
| 327 |
+
return {
|
| 328 |
+
"status": "success",
|
| 329 |
+
"model_status": service.get_model_status(),
|
| 330 |
+
"result": result,
|
| 331 |
+
"result_length": len(result) if result else 0
|
| 332 |
+
}
|
| 333 |
+
except Exception as e:
|
| 334 |
+
return {"status": "error", "message": str(e), "model_status": service.get_model_status()}
|
| 335 |
+
|
| 336 |
|
| 337 |
@app.post("/embed")
|
| 338 |
def embed(req: EmbedRequest, x_api_key: Optional[str] = Header(None)) -> dict:
|
inference.py
CHANGED
|
@@ -45,8 +45,8 @@ class InferenceService:
|
|
| 45 |
# Disable gradients
|
| 46 |
for m in [self.resnet, self.vit]:
|
| 47 |
if m is not None:
|
| 48 |
-
|
| 49 |
-
|
| 50 |
|
| 51 |
# Update overall status
|
| 52 |
self.models_loaded = self.resnet_loaded and self.vit_loaded
|
|
@@ -177,8 +177,8 @@ class InferenceService:
|
|
| 177 |
# Disable gradients
|
| 178 |
for m in [self.resnet, self.vit]:
|
| 179 |
if m is not None:
|
| 180 |
-
|
| 181 |
-
|
| 182 |
|
| 183 |
# Update overall status
|
| 184 |
self.models_loaded = self.resnet_loaded and self.vit_loaded
|
|
@@ -202,12 +202,12 @@ class InferenceService:
|
|
| 202 |
return []
|
| 203 |
|
| 204 |
try:
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
result = [e.detach().cpu().numpy().astype(np.float32) for e in emb]
|
| 212 |
print(f"π DEBUG: Successfully generated {len(result)} embeddings")
|
| 213 |
return result
|
|
@@ -257,6 +257,8 @@ class InferenceService:
|
|
| 257 |
if len(proc_items) < 2:
|
| 258 |
print("π DEBUG: Returning empty array - not enough items (< 2)")
|
| 259 |
return []
|
|
|
|
|
|
|
| 260 |
|
| 261 |
# 2) Candidate generation with outfit templates
|
| 262 |
rng = np.random.default_rng(int(context.get("seed", 42)))
|
|
@@ -299,6 +301,11 @@ class InferenceService:
|
|
| 299 |
# Enhanced category-aware pools with diversity checks
|
| 300 |
def cat_str(i: int) -> str:
|
| 301 |
return (proc_items[i].get("category") or "").lower()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 302 |
|
| 303 |
def extract_color_from_category(category: str) -> str:
|
| 304 |
"""Extract color information from category name"""
|
|
@@ -356,6 +363,7 @@ class InferenceService:
|
|
| 356 |
def get_category_type(cat: str) -> str:
|
| 357 |
"""Map category to outfit slot type with comprehensive taxonomy"""
|
| 358 |
cat_lower = cat.lower().strip()
|
|
|
|
| 359 |
|
| 360 |
# Upper body items (tops, outerwear)
|
| 361 |
upper_keywords = [
|
|
@@ -405,14 +413,23 @@ class InferenceService:
|
|
| 405 |
return "other"
|
| 406 |
|
| 407 |
# Create category pools
|
|
|
|
| 408 |
uppers = [i for i in ids if get_category_type(cat_str(i)) == "upper"]
|
| 409 |
bottoms = [i for i in ids if get_category_type(cat_str(i)) == "bottom"]
|
| 410 |
shoes = [i for i in ids if get_category_type(cat_str(i)) == "shoe"]
|
| 411 |
accs = [i for i in ids if get_category_type(cat_str(i)) == "accessory"]
|
| 412 |
others = [i for i in ids if get_category_type(cat_str(i)) == "other"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 413 |
|
| 414 |
candidates: List[List[int]] = []
|
| 415 |
num_samples = max(num_outfits * 12, 24)
|
|
|
|
| 416 |
|
| 417 |
def has_category_diversity(subset: List[int]) -> bool:
|
| 418 |
"""Check if subset has good category diversity"""
|
|
@@ -468,7 +485,11 @@ class InferenceService:
|
|
| 468 |
subset = list(set(subset))
|
| 469 |
if len(subset) >= 3: # At least 3 items for a valid outfit
|
| 470 |
candidates.append(subset)
|
|
|
|
|
|
|
| 471 |
|
|
|
|
|
|
|
| 472 |
# 3) Score using ViT
|
| 473 |
def score_subset(idx_subset: List[int]) -> float:
|
| 474 |
embs = torch.tensor(
|
|
@@ -609,7 +630,7 @@ class InferenceService:
|
|
| 609 |
})
|
| 610 |
|
| 611 |
return results
|
| 612 |
-
|
| 613 |
def get_model_status(self) -> Dict[str, Any]:
|
| 614 |
"""Get current model loading status and errors."""
|
| 615 |
return {
|
|
@@ -617,7 +638,38 @@ class InferenceService:
|
|
| 617 |
"resnet_loaded": self.resnet_loaded,
|
| 618 |
"vit_loaded": self.vit_loaded,
|
| 619 |
"errors": self.model_errors,
|
| 620 |
-
"can_recommend": self.models_loaded
|
|
|
|
|
|
|
| 621 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 622 |
|
| 623 |
|
|
|
|
| 45 |
# Disable gradients
|
| 46 |
for m in [self.resnet, self.vit]:
|
| 47 |
if m is not None:
|
| 48 |
+
for p in m.parameters():
|
| 49 |
+
p.requires_grad_(False)
|
| 50 |
|
| 51 |
# Update overall status
|
| 52 |
self.models_loaded = self.resnet_loaded and self.vit_loaded
|
|
|
|
| 177 |
# Disable gradients
|
| 178 |
for m in [self.resnet, self.vit]:
|
| 179 |
if m is not None:
|
| 180 |
+
for p in m.parameters():
|
| 181 |
+
p.requires_grad_(False)
|
| 182 |
|
| 183 |
# Update overall status
|
| 184 |
self.models_loaded = self.resnet_loaded and self.vit_loaded
|
|
|
|
| 202 |
return []
|
| 203 |
|
| 204 |
try:
|
| 205 |
+
batch = torch.stack([self.transform(img) for img in images])
|
| 206 |
+
batch = batch.to(self.device, memory_format=torch.channels_last)
|
| 207 |
+
use_amp = (self.device == "cuda")
|
| 208 |
+
with torch.autocast(device_type=("cuda" if use_amp else "cpu"), enabled=use_amp):
|
| 209 |
+
emb = self.resnet(batch)
|
| 210 |
+
emb = nn.functional.normalize(emb, dim=-1)
|
| 211 |
result = [e.detach().cpu().numpy().astype(np.float32) for e in emb]
|
| 212 |
print(f"π DEBUG: Successfully generated {len(result)} embeddings")
|
| 213 |
return result
|
|
|
|
| 257 |
if len(proc_items) < 2:
|
| 258 |
print("π DEBUG: Returning empty array - not enough items (< 2)")
|
| 259 |
return []
|
| 260 |
+
|
| 261 |
+
print("π DEBUG: Starting candidate generation...")
|
| 262 |
|
| 263 |
# 2) Candidate generation with outfit templates
|
| 264 |
rng = np.random.default_rng(int(context.get("seed", 42)))
|
|
|
|
| 301 |
# Enhanced category-aware pools with diversity checks
|
| 302 |
def cat_str(i: int) -> str:
|
| 303 |
return (proc_items[i].get("category") or "").lower()
|
| 304 |
+
|
| 305 |
+
print("π DEBUG: Building category pools...")
|
| 306 |
+
# Debug: Print all categories
|
| 307 |
+
for i in range(len(proc_items)):
|
| 308 |
+
print(f"π DEBUG: Item {i}: category='{proc_items[i].get('category')}' -> cat_str='{cat_str(i)}'")
|
| 309 |
|
| 310 |
def extract_color_from_category(category: str) -> str:
|
| 311 |
"""Extract color information from category name"""
|
|
|
|
| 363 |
def get_category_type(cat: str) -> str:
|
| 364 |
"""Map category to outfit slot type with comprehensive taxonomy"""
|
| 365 |
cat_lower = cat.lower().strip()
|
| 366 |
+
print(f"π DEBUG: Mapping category '{cat}' -> '{cat_lower}'")
|
| 367 |
|
| 368 |
# Upper body items (tops, outerwear)
|
| 369 |
upper_keywords = [
|
|
|
|
| 413 |
return "other"
|
| 414 |
|
| 415 |
# Create category pools
|
| 416 |
+
print("π DEBUG: Building category pools...")
|
| 417 |
uppers = [i for i in ids if get_category_type(cat_str(i)) == "upper"]
|
| 418 |
bottoms = [i for i in ids if get_category_type(cat_str(i)) == "bottom"]
|
| 419 |
shoes = [i for i in ids if get_category_type(cat_str(i)) == "shoe"]
|
| 420 |
accs = [i for i in ids if get_category_type(cat_str(i)) == "accessory"]
|
| 421 |
others = [i for i in ids if get_category_type(cat_str(i)) == "other"]
|
| 422 |
+
|
| 423 |
+
print(f"π DEBUG: Category pools - uppers: {len(uppers)}, bottoms: {len(bottoms)}, shoes: {len(shoes)}, accessories: {len(accs)}, others: {len(others)}")
|
| 424 |
+
|
| 425 |
+
# Check if we have the minimum required items
|
| 426 |
+
if len(uppers) == 0 or len(bottoms) == 0 or len(shoes) == 0:
|
| 427 |
+
print(f"π DEBUG: Missing required categories - uppers: {len(uppers)}, bottoms: {len(bottoms)}, shoes: {len(shoes)}")
|
| 428 |
+
return []
|
| 429 |
|
| 430 |
candidates: List[List[int]] = []
|
| 431 |
num_samples = max(num_outfits * 12, 24)
|
| 432 |
+
print(f"π DEBUG: Generating {num_samples} candidate outfits...")
|
| 433 |
|
| 434 |
def has_category_diversity(subset: List[int]) -> bool:
|
| 435 |
"""Check if subset has good category diversity"""
|
|
|
|
| 485 |
subset = list(set(subset))
|
| 486 |
if len(subset) >= 3: # At least 3 items for a valid outfit
|
| 487 |
candidates.append(subset)
|
| 488 |
+
if len(candidates) % 10 == 0: # Log every 10 candidates
|
| 489 |
+
print(f"π DEBUG: Generated {len(candidates)} candidates so far...")
|
| 490 |
|
| 491 |
+
print(f"π DEBUG: Generated {len(candidates)} total candidates")
|
| 492 |
+
|
| 493 |
# 3) Score using ViT
|
| 494 |
def score_subset(idx_subset: List[int]) -> float:
|
| 495 |
embs = torch.tensor(
|
|
|
|
| 630 |
})
|
| 631 |
|
| 632 |
return results
|
| 633 |
+
|
| 634 |
def get_model_status(self) -> Dict[str, Any]:
|
| 635 |
"""Get current model loading status and errors."""
|
| 636 |
return {
|
|
|
|
| 638 |
"resnet_loaded": self.resnet_loaded,
|
| 639 |
"vit_loaded": self.vit_loaded,
|
| 640 |
"errors": self.model_errors,
|
| 641 |
+
"can_recommend": self.models_loaded,
|
| 642 |
+
"resnet_model": self.resnet is not None,
|
| 643 |
+
"vit_model": self.vit is not None
|
| 644 |
}
|
| 645 |
+
|
| 646 |
+
def force_reload_models(self) -> None:
|
| 647 |
+
"""Force reload models and update status - useful for debugging."""
|
| 648 |
+
print("π Force reloading models...")
|
| 649 |
+
self.resnet, self.resnet_loaded = self._load_resnet()
|
| 650 |
+
self.vit, self.vit_loaded = self._load_vit()
|
| 651 |
+
|
| 652 |
+
# Move to device and set eval mode
|
| 653 |
+
if self.resnet_loaded:
|
| 654 |
+
self.resnet = self.resnet.to(self.device).eval()
|
| 655 |
+
if self.vit_loaded:
|
| 656 |
+
self.vit = self.vit.to(self.device).eval()
|
| 657 |
+
|
| 658 |
+
# Disable gradients
|
| 659 |
+
for m in [self.resnet, self.vit]:
|
| 660 |
+
if m is not None:
|
| 661 |
+
for p in m.parameters():
|
| 662 |
+
p.requires_grad_(False)
|
| 663 |
+
|
| 664 |
+
# Update overall status
|
| 665 |
+
self.models_loaded = self.resnet_loaded and self.vit_loaded
|
| 666 |
+
print(f"π Models reloaded: resnet={self.resnet_loaded}, vit={self.vit_loaded}, overall={self.models_loaded}")
|
| 667 |
+
|
| 668 |
+
if not self.models_loaded:
|
| 669 |
+
self.model_errors = []
|
| 670 |
+
if not self.resnet_loaded:
|
| 671 |
+
self.model_errors.append("ResNet: No trained weights found")
|
| 672 |
+
if not self.vit_loaded:
|
| 673 |
+
self.model_errors.append("ViT: No trained weights found")
|
| 674 |
|
| 675 |
|