Anusha806 commited on
Commit
c3e083b
·
1 Parent(s): c967063

gradionotworking

Browse files
Files changed (2) hide show
  1. app.py +548 -307
  2. requirements.txt +10 -6
app.py CHANGED
@@ -1,76 +1,483 @@
1
 
2
- import os
3
- from pinecone import Pinecone, ServerlessSpec
4
- from PIL import Image, ImageOps
5
- import numpy as np
6
- from datasets import load_dataset
7
- from pinecone_text.sparse import BM25Encoder
8
- from sentence_transformers import SentenceTransformer
9
- import torch
10
- from tqdm.auto import tqdm
11
- import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- # ------------------- Pinecone Setup -------------------
14
- os.environ["PINECONE_API_KEY"] = "pcsk_TMCYK_LrbmZMTDhkxTjUXcr8iTcQ8LxurwKBFDvv4ahFis8SVob7QexVPPEt6g2zW6d3g"
15
- api_key = os.environ.get('PINECONE_API_KEY')
16
- pc = Pinecone(api_key=api_key)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
 
19
- cloud = os.environ.get('PINECONE_CLOUD') or 'aws'
20
- region = os.environ.get('PINECONE_REGION') or 'us-east-1'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- spec = ServerlessSpec(cloud=cloud, region=region)
 
 
23
 
24
- index_name = "hybrid-image-search"
25
- spec = ServerlessSpec(cloud="aws", region="us-east-1")
26
- # choose a name for your index
27
- index_name = "hybrid-image-search"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  import time
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- # check if index already exists (it shouldn't if this is first time)
31
  if index_name not in pc.list_indexes().names():
32
- # if does not exist, create index
33
- pc.create_index(
34
- index_name,
35
- dimension=512,
36
- metric='dotproduct',
37
- spec=spec
38
- )
39
- # wait for index to be initialized
40
  while not pc.describe_index(index_name).status['ready']:
41
  time.sleep(1)
42
-
43
- # connect to index
44
  index = pc.Index(index_name)
45
- # view index stats
46
- index.describe_index_stats()
47
 
48
- # ------------------- Dataset Loading -------------------
49
  fashion = load_dataset("ashraq/fashion-product-images-small", split="train")
50
  images = fashion["image"]
51
  metadata = fashion.remove_columns("image").to_pandas()
52
-
53
- # ------------------- Encoders -------------------
54
  bm25 = BM25Encoder()
55
  bm25.fit(metadata["productDisplayName"])
56
- model = SentenceTransformer('sentence-transformers/clip-ViT-B-32', device='cuda' if torch.cuda.is_available() else 'cpu')
57
- from sentence_transformers import SentenceTransformer
58
- import torch
59
-
60
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
 
 
61
 
62
- # load a CLIP model from huggingface
63
- model = SentenceTransformer(
64
- 'sentence-transformers/clip-ViT-B-32',
65
- device=device
66
- )
67
- model
68
- # ------------------- Hybrid Scaling -------------------
69
  def hybrid_scale(dense, sparse, alpha: float):
70
-
71
  if alpha < 0 or alpha > 1:
72
  raise ValueError("Alpha must be between 0 and 1")
73
- # scale sparse and dense vectors to create hybrid search vecs
74
  hsparse = {
75
  'indices': sparse['indices'],
76
  'values': [v * (1 - alpha) for v in sparse['values']]
@@ -78,176 +485,77 @@ def hybrid_scale(dense, sparse, alpha: float):
78
  hdense = [v * alpha for v in dense]
79
  return hdense, hsparse
80
 
81
- # ------------------- Metadata Filter Extraction -------------------
82
- from PIL import Image, ImageOps
83
- import numpy as np
84
- from PIL import Image, ImageOps
85
- import numpy as np
86
- from PIL import Image, ImageOps
87
- import numpy as np
88
-
89
- from transformers import CLIPProcessor, CLIPModel
90
-
91
- clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
92
- clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
93
-
94
- def extract_metadata_filters(query: str):
95
- query_lower = query.lower()
96
- gender = None
97
- category = None
98
- subcategory = None
99
- color = None
100
-
101
- # --- Gender Mapping ---
102
- gender_map = {
103
- "men": "Men", "man": "Men", "mens": "Men", "mans": "Men", "male": "Men",
104
- "women": "Women", "woman": "Women", "womens": "Women", "female": "Women",
105
- "boys": "Boys", "boy": "Boys",
106
- "girls": "Girls", "girl": "Girls",
107
- "kids": "Kids","kid": "Kids",
108
- "unisex": "Unisex"
109
- }
110
- for term, mapped_value in gender_map.items():
111
- if term in query_lower:
112
- gender = mapped_value
113
- break
114
-
115
- # --- Category Mapping ---
116
- category_map = {
117
- "shirt": "Shirts",
118
- "tshirt": "Tshirts", "t-shirt": "Tshirts",
119
- "jeans": "Jeans",
120
- "watch": "Watches",
121
- "kurta": "Kurtas",
122
- "dress": "Dresses", "dresses": "Dresses",
123
- "trousers": "Trousers", "pants": "Trousers",
124
- "shorts": "Shorts",
125
- "footwear": "Footwear",
126
- "shoes": "Shoes", # note kept as Shoes
127
- "fashion": "Apparel"
128
- }
129
- for term, mapped_value in category_map.items():
130
- if term in query_lower:
131
- category = mapped_value
132
- break
133
-
134
- # --- SubCategory Mapping ---
135
- subCategory_list = [
136
- "Accessories", "Apparel Set", "Bags", "Bath and Body", "Beauty Accessories",
137
- "Belts", "Bottomwear", "Cufflinks", "Dress", "Eyes", "Eyewear", "Flip Flops",
138
- "Fragrance", "Free Gifts", "Gloves", "Hair", "Headwear", "Home Furnishing",
139
- "Innerwear", "Jewellery", "Lips", "Loungewear and Nightwear", "Makeup",
140
- "Mufflers", "Nails", "Perfumes", "Sandal", "Saree", "Scarves", "Shoe Accessories",
141
- "Shoes", "Skin", "Skin Care", "Socks", "Sports Accessories", "Sports Equipment",
142
- "Stoles", "Ties", "Topwear", "Umbrellas", "Vouchers", "Wallets", "Watches",
143
- "Water Bottle", "Wristbands"
144
- ]
145
- if "topwear" in query_lower or "top" in query_lower:
146
- subcategory = "Topwear"
147
- else:
148
- for subcat in subCategory_list:
149
- if subcat.lower() in query_lower:
150
- subcategory = subcat
151
- break
152
-
153
- # --- Color Extraction ---
154
- colors = [
155
- "red","blue","green","yellow","black","white",
156
- "orange","pink","purple","brown","grey","beige"
157
- ]
158
- for c in colors:
159
- if c in query_lower:
160
- color = c.capitalize()
161
- break
162
-
163
- # --- Invalid pairs ---
164
- invalid_pairs = {
165
- ("Men", "Dresses"), ("Men", "Sarees"), ("Men", "Skirts"),
166
- ("Boys", "Dresses"), ("Boys", "Sarees"),
167
- ("Girls", "Boxers"), ("Men", "Heels")
168
- }
169
- if (gender, category) in invalid_pairs:
170
- print(f"⚠️ Invalid pair: {gender} + {category}, dropping gender")
171
- gender = None
172
-
173
- # fallback
174
- if gender and not category:
175
- category = "Apparel"
176
-
177
- return gender, category, subcategory, color
178
-
179
-
180
  def search_fashion(query: str, alpha: float, start: int = 0, end: int = 12, gender_override: str = None):
181
- gender, category, subcategory, color = extract_metadata_filters(query)
182
-
183
- # override from dropdown
 
 
184
  if gender_override:
185
  gender = gender_override
186
 
187
- # --- Pinecone Filter ---
188
  filter = {}
189
-
190
  if gender:
191
  filter["gender"] = gender
192
-
193
  if category:
194
  if category in ["Footwear", "Shoes"]:
195
- shoe_article_types = [
196
- "Casual Shoes", "Sports Shoes", "Formal Shoes", "Training Shoes",
197
- "Sneakers", "Sandals", "Slippers", "Boots", "Flip Flops"
198
- ]
199
- filter["articleType"] = {"$in": shoe_article_types}
200
  else:
201
  filter["articleType"] = category
202
-
203
  if subcategory:
204
  filter["subCategory"] = subcategory
205
-
206
  if color:
207
  filter["baseColour"] = color
208
 
209
- print(f"🔍 Using filter: {filter} (showing {start} to {end})")
210
-
211
  sparse = bm25.encode_queries(query)
212
  dense = model.encode(query).tolist()
213
  hdense, hsparse = hybrid_scale(dense, sparse, alpha=alpha)
214
 
215
  result = index.query(
216
- top_k=end,
217
  vector=hdense,
218
  sparse_vector=hsparse,
219
  include_metadata=True,
220
  filter=filter if filter else None
221
  )
222
 
223
- # fallback if no results
224
  if len(result["matches"]) == 0:
225
  print("⚠️ No results, retrying with alpha=0 sparse only")
226
  hdense, hsparse = hybrid_scale(dense, sparse, alpha=0)
227
- result = index.query(
228
- top_k=end,
229
- vector=hdense,
230
- sparse_vector=hsparse,
231
- include_metadata=True,
232
- filter=filter if filter else None
233
- )
234
-
235
- # fallback if no results with gender
236
- if gender and len(result["matches"]) == 0:
237
- print(f"⚠️ No results for gender {gender}, relaxing gender filter")
238
- filter.pop("gender", None)
239
- result = index.query(
240
- top_k=end,
241
- vector=hdense,
242
- sparse_vector=hsparse,
243
- include_metadata=True,
244
- filter=filter if filter else None
245
- )
246
-
247
- matches = result["matches"][start:end]
248
 
249
  imgs_with_captions = []
250
- for r in matches:
 
251
  idx = int(r["id"])
252
  img = images[idx]
253
  meta = r.get("metadata", {})
@@ -255,183 +563,116 @@ def search_fashion(query: str, alpha: float, start: int = 0, end: int = 12, gend
255
  img = Image.fromarray(np.array(img))
256
  padded = ImageOps.pad(img, (256, 256), color="white")
257
  caption = str(meta.get("productDisplayName", "Unknown Product"))
258
- imgs_with_captions.append((padded, caption))
 
 
 
259
 
260
  return imgs_with_captions
261
 
262
-
263
-
264
- # this is working code block
265
-
266
- from PIL import Image, ImageOps
267
- import numpy as np
268
-
269
  def search_by_image(uploaded_image, alpha=0.5, start=0, end=12):
270
- """
271
- Search visually similar products with support for pagination.
272
- """
273
- # Preprocess image for CLIP
274
  processed = clip_processor(images=uploaded_image, return_tensors="pt").to(device)
275
-
276
  with torch.no_grad():
277
  image_vec = clip_model.get_image_features(**processed)
278
  image_vec = image_vec.cpu().numpy().flatten().tolist()
279
 
280
- # Query a larger top_k so you have enough to paginate
281
- result = index.query(
282
- top_k=end,
283
- vector=image_vec,
284
- include_metadata=True
285
- )
286
-
287
- matches = result["matches"][start:end] # slice for pagination
288
-
289
  imgs_with_captions = []
290
- for r in matches:
 
 
291
  idx = int(r["id"])
292
  img = images[idx]
293
  meta = r.get("metadata", {})
 
294
  if not isinstance(img, Image.Image):
295
  img = Image.fromarray(np.array(img))
296
  padded = ImageOps.pad(img, (256, 256), color="white")
297
- caption = str(meta.get("productDisplayName", "Unknown Product"))
298
- imgs_with_captions.append((padded, caption))
 
 
299
 
300
  return imgs_with_captions
301
 
302
- # with gr.Blocks(css=custom_css) as demo:
303
- # gr.Markdown("# 🛍️ Fashion Product Hybrid Search")
304
-
305
- # with gr.Row(equal_height=True):
306
- # with gr.Column(scale=5, elem_classes="query-slider"):
307
- # query = gr.Textbox(
308
- # label="Enter your fashion search query",
309
- # placeholder="Type something or leave blank to only use the image"
310
- # )
311
- # alpha = gr.Slider(
312
- # 0, 1, value=0.5,
313
- # label="Hybrid Weight (alpha: 0=sparse, 1=dense)"
314
- # )
315
- # with gr.Column(scale=1):
316
- # image_input = gr.Image(
317
- # type="pil",
318
- # label="Upload an image (optional)",
319
- # height=256,
320
- # width=356,
321
- # show_label=True
322
- # )
323
-
324
- # search_btn = gr.Button("Search", elem_classes="search-btn")
325
-
326
- # gallery = gr.Gallery(
327
- # label="Search Results",
328
- # columns=6,
329
- # height="40vh"
330
- # )
331
- import gradio as gr
332
- import gradio as gr
333
  custom_css = """
334
- .search-btn {
335
- width: 100%;
336
- }
337
- .gr-row {
338
- gap: 8px !important;
339
- }
340
- .query-slider > div {
341
- margin-bottom: 4px !important;
342
- }
343
- .upload-box .icon-container {
344
- display: none !important;
345
- }
346
  """
347
 
348
  with gr.Blocks(css=custom_css) as demo:
349
- gr.Markdown("# 🛍️ Fashion Product Hybrid Search")
350
 
351
  with gr.Row(equal_height=True):
352
  with gr.Column(scale=5, elem_classes="query-slider"):
353
- query = gr.Textbox(
354
- label="Enter your fashion search query",
355
- placeholder="Type something or leave blank to only use the image"
356
- )
357
  alpha = gr.Slider(0, 1, value=0.5, label="Hybrid Weight (alpha: 0=sparse, 1=dense)")
358
-
359
- gender_dropdown = gr.Dropdown(
360
- ["", "Men", "Women", "Boys", "Girls", "Kids", "Unisex"],
361
- label="Gender Filter (optional)"
362
- )
363
- # with gr.Column(scale=1):
364
- # image_input = gr.Image(
365
- # type="pil",
366
- # label="Upload an image (optional)",
367
- # height=256,
368
- # width=356
369
- # )
370
  with gr.Column(scale=1):
371
- image_input = gr.Image(
372
- type="pil",
373
- label="Upload an image (optional)",
374
- height=256,
375
- width=356,
376
- sources=["upload", "clipboard"] # only upload and paste allowed
377
- )
378
-
379
 
380
  search_btn = gr.Button("Search", elem_classes="search-btn")
381
- gallery = gr.Gallery(label="Search Results", columns=6, height="50vh")
382
  load_more_btn = gr.Button("Load More")
383
 
384
- # States to track
385
  search_offset = gr.State(0)
386
  current_query = gr.State("")
387
  current_image = gr.State(None)
388
  current_gender = gr.State("")
389
- shown_results = gr.State([]) # new: store the list of shown images
 
390
 
391
  def unified_search(q, uploaded_image, a, offset, gender_ui):
392
  start = 0
393
  end = 12
394
-
395
- gender_override = gender_ui if gender_ui else None
396
 
397
  if uploaded_image is not None:
398
  results = search_by_image(uploaded_image, a, start, end)
399
- elif q.strip() != "":
400
  results = search_fashion(q, a, start, end, gender_override)
401
  else:
402
  results = []
403
 
404
- # reset shown_results to just these first 12
405
- return results, end, q, uploaded_image, gender_ui, results
406
 
407
- search_btn.click(
408
- unified_search,
409
- inputs=[query, image_input, alpha, search_offset, gender_dropdown],
410
- outputs=[gallery, search_offset, current_query, current_image, current_gender, shown_results]
411
- )
412
 
413
- def load_more_fn(a, offset, q, img, gender_ui, prev_results):
414
  start = offset
415
  end = offset + 12
416
-
417
- gender_override = gender_ui if gender_ui else None
418
 
419
  if img is not None:
420
  new_results = search_by_image(img, a, start, end)
421
- elif q.strip() != "":
422
  new_results = search_fashion(q, a, start, end, gender_override)
423
  else:
424
  new_results = []
425
 
426
- combined_results = prev_results + new_results
427
- return combined_results, end, combined_results
 
 
 
 
 
428
 
429
- load_more_btn.click(
430
- load_more_fn,
431
- inputs=[alpha, search_offset, current_query, current_image, current_gender, shown_results],
432
- outputs=[gallery, search_offset, shown_results]
433
- )
 
 
434
 
435
- gr.Markdown("Powered by your hybrid AI search model 🚀")
436
 
437
- demo.launch()
 
1
 
2
+ # import os
3
+ # from pinecone import Pinecone, ServerlessSpec
4
+ # from PIL import Image, ImageOps
5
+ # import numpy as np
6
+ # from datasets import load_dataset
7
+ # from pinecone_text.sparse import BM25Encoder
8
+ # from sentence_transformers import SentenceTransformer
9
+ # import torch
10
+ # from tqdm.auto import tqdm
11
+ # import gradio as gr
12
+
13
+ # # ------------------- Pinecone Setup -------------------
14
+ # os.environ["PINECONE_API_KEY"] = "pcsk_TMCYK_LrbmZMTDhkxTjUXcr8iTcQ8LxurwKBFDvv4ahFis8SVob7QexVPPEt6g2zW6d3g"
15
+ # api_key = os.environ.get('PINECONE_API_KEY')
16
+ # pc = Pinecone(api_key=api_key)
17
+
18
+
19
+ # cloud = os.environ.get('PINECONE_CLOUD') or 'aws'
20
+ # region = os.environ.get('PINECONE_REGION') or 'us-east-1'
21
+
22
+ # spec = ServerlessSpec(cloud=cloud, region=region)
23
+
24
+ # index_name = "hybrid-image-search"
25
+ # spec = ServerlessSpec(cloud="aws", region="us-east-1")
26
+ # # choose a name for your index
27
+ # index_name = "hybrid-image-search"
28
+ # import time
29
+
30
+ # # check if index already exists (it shouldn't if this is first time)
31
+ # if index_name not in pc.list_indexes().names():
32
+ # # if does not exist, create index
33
+ # pc.create_index(
34
+ # index_name,
35
+ # dimension=512,
36
+ # metric='dotproduct',
37
+ # spec=spec
38
+ # )
39
+ # # wait for index to be initialized
40
+ # while not pc.describe_index(index_name).status['ready']:
41
+ # time.sleep(1)
42
+
43
+ # # connect to index
44
+ # index = pc.Index(index_name)
45
+ # # view index stats
46
+ # index.describe_index_stats()
47
+
48
+ # # ------------------- Dataset Loading -------------------
49
+ # fashion = load_dataset("ashraq/fashion-product-images-small", split="train")
50
+ # images = fashion["image"]
51
+ # metadata = fashion.remove_columns("image").to_pandas()
52
+
53
+ # # ------------------- Encoders -------------------
54
+ # bm25 = BM25Encoder()
55
+ # bm25.fit(metadata["productDisplayName"])
56
+ # model = SentenceTransformer('sentence-transformers/clip-ViT-B-32', device='cuda' if torch.cuda.is_available() else 'cpu')
57
+ # from sentence_transformers import SentenceTransformer
58
+ # import torch
59
+
60
+ # device = 'cuda' if torch.cuda.is_available() else 'cpu'
61
+
62
+ # # load a CLIP model from huggingface
63
+ # model = SentenceTransformer(
64
+ # 'sentence-transformers/clip-ViT-B-32',
65
+ # device=device
66
+ # )
67
+ # model
68
+ # # ------------------- Hybrid Scaling -------------------
69
+ # def hybrid_scale(dense, sparse, alpha: float):
70
+
71
+ # if alpha < 0 or alpha > 1:
72
+ # raise ValueError("Alpha must be between 0 and 1")
73
+ # # scale sparse and dense vectors to create hybrid search vecs
74
+ # hsparse = {
75
+ # 'indices': sparse['indices'],
76
+ # 'values': [v * (1 - alpha) for v in sparse['values']]
77
+ # }
78
+ # hdense = [v * alpha for v in dense]
79
+ # return hdense, hsparse
80
+
81
+ # # ------------------- Metadata Filter Extraction -------------------
82
+ # from PIL import Image, ImageOps
83
+ # import numpy as np
84
+ # from PIL import Image, ImageOps
85
+ # import numpy as np
86
+ # from PIL import Image, ImageOps
87
+ # import numpy as np
88
+
89
+ # from transformers import CLIPProcessor, CLIPModel
90
+
91
+ # clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
92
+ # clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
93
+
94
+ # def extract_metadata_filters(query: str):
95
+ # query_lower = query.lower()
96
+ # gender = None
97
+ # category = None
98
+ # subcategory = None
99
+ # color = None
100
+
101
+ # # --- Gender Mapping ---
102
+ # gender_map = {
103
+ # "men": "Men", "man": "Men", "mens": "Men", "mans": "Men", "male": "Men",
104
+ # "women": "Women", "woman": "Women", "womens": "Women", "female": "Women",
105
+ # "boys": "Boys", "boy": "Boys",
106
+ # "girls": "Girls", "girl": "Girls",
107
+ # "kids": "Kids","kid": "Kids",
108
+ # "unisex": "Unisex"
109
+ # }
110
+ # for term, mapped_value in gender_map.items():
111
+ # if term in query_lower:
112
+ # gender = mapped_value
113
+ # break
114
+
115
+ # # --- Category Mapping ---
116
+ # category_map = {
117
+ # "shirt": "Shirts",
118
+ # "tshirt": "Tshirts", "t-shirt": "Tshirts",
119
+ # "jeans": "Jeans",
120
+ # "watch": "Watches",
121
+ # "kurta": "Kurtas",
122
+ # "dress": "Dresses", "dresses": "Dresses",
123
+ # "trousers": "Trousers", "pants": "Trousers",
124
+ # "shorts": "Shorts",
125
+ # "footwear": "Footwear",
126
+ # "shoes": "Shoes", # note kept as Shoes
127
+ # "fashion": "Apparel"
128
+ # }
129
+ # for term, mapped_value in category_map.items():
130
+ # if term in query_lower:
131
+ # category = mapped_value
132
+ # break
133
+
134
+ # # --- SubCategory Mapping ---
135
+ # subCategory_list = [
136
+ # "Accessories", "Apparel Set", "Bags", "Bath and Body", "Beauty Accessories",
137
+ # "Belts", "Bottomwear", "Cufflinks", "Dress", "Eyes", "Eyewear", "Flip Flops",
138
+ # "Fragrance", "Free Gifts", "Gloves", "Hair", "Headwear", "Home Furnishing",
139
+ # "Innerwear", "Jewellery", "Lips", "Loungewear and Nightwear", "Makeup",
140
+ # "Mufflers", "Nails", "Perfumes", "Sandal", "Saree", "Scarves", "Shoe Accessories",
141
+ # "Shoes", "Skin", "Skin Care", "Socks", "Sports Accessories", "Sports Equipment",
142
+ # "Stoles", "Ties", "Topwear", "Umbrellas", "Vouchers", "Wallets", "Watches",
143
+ # "Water Bottle", "Wristbands"
144
+ # ]
145
+ # if "topwear" in query_lower or "top" in query_lower:
146
+ # subcategory = "Topwear"
147
+ # else:
148
+ # for subcat in subCategory_list:
149
+ # if subcat.lower() in query_lower:
150
+ # subcategory = subcat
151
+ # break
152
+
153
+ # # --- Color Extraction ---
154
+ # colors = [
155
+ # "red","blue","green","yellow","black","white",
156
+ # "orange","pink","purple","brown","grey","beige"
157
+ # ]
158
+ # for c in colors:
159
+ # if c in query_lower:
160
+ # color = c.capitalize()
161
+ # break
162
+
163
+ # # --- Invalid pairs ---
164
+ # invalid_pairs = {
165
+ # ("Men", "Dresses"), ("Men", "Sarees"), ("Men", "Skirts"),
166
+ # ("Boys", "Dresses"), ("Boys", "Sarees"),
167
+ # ("Girls", "Boxers"), ("Men", "Heels")
168
+ # }
169
+ # if (gender, category) in invalid_pairs:
170
+ # print(f"⚠️ Invalid pair: {gender} + {category}, dropping gender")
171
+ # gender = None
172
+
173
+ # # fallback
174
+ # if gender and not category:
175
+ # category = "Apparel"
176
+
177
+ # return gender, category, subcategory, color
178
+
179
+
180
+ # def search_fashion(query: str, alpha: float, start: int = 0, end: int = 12, gender_override: str = None):
181
+ # gender, category, subcategory, color = extract_metadata_filters(query)
182
+
183
+ # # override from dropdown
184
+ # if gender_override:
185
+ # gender = gender_override
186
+
187
+ # # --- Pinecone Filter ---
188
+ # filter = {}
189
+
190
+ # if gender:
191
+ # filter["gender"] = gender
192
+
193
+ # if category:
194
+ # if category in ["Footwear", "Shoes"]:
195
+ # shoe_article_types = [
196
+ # "Casual Shoes", "Sports Shoes", "Formal Shoes", "Training Shoes",
197
+ # "Sneakers", "Sandals", "Slippers", "Boots", "Flip Flops"
198
+ # ]
199
+ # filter["articleType"] = {"$in": shoe_article_types}
200
+ # else:
201
+ # filter["articleType"] = category
202
+
203
+ # if subcategory:
204
+ # filter["subCategory"] = subcategory
205
+
206
+ # if color:
207
+ # filter["baseColour"] = color
208
+
209
+ # print(f"🔍 Using filter: {filter} (showing {start} to {end})")
210
+
211
+ # sparse = bm25.encode_queries(query)
212
+ # dense = model.encode(query).tolist()
213
+ # hdense, hsparse = hybrid_scale(dense, sparse, alpha=alpha)
214
+
215
+ # result = index.query(
216
+ # top_k=end,
217
+ # vector=hdense,
218
+ # sparse_vector=hsparse,
219
+ # include_metadata=True,
220
+ # filter=filter if filter else None
221
+ # )
222
+
223
+ # # fallback if no results
224
+ # if len(result["matches"]) == 0:
225
+ # print("⚠️ No results, retrying with alpha=0 sparse only")
226
+ # hdense, hsparse = hybrid_scale(dense, sparse, alpha=0)
227
+ # result = index.query(
228
+ # top_k=end,
229
+ # vector=hdense,
230
+ # sparse_vector=hsparse,
231
+ # include_metadata=True,
232
+ # filter=filter if filter else None
233
+ # )
234
+
235
+ # # fallback if no results with gender
236
+ # if gender and len(result["matches"]) == 0:
237
+ # print(f"⚠️ No results for gender {gender}, relaxing gender filter")
238
+ # filter.pop("gender", None)
239
+ # result = index.query(
240
+ # top_k=end,
241
+ # vector=hdense,
242
+ # sparse_vector=hsparse,
243
+ # include_metadata=True,
244
+ # filter=filter if filter else None
245
+ # )
246
+
247
+ # matches = result["matches"][start:end]
248
+
249
+ # imgs_with_captions = []
250
+ # for r in matches:
251
+ # idx = int(r["id"])
252
+ # img = images[idx]
253
+ # meta = r.get("metadata", {})
254
+ # if not isinstance(img, Image.Image):
255
+ # img = Image.fromarray(np.array(img))
256
+ # padded = ImageOps.pad(img, (256, 256), color="white")
257
+ # caption = str(meta.get("productDisplayName", "Unknown Product"))
258
+ # imgs_with_captions.append((padded, caption))
259
+
260
+ # return imgs_with_captions
261
+
262
+
263
+
264
+ # # this is working code block
265
+
266
+ # from PIL import Image, ImageOps
267
+ # import numpy as np
268
+
269
+ # def search_by_image(uploaded_image, alpha=0.5, start=0, end=12):
270
+ # """
271
+ # Search visually similar products with support for pagination.
272
+ # """
273
+ # # Preprocess image for CLIP
274
+ # processed = clip_processor(images=uploaded_image, return_tensors="pt").to(device)
275
+
276
+ # with torch.no_grad():
277
+ # image_vec = clip_model.get_image_features(**processed)
278
+ # image_vec = image_vec.cpu().numpy().flatten().tolist()
279
+
280
+ # # Query a larger top_k so you have enough to paginate
281
+ # result = index.query(
282
+ # top_k=end,
283
+ # vector=image_vec,
284
+ # include_metadata=True
285
+ # )
286
+
287
+ # matches = result["matches"][start:end] # slice for pagination
288
+
289
+ # imgs_with_captions = []
290
+ # for r in matches:
291
+ # idx = int(r["id"])
292
+ # img = images[idx]
293
+ # meta = r.get("metadata", {})
294
+ # if not isinstance(img, Image.Image):
295
+ # img = Image.fromarray(np.array(img))
296
+ # padded = ImageOps.pad(img, (256, 256), color="white")
297
+ # caption = str(meta.get("productDisplayName", "Unknown Product"))
298
+ # imgs_with_captions.append((padded, caption))
299
+
300
+ # return imgs_with_captions
301
+
302
+ # # with gr.Blocks(css=custom_css) as demo:
303
+ # # gr.Markdown("# 🛍️ Fashion Product Hybrid Search")
304
+
305
+ # # with gr.Row(equal_height=True):
306
+ # # with gr.Column(scale=5, elem_classes="query-slider"):
307
+ # # query = gr.Textbox(
308
+ # # label="Enter your fashion search query",
309
+ # # placeholder="Type something or leave blank to only use the image"
310
+ # # )
311
+ # # alpha = gr.Slider(
312
+ # # 0, 1, value=0.5,
313
+ # # label="Hybrid Weight (alpha: 0=sparse, 1=dense)"
314
+ # # )
315
+ # # with gr.Column(scale=1):
316
+ # # image_input = gr.Image(
317
+ # # type="pil",
318
+ # # label="Upload an image (optional)",
319
+ # # height=256,
320
+ # # width=356,
321
+ # # show_label=True
322
+ # # )
323
+
324
+ # # search_btn = gr.Button("Search", elem_classes="search-btn")
325
+
326
+ # # gallery = gr.Gallery(
327
+ # # label="Search Results",
328
+ # # columns=6,
329
+ # # height="40vh"
330
+ # # )
331
+ # import gradio as gr
332
+ # custom_css = """
333
+ # .search-btn {
334
+ # width: 100%;
335
+ # }
336
+ # .gr-row {
337
+ # gap: 8px !important;
338
+ # }
339
+ # .query-slider > div {
340
+ # margin-bottom: 4px !important;
341
+ # }
342
+ # .upload-box .icon-container {
343
+ # display: none !important;
344
+ # }
345
+ # """
346
+
347
+ # with gr.Blocks(css=custom_css) as demo:
348
+ # gr.Markdown("# 🛍️ Fashion Product Hybrid Search")
349
 
350
+ # with gr.Row(equal_height=True):
351
+ # with gr.Column(scale=5, elem_classes="query-slider"):
352
+ # query = gr.Textbox(
353
+ # label="Enter your fashion search query",
354
+ # placeholder="Type something or leave blank to only use the image"
355
+ # )
356
+ # alpha = gr.Slider(0, 1, value=0.5, label="Hybrid Weight (alpha: 0=sparse, 1=dense)")
357
+
358
+ # gender_dropdown = gr.Dropdown(
359
+ # ["", "Men", "Women", "Boys", "Girls", "Kids", "Unisex"],
360
+ # label="Gender Filter (optional)"
361
+ # )
362
+ # # with gr.Column(scale=1):
363
+ # # image_input = gr.Image(
364
+ # # type="pil",
365
+ # # label="Upload an image (optional)",
366
+ # # height=256,
367
+ # # width=356
368
+ # # )
369
+ # with gr.Column(scale=1):
370
+ # image_input = gr.Image(
371
+ # type="pil",
372
+ # label="Upload an image (optional)",
373
+ # height=256,
374
+ # width=356,
375
+ # sources=["upload", "clipboard"] # only upload and paste allowed
376
+ # )
377
 
378
 
379
+ # search_btn = gr.Button("Search", elem_classes="search-btn")
380
+ # gallery = gr.Gallery(label="Search Results", columns=6, height="50vh")
381
+ # load_more_btn = gr.Button("Load More")
382
+
383
+ # # States to track
384
+ # search_offset = gr.State(0)
385
+ # current_query = gr.State("")
386
+ # current_image = gr.State(None)
387
+ # current_gender = gr.State("")
388
+ # shown_results = gr.State([]) # new: store the list of shown images
389
+
390
+ # def unified_search(q, uploaded_image, a, offset, gender_ui):
391
+ # start = 0
392
+ # end = 12
393
+
394
+ # gender_override = gender_ui if gender_ui else None
395
+
396
+ # if uploaded_image is not None:
397
+ # results = search_by_image(uploaded_image, a, start, end)
398
+ # elif q.strip() != "":
399
+ # results = search_fashion(q, a, start, end, gender_override)
400
+ # else:
401
+ # results = []
402
+
403
+ # # reset shown_results to just these first 12
404
+ # return results, end, q, uploaded_image, gender_ui, results
405
+
406
+ # search_btn.click(
407
+ # unified_search,
408
+ # inputs=[query, image_input, alpha, search_offset, gender_dropdown],
409
+ # outputs=[gallery, search_offset, current_query, current_image, current_gender, shown_results]
410
+ # )
411
 
412
+ # def load_more_fn(a, offset, q, img, gender_ui, prev_results):
413
+ # start = offset
414
+ # end = offset + 12
415
 
416
+ # gender_override = gender_ui if gender_ui else None
417
+
418
+ # if img is not None:
419
+ # new_results = search_by_image(img, a, start, end)
420
+ # elif q.strip() != "":
421
+ # new_results = search_fashion(q, a, start, end, gender_override)
422
+ # else:
423
+ # new_results = []
424
+
425
+ # combined_results = prev_results + new_results
426
+ # return combined_results, end, combined_results
427
+
428
+ # load_more_btn.click(
429
+ # load_more_fn,
430
+ # inputs=[alpha, search_offset, current_query, current_image, current_gender, shown_results],
431
+ # outputs=[gallery, search_offset, shown_results]
432
+ # )
433
+
434
+ # gr.Markdown("Powered by your hybrid AI search model 🚀")
435
+
436
+ # demo.launch()
437
+
438
+
439
+ # app.py
440
+ import os
441
  import time
442
+ import torch
443
+ import numpy as np
444
+ import gradio as gr
445
+ from PIL import Image, ImageOps
446
+ from tqdm.auto import tqdm
447
+ from datasets import load_dataset
448
+ from sentence_transformers import SentenceTransformer
449
+ from pinecone import Pinecone, ServerlessSpec
450
+ from pinecone_text.sparse import BM25Encoder
451
+ from transformers import CLIPProcessor, CLIPModel
452
+ import openai
453
+
454
+ # ------------------- Keys & Setup -------------------
455
+ openai.api_key = os.getenv("OPENAI_API_KEY")
456
+ pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
457
+ spec = ServerlessSpec(cloud=os.getenv("PINECONE_CLOUD") or "aws", region=os.getenv("PINECONE_REGION") or "us-east-1")
458
+ index_name = "hybrid-image-search"
459
 
 
460
  if index_name not in pc.list_indexes().names():
461
+ pc.create_index(index_name, dimension=512, metric='dotproduct', spec=spec)
 
 
 
 
 
 
 
462
  while not pc.describe_index(index_name).status['ready']:
463
  time.sleep(1)
 
 
464
  index = pc.Index(index_name)
 
 
465
 
466
+ # ------------------- Models & Dataset -------------------
467
  fashion = load_dataset("ashraq/fashion-product-images-small", split="train")
468
  images = fashion["image"]
469
  metadata = fashion.remove_columns("image").to_pandas()
 
 
470
  bm25 = BM25Encoder()
471
  bm25.fit(metadata["productDisplayName"])
 
 
 
 
472
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
473
+ model = SentenceTransformer('sentence-transformers/clip-ViT-B-32', device=device)
474
+ clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
475
+ clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
476
 
477
+ # ------------------- Helper Functions -------------------
 
 
 
 
 
 
478
  def hybrid_scale(dense, sparse, alpha: float):
 
479
  if alpha < 0 or alpha > 1:
480
  raise ValueError("Alpha must be between 0 and 1")
 
481
  hsparse = {
482
  'indices': sparse['indices'],
483
  'values': [v * (1 - alpha) for v in sparse['values']]
 
485
  hdense = [v * alpha for v in dense]
486
  return hdense, hsparse
487
 
488
+ def extract_intent_from_openai(query: str):
489
+ prompt = f'''
490
+ You are an assistant for a fashion search engine. Extract the user's intent from the following query.
491
+ Return a Python dictionary with keys: category, gender, subcategory, color.
492
+ If something is missing, use null.
493
+ Query: "{query}"
494
+ Only return the dictionary.
495
+ '''
496
+ try:
497
+ response = openai.ChatCompletion.create(
498
+ model="gpt-4",
499
+ messages=[{"role": "user", "content": prompt}],
500
+ temperature=0
501
+ )
502
+ raw = response.choices[0].message['content']
503
+ structured = eval(raw)
504
+ return structured
505
+ except Exception as e:
506
+ print(f"⚠️ OpenAI intent extraction failed: {e}")
507
+ return {}
508
+
509
+ def is_duplicate(img, seen_hashes):
510
+ h = hash(img.tobytes())
511
+ if h in seen_hashes:
512
+ return True
513
+ seen_hashes.add(h)
514
+ return False
515
+
516
+ # ------------------- Search Functions -------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
517
  def search_fashion(query: str, alpha: float, start: int = 0, end: int = 12, gender_override: str = None):
518
+ intent = extract_intent_from_openai(query)
519
+ gender = intent.get("gender")
520
+ category = intent.get("category")
521
+ subcategory = intent.get("subcategory")
522
+ color = intent.get("color")
523
  if gender_override:
524
  gender = gender_override
525
 
 
526
  filter = {}
 
527
  if gender:
528
  filter["gender"] = gender
 
529
  if category:
530
  if category in ["Footwear", "Shoes"]:
531
+ filter["articleType"] = {"$regex": ".*(Shoe|Footwear).*"}
 
 
 
 
532
  else:
533
  filter["articleType"] = category
 
534
  if subcategory:
535
  filter["subCategory"] = subcategory
 
536
  if color:
537
  filter["baseColour"] = color
538
 
 
 
539
  sparse = bm25.encode_queries(query)
540
  dense = model.encode(query).tolist()
541
  hdense, hsparse = hybrid_scale(dense, sparse, alpha=alpha)
542
 
543
  result = index.query(
544
+ top_k=100,
545
  vector=hdense,
546
  sparse_vector=hsparse,
547
  include_metadata=True,
548
  filter=filter if filter else None
549
  )
550
 
 
551
  if len(result["matches"]) == 0:
552
  print("⚠️ No results, retrying with alpha=0 sparse only")
553
  hdense, hsparse = hybrid_scale(dense, sparse, alpha=0)
554
+ result = index.query(top_k=100, vector=hdense, sparse_vector=hsparse, include_metadata=True, filter=filter)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
555
 
556
  imgs_with_captions = []
557
+ seen_hashes = set()
558
+ for r in result["matches"]:
559
  idx = int(r["id"])
560
  img = images[idx]
561
  meta = r.get("metadata", {})
 
563
  img = Image.fromarray(np.array(img))
564
  padded = ImageOps.pad(img, (256, 256), color="white")
565
  caption = str(meta.get("productDisplayName", "Unknown Product"))
566
+ if not is_duplicate(padded, seen_hashes):
567
+ imgs_with_captions.append((padded, caption))
568
+ if len(imgs_with_captions) >= end:
569
+ break
570
 
571
  return imgs_with_captions
572
 
 
 
 
 
 
 
 
573
  def search_by_image(uploaded_image, alpha=0.5, start=0, end=12):
 
 
 
 
574
  processed = clip_processor(images=uploaded_image, return_tensors="pt").to(device)
 
575
  with torch.no_grad():
576
  image_vec = clip_model.get_image_features(**processed)
577
  image_vec = image_vec.cpu().numpy().flatten().tolist()
578
 
579
+ result = index.query(top_k=100, vector=image_vec, include_metadata=True)
 
 
 
 
 
 
 
 
580
  imgs_with_captions = []
581
+ seen_hashes = set()
582
+
583
+ for r in result["matches"]:
584
  idx = int(r["id"])
585
  img = images[idx]
586
  meta = r.get("metadata", {})
587
+ caption = str(meta.get("productDisplayName", "Unknown Product"))
588
  if not isinstance(img, Image.Image):
589
  img = Image.fromarray(np.array(img))
590
  padded = ImageOps.pad(img, (256, 256), color="white")
591
+ if not is_duplicate(padded, seen_hashes):
592
+ imgs_with_captions.append((padded, caption))
593
+ if len(imgs_with_captions) >= end:
594
+ break
595
 
596
  return imgs_with_captions
597
 
598
+ # ------------------- UI -------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
599
  custom_css = """
600
+ .search-btn { width: 100%; }
601
+ .gr-row { gap: 8px !important; }
602
+ .query-slider > div { margin-bottom: 4px !important; }
603
+ .gr-gallery-item { width: 256px !important; height: 256px !important; }
604
+ .gr-gallery-item img { width: 100% !important; height: 100% !important; object-fit: cover !important; }
 
 
 
 
 
 
 
605
  """
606
 
607
  with gr.Blocks(css=custom_css) as demo:
608
+ gr.Markdown("# 🛍️ Fashion Product Hybrid Search (with GPT-4 powered query parsing)")
609
 
610
  with gr.Row(equal_height=True):
611
  with gr.Column(scale=5, elem_classes="query-slider"):
612
+ query = gr.Textbox(label="Enter your fashion search query", placeholder="e.g., black sneakers for women")
 
 
 
613
  alpha = gr.Slider(0, 1, value=0.5, label="Hybrid Weight (alpha: 0=sparse, 1=dense)")
614
+ gender_dropdown = gr.Dropdown(["", "Men", "Women", "Boys", "Girls", "Kids", "Unisex"], label="Gender Filter (optional)")
 
 
 
 
 
 
 
 
 
 
 
615
  with gr.Column(scale=1):
616
+ image_input = gr.Image(type="pil", label="Upload an image (optional)", sources=["upload", "clipboard"], height=256, width=356)
 
 
 
 
 
 
 
617
 
618
  search_btn = gr.Button("Search", elem_classes="search-btn")
619
+ gallery = gr.Gallery(label="Search Results", columns=6, height=None)
620
  load_more_btn = gr.Button("Load More")
621
 
 
622
  search_offset = gr.State(0)
623
  current_query = gr.State("")
624
  current_image = gr.State(None)
625
  current_gender = gr.State("")
626
+ shown_results = gr.State([])
627
+ shown_ids = gr.State(set())
628
 
629
  def unified_search(q, uploaded_image, a, offset, gender_ui):
630
  start = 0
631
  end = 12
632
+ filters = extract_intent_from_openai(q) if q.strip() else {}
633
+ gender_override = gender_ui if gender_ui else filters.get("gender")
634
 
635
  if uploaded_image is not None:
636
  results = search_by_image(uploaded_image, a, start, end)
637
+ elif q.strip():
638
  results = search_fashion(q, a, start, end, gender_override)
639
  else:
640
  results = []
641
 
642
+ seen_ids = {r[1] for r in results}
643
+ return results, end, q, uploaded_image, gender_override, results, seen_ids
644
 
645
+ search_btn.click(unified_search, inputs=[query, image_input, alpha, search_offset, gender_dropdown],
646
+ outputs=[gallery, search_offset, current_query, current_image, current_gender, shown_results, shown_ids])
 
 
 
647
 
648
+ def load_more_fn(a, offset, q, img, gender_ui, prev_results, prev_ids):
649
  start = offset
650
  end = offset + 12
651
+ gender_override = gender_ui
 
652
 
653
  if img is not None:
654
  new_results = search_by_image(img, a, start, end)
655
+ elif q.strip():
656
  new_results = search_fashion(q, a, start, end, gender_override)
657
  else:
658
  new_results = []
659
 
660
+ filtered_new = []
661
+ new_ids = set()
662
+ for item in new_results:
663
+ img_obj, caption = item
664
+ if caption not in prev_ids:
665
+ filtered_new.append(item)
666
+ new_ids.add(caption)
667
 
668
+ combined = prev_results + filtered_new
669
+ updated_ids = prev_ids.union(new_ids)
670
+
671
+ return combined, end, combined, updated_ids
672
+
673
+ load_more_btn.click(load_more_fn, inputs=[alpha, search_offset, current_query, current_image, current_gender, shown_results, shown_ids],
674
+ outputs=[gallery, search_offset, shown_results, shown_ids])
675
 
676
+ gr.Markdown("🧠 Powered by OpenAI + Hybrid AI Fashion Search")
677
 
678
+ demo.launch()
requirements.txt CHANGED
@@ -1,7 +1,11 @@
1
- gradio==4.14.0
 
 
 
 
2
  datasets
3
- transformers
4
- sentence-transformers
5
- pinecone-client==3.1.0
6
- pinecone-text
7
- pillow
 
1
+ gradio==4.34.1
2
+ openai==1.30.1
3
+ sentence-transformers==2.6.1
4
+ torch>=2.0.0
5
+ transformers==4.41.1
6
  datasets
7
+ Pillow
8
+ pinecone-client==3.2.2
9
+ scikit-learn
10
+ tqdm
11
+ numpy