Ali Mohsin commited on
Commit
1d3b4c2
Β·
1 Parent(s): c0eeb7b

gooooooooo

Browse files
Files changed (1) hide show
  1. inference.py +87 -2
inference.py CHANGED
@@ -6,6 +6,7 @@ import torch
6
  import torch.nn as nn
7
  from PIL import Image
8
  from huggingface_hub import hf_hub_download
 
9
 
10
  from utils.transforms import build_inference_transform
11
  from models.resnet_embedder import ResNetItemEmbedder
@@ -32,6 +33,10 @@ class InferenceService:
32
  self.models_loaded = False
33
  self.model_errors = []
34
 
 
 
 
 
35
  # Load models with validation
36
  self.resnet, self.resnet_loaded = self._load_resnet()
37
  self.vit, self.vit_loaded = self._load_vit()
@@ -57,6 +62,75 @@ class InferenceService:
57
  if not self.vit_loaded:
58
  self.model_errors.append("ViT: No trained weights found")
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  def _load_resnet(self) -> tuple[nn.Module, bool]:
61
  strategy = os.getenv("MODEL_LOAD_STRATEGY", "state_dict")
62
  ckpt_path = os.getenv("RESNET_CHECKPOINT", "models/exports/resnet_item_embedder.pth")
@@ -236,6 +310,17 @@ class InferenceService:
236
  proc_items: List[Dict[str, Any]] = []
237
  for i, it in enumerate(items):
238
  print(f"πŸ” DEBUG: Processing item {i}: id={it.get('id')}, has_image={it.get('image') is not None}, has_embedding={it.get('embedding') is not None}")
 
 
 
 
 
 
 
 
 
 
 
239
  emb = it.get("embedding")
240
  if emb is None and it.get("image") is not None:
241
  # Compute on-the-fly if image provided
@@ -249,9 +334,9 @@ class InferenceService:
249
  proc_items.append({
250
  "id": it.get("id"),
251
  "embedding": emb_np,
252
- "category": it.get("category")
253
  })
254
- print(f"πŸ” DEBUG: Added item {i} to proc_items, total: {len(proc_items)}")
255
 
256
  print(f"πŸ” DEBUG: Final proc_items count: {len(proc_items)}")
257
  if len(proc_items) < 2:
 
6
  import torch.nn as nn
7
  from PIL import Image
8
  from huggingface_hub import hf_hub_download
9
+ import clip
10
 
11
  from utils.transforms import build_inference_transform
12
  from models.resnet_embedder import ResNetItemEmbedder
 
33
  self.models_loaded = False
34
  self.model_errors = []
35
 
36
+ # Load CLIP for category detection
37
+ self.clip_model, self.clip_preprocess = None, None
38
+ self._load_clip()
39
+
40
  # Load models with validation
41
  self.resnet, self.resnet_loaded = self._load_resnet()
42
  self.vit, self.vit_loaded = self._load_vit()
 
62
  if not self.vit_loaded:
63
  self.model_errors.append("ViT: No trained weights found")
64
 
65
+ def _load_clip(self) -> None:
66
+ """Load CLIP model for category detection."""
67
+ try:
68
+ print("πŸ”„ Loading CLIP model for category detection...")
69
+ self.clip_model, self.clip_preprocess = clip.load("ViT-B/32", device=self.device)
70
+ print("βœ… CLIP model loaded successfully")
71
+ except Exception as e:
72
+ print(f"❌ Failed to load CLIP model: {e}")
73
+ self.clip_model, self.clip_preprocess = None, None
74
+
75
+ def _detect_category_with_clip(self, image: Image.Image) -> str:
76
+ """Detect clothing category using CLIP."""
77
+ if self.clip_model is None or self.clip_preprocess is None:
78
+ return "other"
79
+
80
+ try:
81
+ # Define clothing categories with descriptions
82
+ categories = [
83
+ "a shirt, t-shirt, blouse, or top",
84
+ "pants, jeans, trousers, or bottoms",
85
+ "shoes, sneakers, boots, or footwear",
86
+ "a jacket, blazer, coat, or outerwear",
87
+ "a dress or gown",
88
+ "a skirt or shorts",
89
+ "a sweater, hoodie, or pullover",
90
+ "a watch, ring, necklace, or jewelry",
91
+ "a bag, purse, or handbag",
92
+ "a hat, cap, or headwear",
93
+ "a belt or accessory"
94
+ ]
95
+
96
+ # Prepare image and text
97
+ image_input = self.clip_preprocess(image).unsqueeze(0).to(self.device)
98
+ text_inputs = clip.tokenize(categories).to(self.device)
99
+
100
+ # Get predictions
101
+ with torch.no_grad():
102
+ image_features = self.clip_model.encode_image(image_input)
103
+ text_features = self.clip_model.encode_text(text_inputs)
104
+
105
+ # Compute similarity
106
+ similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
107
+ values, indices = similarity[0].topk(1)
108
+
109
+ # Map to outfit categories
110
+ category_map = {
111
+ 0: "shirt", # shirt, t-shirt, blouse, top
112
+ 1: "pants", # pants, jeans, trousers, bottoms
113
+ 2: "shoes", # shoes, sneakers, boots, footwear
114
+ 3: "jacket", # jacket, blazer, coat, outerwear
115
+ 4: "dress", # dress, gown
116
+ 5: "pants", # skirt, shorts (map to pants for outfit logic)
117
+ 6: "shirt", # sweater, hoodie, pullover (map to shirt)
118
+ 7: "accessory", # watch, ring, necklace, jewelry
119
+ 8: "accessory", # bag, purse, handbag
120
+ 9: "accessory", # hat, cap, headwear
121
+ 10: "accessory" # belt, accessory
122
+ }
123
+
124
+ predicted_category = category_map.get(indices[0].item(), "other")
125
+ confidence = values[0].item()
126
+
127
+ print(f"πŸ” CLIP detected: '{predicted_category}' (confidence: {confidence:.3f})")
128
+ return predicted_category
129
+
130
+ except Exception as e:
131
+ print(f"❌ CLIP category detection failed: {e}")
132
+ return "other"
133
+
134
  def _load_resnet(self) -> tuple[nn.Module, bool]:
135
  strategy = os.getenv("MODEL_LOAD_STRATEGY", "state_dict")
136
  ckpt_path = os.getenv("RESNET_CHECKPOINT", "models/exports/resnet_item_embedder.pth")
 
310
  proc_items: List[Dict[str, Any]] = []
311
  for i, it in enumerate(items):
312
  print(f"πŸ” DEBUG: Processing item {i}: id={it.get('id')}, has_image={it.get('image') is not None}, has_embedding={it.get('embedding') is not None}")
313
+
314
+ # Auto-detect category using CLIP if not provided or is None
315
+ category = it.get("category")
316
+ if not category or category == "None" or category == "":
317
+ if it.get("image") is not None:
318
+ print(f"πŸ” DEBUG: Auto-detecting category for item {i} using CLIP...")
319
+ category = self._detect_category_with_clip(it["image"])
320
+ else:
321
+ category = "other"
322
+ print(f"πŸ” DEBUG: No image available for item {i}, using 'other' category")
323
+
324
  emb = it.get("embedding")
325
  if emb is None and it.get("image") is not None:
326
  # Compute on-the-fly if image provided
 
334
  proc_items.append({
335
  "id": it.get("id"),
336
  "embedding": emb_np,
337
+ "category": category
338
  })
339
+ print(f"πŸ” DEBUG: Added item {i} to proc_items with category '{category}', total: {len(proc_items)}")
340
 
341
  print(f"πŸ” DEBUG: Final proc_items count: {len(proc_items)}")
342
  if len(proc_items) < 2: