Anusha806 commited on
Commit
ce1c241
·
1 Parent(s): 51c1e26
Files changed (1) hide show
  1. app.py +32 -19
app.py CHANGED
@@ -279,16 +279,35 @@ os.environ["PINECONE_API_KEY"] = "pcsk_TMCYK_LrbmZMTDhkxTjUXcr8iTcQ8LxurwKBFDvv4
279
  api_key = os.environ.get('PINECONE_API_KEY')
280
  pc = Pinecone(api_key=api_key)
281
 
 
 
 
 
 
 
282
  index_name = "hybrid-image-search"
283
  spec = ServerlessSpec(cloud="aws", region="us-east-1")
 
 
 
284
 
 
285
  if index_name not in pc.list_indexes().names():
286
- pc.create_index(index_name, dimension=512, metric="dotproduct", spec=spec)
287
- import time
 
 
 
 
 
 
288
  while not pc.describe_index(index_name).status['ready']:
289
  time.sleep(1)
290
 
 
291
  index = pc.Index(index_name)
 
 
292
 
293
  # ------------------- Dataset Loading -------------------
294
  fashion = load_dataset("ashraq/fashion-product-images-small", split="train")
@@ -313,24 +332,11 @@ def hybrid_scale(dense, sparse, alpha: float):
313
  hdense = [v * alpha for v in dense]
314
  return hdense, hsparse
315
 
316
-
317
- # def search_fashion(query: str, alpha: float):
318
- # sparse = bm25.encode_queries(query)
319
- # dense = model.encode(query).tolist()
320
- # hdense, hsparse = hybrid_scale(dense, sparse, alpha=alpha)
321
- # result = index.query(
322
- # top_k=8,
323
- # vector=hdense,
324
- # sparse_vector=hsparse,
325
- # include_metadata=True
326
- # )
327
- # imgs = [images[int(r["id"])] for r in result["matches"]]
328
- # return imgs
329
-
330
-
331
  # ------------------- Metadata Filter Extraction -------------------
332
  from PIL import Image, ImageOps
333
  import numpy as np
 
 
334
 
335
  def extract_metadata_filters(query: str):
336
  query_lower = query.lower()
@@ -452,7 +458,7 @@ def search_fashion(query: str, alpha: float):
452
  print(f"⚠️ No results with gender {gender}, relaxing gender filter")
453
  filter.pop("gender")
454
  result = index.query(
455
- top_k=12,
456
  vector=hdense,
457
  sparse_vector=hsparse,
458
  include_metadata=True,
@@ -473,6 +479,14 @@ def search_fashion(query: str, alpha: float):
473
 
474
  return imgs_with_captions
475
 
 
 
 
 
 
 
 
 
476
  from PIL import Image, ImageOps
477
  import numpy as np
478
 
@@ -509,7 +523,6 @@ def search_by_image(uploaded_image, alpha=0.5):
509
 
510
 
511
 
512
- # ------------------- Gradio UI -------------------
513
  custom_css = """
514
  .search-btn {
515
  width: 100%;
 
279
  api_key = os.environ.get('PINECONE_API_KEY')
280
  pc = Pinecone(api_key=api_key)
281
 
282
+
283
+ cloud = os.environ.get('PINECONE_CLOUD') or 'aws'
284
+ region = os.environ.get('PINECONE_REGION') or 'us-east-1'
285
+
286
+ spec = ServerlessSpec(cloud=cloud, region=region)
287
+
288
  index_name = "hybrid-image-search"
289
  spec = ServerlessSpec(cloud="aws", region="us-east-1")
290
+ # choose a name for your index
291
+ index_name = "hybrid-image-search"
292
+ import time
293
 
294
+ # check if index already exists (it shouldn't if this is first time)
295
  if index_name not in pc.list_indexes().names():
296
+ # if does not exist, create index
297
+ pc.create_index(
298
+ index_name,
299
+ dimension=512,
300
+ metric='dotproduct',
301
+ spec=spec
302
+ )
303
+ # wait for index to be initialized
304
  while not pc.describe_index(index_name).status['ready']:
305
  time.sleep(1)
306
 
307
+ # connect to index
308
  index = pc.Index(index_name)
309
+ # view index stats
310
+ index.describe_index_stats()
311
 
312
  # ------------------- Dataset Loading -------------------
313
  fashion = load_dataset("ashraq/fashion-product-images-small", split="train")
 
332
  hdense = [v * alpha for v in dense]
333
  return hdense, hsparse
334
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335
  # ------------------- Metadata Filter Extraction -------------------
336
  from PIL import Image, ImageOps
337
  import numpy as np
338
+ from PIL import Image, ImageOps
339
+ import numpy as np
340
 
341
  def extract_metadata_filters(query: str):
342
  query_lower = query.lower()
 
458
  print(f"⚠️ No results with gender {gender}, relaxing gender filter")
459
  filter.pop("gender")
460
  result = index.query(
461
+ top_k=12,
462
  vector=hdense,
463
  sparse_vector=hsparse,
464
  include_metadata=True,
 
479
 
480
  return imgs_with_captions
481
 
482
+
483
+ from transformers import CLIPProcessor, CLIPModel
484
+
485
+ clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
486
+ clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
487
+
488
+
489
+
490
  from PIL import Image, ImageOps
491
  import numpy as np
492
 
 
523
 
524
 
525
 
 
526
  custom_css = """
527
  .search-btn {
528
  width: 100%;