Jaywalker061707 commited on
Commit
b8afcdb
·
verified ·
1 Parent(s): c58c07d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -29
app.py CHANGED
@@ -66,33 +66,8 @@ def search(text_query, image_query, k=5):
66
  return [], "Build the index first."
67
 
68
  with torch.no_grad():
69
- if text_query and text_query.strip():
70
- inputs = processor(text=[text_query.strip()], return_tensors="pt")
71
- q = model.get_text_features(**inputs) # [1, 512]
72
- elif image_query is not None:
73
- pil = image_query.convert("RGB")
74
- inputs = processor(images=pil, return_tensors="pt")
75
- q = model.get_image_features(**inputs) # [1, 512]
76
- else:
77
- return [], "Enter text or upload an image."
78
-
79
- q = F.normalize(q, p=2, dim=-1)[0] # [512]
80
- sims = (INDEX["feats"] @ q).cpu() # [N]
81
- topk = torch.topk(sims, k=min(int(k), sims.shape[0]))
82
-
83
- items = []
84
- for idx in topk.indices.tolist():
85
- cap = f"id: {INDEX['ids'][idx]} score: {float(sims[idx]):.3f} band: {INDEX['bands'][idx]}"
86
- items.append((INDEX["thumbs"][idx], cap))
87
- return items, f"Returned {len(items)} results."
88
-
89
- def search(text_query, image_query, k=5):
90
- if INDEX["feats"] is None:
91
- return [], "Build the index first."
92
-
93
- with torch.no_grad():
94
- if text_query and text_query.strip():
95
- inputs = processor(text=[text_query.strip()], return_tensors="pt")
96
  q = model.get_text_features(**inputs) # [1, 512]
97
  elif image_query is not None:
98
  pil = image_query.convert("RGB")
@@ -103,13 +78,16 @@ def search(text_query, image_query, k=5):
103
 
104
  q = F.normalize(q, p=2, dim=-1)[0] # [512]
105
  sims = (INDEX["feats"] @ q).cpu() # [N]
106
- topk = torch.topk(sims, k=min(int(k), sims.shape[0]))
 
107
 
108
  items = []
109
  for idx in topk.indices.tolist():
110
  cap = f"id: {INDEX['ids'][idx]} score: {float(sims[idx]):.3f} band: {INDEX['bands'][idx]}"
111
  items.append((INDEX["thumbs"][idx], cap))
112
- return items, f"Returned {len(items)} results."
 
 
113
 
114
  # ---------- UI ----------
115
  with gr.Blocks() as demo:
 
66
  return [], "Build the index first."
67
 
68
  with torch.no_grad():
69
+ if text_query and str(text_query).strip():
70
+ inputs = processor(text=[str(text_query).strip()], return_tensors="pt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  q = model.get_text_features(**inputs) # [1, 512]
72
  elif image_query is not None:
73
  pil = image_query.convert("RGB")
 
78
 
79
  q = F.normalize(q, p=2, dim=-1)[0] # [512]
80
  sims = (INDEX["feats"] @ q).cpu() # [N]
81
+ k = min(int(k), sims.shape[0])
82
+ topk = torch.topk(sims, k=k)
83
 
84
  items = []
85
  for idx in topk.indices.tolist():
86
  cap = f"id: {INDEX['ids'][idx]} score: {float(sims[idx]):.3f} band: {INDEX['bands'][idx]}"
87
  items.append((INDEX["thumbs"][idx], cap))
88
+
89
+ return items, f"Returned {k} results."
90
+
91
 
92
  # ---------- UI ----------
93
  with gr.Blocks() as demo: