Leacb4 commited on
Commit
9be16b3
ยท
verified ยท
1 Parent(s): 61f1c2b

Upload evaluation/fashion_search.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. evaluation/fashion_search.py +365 -0
evaluation/fashion_search.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Fashion search system using multi-modal embeddings.
4
+ This file implements a fashion search engine that allows searching for clothing items
5
+ using text queries. It uses embeddings from the main model to calculate cosine similarities
6
+ and return the most relevant items. The system pre-computes embeddings for all items
7
+ in the dataset for fast search.
8
+ """
9
+
10
+ import torch
11
+ import numpy as np
12
+ import pandas as pd
13
+ from PIL import Image
14
+ import matplotlib.pyplot as plt
15
+ from sklearn.metrics.pairwise import cosine_similarity
16
+ from transformers import CLIPProcessor, CLIPModel as CLIPModel_transformers
17
+ import warnings
18
+ import os
19
+ from typing import List, Tuple, Union, Optional
20
+ import argparse
21
+
22
+ # Import custom models
23
+ from color_model import CLIPModel as ColorModel
24
+ from hierarchy_model import Model as HierarchyModel, HierarchyExtractor
25
+ from main_model import CustomDataset
26
+ import config
27
+
28
+ warnings.filterwarnings("ignore")
29
+
30
+ class FashionSearchEngine:
31
+ """
32
+ Fashion search engine using multi-modal embeddings with category emphasis
33
+ """
34
+
35
+ def __init__(self, top_k: int = 10, max_items: int = 10000):
36
+ """
37
+ Initialize the fashion search engine
38
+ Args:
39
+ top_k: Number of top results to return
40
+ max_items: Maximum number of items to process (for faster initialization)
41
+ hierarchy_weight: Weight for hierarchy/category dimensions (default: 2.0)
42
+ color_weight: Weight for color dimensions (default: 1.0)
43
+ """
44
+ self.device = config.device
45
+ self.top_k = top_k
46
+ self.max_items = max_items
47
+ self.color_dim = config.color_emb_dim
48
+ self.hierarchy_dim = config.hierarchy_emb_dim
49
+
50
+ # Load models
51
+ self._load_models()
52
+
53
+ # Load dataset
54
+ self._load_dataset()
55
+
56
+ # Pre-compute embeddings for all items
57
+ self._precompute_embeddings()
58
+
59
+ print("โœ… Fashion Search Engine ready!")
60
+
61
+ def _load_models(self):
62
+ """Load all required models"""
63
+ print("๐Ÿ“ฆ Loading models...")
64
+
65
+ # Load color model
66
+ color_checkpoint = torch.load(config.color_model_path, map_location=self.device, weights_only=True)
67
+ self.color_model = ColorModel(embed_dim=self.color_dim).to(self.device)
68
+ self.color_model.load_state_dict(color_checkpoint)
69
+ self.color_model.eval()
70
+
71
+ # Load hierarchy model
72
+ hierarchy_checkpoint = torch.load(config.hierarchy_model_path, map_location=self.device)
73
+ self.hierarchy_classes = hierarchy_checkpoint.get('hierarchy_classes', [])
74
+ self.hierarchy_model = HierarchyModel(
75
+ num_hierarchy_classes=len(self.hierarchy_classes),
76
+ embed_dim=self.hierarchy_dim
77
+ ).to(self.device)
78
+ self.hierarchy_model.load_state_dict(hierarchy_checkpoint['model_state'])
79
+
80
+ # Set hierarchy extractor
81
+ hierarchy_extractor = HierarchyExtractor(self.hierarchy_classes, verbose=False)
82
+ self.hierarchy_model.set_hierarchy_extractor(hierarchy_extractor)
83
+ self.hierarchy_model.eval()
84
+
85
+ # Load main CLIP model - Use the trained model directly
86
+ self.main_model = CLIPModel_transformers.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K')
87
+
88
+ # Load the trained weights
89
+ checkpoint = torch.load(config.main_model_path, map_location=self.device)
90
+ if 'model_state_dict' in checkpoint:
91
+ self.main_model.load_state_dict(checkpoint['model_state_dict'])
92
+ else:
93
+ # Fallback: try to load as state dict directly
94
+ self.main_model.load_state_dict(checkpoint)
95
+ print("โœ… Loaded model weights directly")
96
+
97
+ self.main_model.to(self.device)
98
+ self.main_model.eval()
99
+
100
+ # Load CLIP processor
101
+ self.clip_processor = CLIPProcessor.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K')
102
+
103
+ print(f"โœ… Models loaded - Colors: {self.color_dim}D, Hierarchy: {self.hierarchy_dim}D")
104
+
105
+ def _load_dataset(self):
106
+ """Load the fashion dataset"""
107
+ print("๐Ÿ“Š Loading dataset...")
108
+
109
+ # Load dataset
110
+ self.df = pd.read_csv(config.local_dataset_path)
111
+ self.df_clean = self.df.dropna(subset=[config.column_local_image_path])
112
+
113
+ # Create dataset object
114
+ self.dataset = CustomDataset(self.df_clean)
115
+ self.dataset.set_training_mode(False) # No augmentation for search
116
+
117
+ print(f"โœ… {len(self.df_clean)} items loaded for search")
118
+
119
+ def _precompute_embeddings(self):
120
+ """Pre-compute embeddings for all items in the dataset"""
121
+ print("๐Ÿ”„ Pre-computing embeddings...")
122
+
123
+ # OPTIMIZATION: Sample a subset for faster initialization
124
+ print(f"โš ๏ธ Dataset too large ({len(self.dataset)} items). Using stratified sampling of 10 items per color-category combination.")
125
+
126
+ # Stratified sampling by color-category combinations
127
+ sampled_df = self.df_clean.groupby([config.color_column, config.hierarchy_column]).sample(n=20, replace=False)
128
+
129
+ # Get the original indices of sampled items
130
+ sampled_indices = sampled_df.index.tolist()
131
+
132
+ all_embeddings = []
133
+ all_texts = []
134
+ all_colors = []
135
+ all_hierarchies = []
136
+ all_images = []
137
+ all_urls = []
138
+
139
+ # Process in batches for efficiency
140
+ batch_size = 32
141
+
142
+ # Add progress bar
143
+ from tqdm import tqdm
144
+ total_batches = (len(sampled_indices) + batch_size - 1) // batch_size
145
+
146
+ for i in tqdm(range(0, len(sampled_indices), batch_size),
147
+ desc="Computing embeddings",
148
+ total=total_batches):
149
+ batch_end = min(i + batch_size, len(sampled_indices))
150
+ batch_items = []
151
+
152
+ for j in range(i, batch_end):
153
+ try:
154
+ # Use the original dataset with the sampled index
155
+ original_idx = sampled_indices[j]
156
+ image, text, color, hierarchy = self.dataset[original_idx]
157
+ batch_items.append((image, text, color, hierarchy))
158
+ all_texts.append(text)
159
+ all_colors.append(color)
160
+ all_hierarchies.append(hierarchy)
161
+ all_images.append(self.df_clean.iloc[original_idx][config.column_local_image_path])
162
+ all_urls.append(self.df_clean.iloc[original_idx][config.column_url_image])
163
+ except Exception as e:
164
+ print(f"โš ๏ธ Skipping item {j}: {e}")
165
+ continue
166
+
167
+ if not batch_items:
168
+ continue
169
+
170
+ # Process batch
171
+ images = torch.stack([item[0] for item in batch_items]).to(self.device)
172
+ texts = [item[1] for item in batch_items]
173
+
174
+ with torch.no_grad():
175
+ # Get embeddings from main model (text embeddings only)
176
+ text_inputs = self.clip_processor(text=texts, padding=True, return_tensors="pt")
177
+ text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()}
178
+
179
+ # Create dummy images for the model
180
+ dummy_images = torch.zeros(len(texts), 3, 224, 224).to(self.device)
181
+
182
+ outputs = self.main_model(**text_inputs, pixel_values=dummy_images)
183
+ embeddings = outputs.text_embeds.cpu().numpy()
184
+
185
+ all_embeddings.extend(embeddings)
186
+
187
+ self.all_embeddings = np.array(all_embeddings)
188
+ self.all_texts = all_texts
189
+ self.all_colors = all_colors
190
+ self.all_hierarchies = all_hierarchies
191
+ self.all_images = all_images
192
+ self.all_urls = all_urls
193
+
194
+ print(f"โœ… Pre-computed embeddings for {len(self.all_embeddings)} items")
195
+
196
+ def search_by_text(self, query_text: str, filter_category: str = None) -> List[dict]:
197
+ """
198
+ Search for clothing items using text query
199
+
200
+ Args:
201
+ query_text: Text description to search for
202
+
203
+ Returns:
204
+ List of dictionaries containing search results
205
+ """
206
+ print(f"๐Ÿ” Searching for: '{query_text}'")
207
+
208
+ # Get query embedding
209
+ with torch.no_grad():
210
+ text_inputs = self.clip_processor(text=[query_text], padding=True, return_tensors="pt")
211
+ text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()}
212
+
213
+ # Create a dummy image tensor to satisfy the model's requirements
214
+ dummy_image = torch.zeros(1, 3, 224, 224).to(self.device)
215
+
216
+ outputs = self.main_model(**text_inputs, pixel_values=dummy_image)
217
+ query_embedding = outputs.text_embeds.cpu().numpy()
218
+
219
+ # Calculate similarities
220
+ similarities = cosine_similarity(query_embedding, self.all_embeddings)[0]
221
+
222
+ # Get top-k results
223
+ top_indices = np.argsort(similarities)[::-1][:self.top_k * 2] # Prendre plus de rรฉsultats
224
+
225
+ results = []
226
+ for idx in top_indices:
227
+ if similarities[idx] > -0.5:
228
+ # Filter by category if specified
229
+ if filter_category and filter_category.lower() not in self.all_hierarchies[idx].lower():
230
+ continue
231
+
232
+ results.append({
233
+ 'rank': len(results) + 1,
234
+ 'image_path': self.all_images[idx],
235
+ 'text': self.all_texts[idx],
236
+ 'color': self.all_colors[idx],
237
+ 'hierarchy': self.all_hierarchies[idx],
238
+ 'similarity': float(similarities[idx]),
239
+ 'index': int(idx),
240
+ 'url': self.all_urls[idx]
241
+ })
242
+
243
+ if len(results) >= self.top_k:
244
+ break
245
+
246
+ print(f"โœ… Found {len(results)} results")
247
+ return results
248
+
249
+
250
+ def display_results(self, results: List[dict], query_info: str = ""):
251
+ """
252
+ Display search results with images and information
253
+
254
+ Args:
255
+ results: List of search result dictionaries
256
+ query_info: Information about the query
257
+ """
258
+ if not results:
259
+ print("โŒ No results found")
260
+ return
261
+
262
+ print(f"\n๐ŸŽฏ Search Results for: {query_info}")
263
+ print("=" * 80)
264
+
265
+ # Calculate grid layout
266
+ n_results = len(results)
267
+ cols = min(5, n_results)
268
+ rows = (n_results + cols - 1) // cols
269
+
270
+ fig, axes = plt.subplots(rows, cols, figsize=(4*cols, 4*rows))
271
+ if rows == 1:
272
+ axes = axes.reshape(1, -1)
273
+ elif cols == 1:
274
+ axes = axes.reshape(-1, 1)
275
+
276
+ for i, result in enumerate(results):
277
+ row = i // cols
278
+ col = i % cols
279
+ ax = axes[row, col]
280
+
281
+ try:
282
+ # Load and display image
283
+ image = Image.open(result['image_path'])
284
+ ax.imshow(image)
285
+ ax.axis('off')
286
+
287
+ # Add title with similarity score
288
+ title = f"#{result['rank']} (Similarity: {result['similarity']:.3f})\n{result['color']} {result['hierarchy']}"
289
+ ax.set_title(title, fontsize=10, wrap=True)
290
+
291
+ except Exception as e:
292
+ ax.text(0.5, 0.5, f"Error loading image\n{result['image_path']}",
293
+ ha='center', va='center', transform=ax.transAxes)
294
+ ax.axis('off')
295
+
296
+ # Hide empty subplots
297
+ for i in range(n_results, rows * cols):
298
+ row = i // cols
299
+ col = i % cols
300
+ axes[row, col].axis('off')
301
+
302
+ plt.tight_layout()
303
+ plt.show()
304
+
305
+ # Print detailed results
306
+ print("\n๐Ÿ“‹ Detailed Results:")
307
+ for result in results:
308
+ print(f"#{result['rank']:2d} | Similarity: {result['similarity']:.3f} | "
309
+ f"Color: {result['color']:12s} | Category: {result['hierarchy']:15s} | "
310
+ f"Text: {result['text'][:50]}...")
311
+ print(f" ๐Ÿ”— URL: {result['url']}")
312
+ print()
313
+
314
+
315
+ def main():
316
+ """Main function for command-line usage"""
317
+ parser = argparse.ArgumentParser(description="Fashion Search Engine with Category Emphasis")
318
+ parser.add_argument("--query", "-q", type=str, help="Search query")
319
+ parser.add_argument("--top-k", "-k", type=int, default=10, help="Number of results (default: 10)")
320
+ parser.add_argument("--fast", "-f", action="store_true", help="Fast mode (less items)")
321
+ parser.add_argument("--interactive", "-i", action="store_true", help="Interactive mode")
322
+
323
+ args = parser.parse_args()
324
+
325
+ print("๐ŸŽฏ Fashion Search Engine with Category Emphasis")
326
+
327
+ search_engine = FashionSearchEngine(
328
+ top_k=args.top_k,
329
+ )
330
+ print("โœ… Ready!")
331
+
332
+ # Single query mode
333
+ if args.query:
334
+ print(f"๐Ÿ” Search: '{args.query}'...")
335
+ results = search_engine.search_by_text(args.query)
336
+ search_engine.display_results(results, args.query)
337
+
338
+
339
+ # Interactive mode
340
+ print("Enter your query (e.g. 'red dress') or 'quit' to exit")
341
+
342
+ while True:
343
+ try:
344
+ user_input = input("\n๐Ÿ” Query: ").strip()
345
+ if not user_input or user_input.lower() in ['quit', 'exit', 'q']:
346
+ print("๐Ÿ‘‹ Goodbye!")
347
+ break
348
+
349
+ if user_input.startswith('verify '):
350
+ if 'yellow accessories' in user_input:
351
+ search_engine.display_yellow_accessories()
352
+ continue
353
+
354
+ print(f"๐Ÿ” Search: '{user_input}'...")
355
+ results = search_engine.search_by_text(user_input)
356
+ search_engine.display_results(results, user_input)
357
+
358
+ except KeyboardInterrupt:
359
+ print("\n๐Ÿ‘‹ Goodbye!")
360
+ break
361
+ except Exception as e:
362
+ print(f"โŒ Error: {e}")
363
+
364
+ if __name__ == "__main__":
365
+ main()