hackerloi45 commited on
Commit
e8736ae
·
1 Parent(s): 746bf5b

Fix CLIP model issue in app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -54
app.py CHANGED
@@ -19,29 +19,22 @@ from qdrant_client.http.models import VectorParams, Distance, PointStruct
19
  # -------------------------
20
  # CONFIG (reads env vars)
21
  # -------------------------
22
- GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY") # set in Hugging Face Space secrets
23
- QDRANT_URL = os.environ.get("QDRANT_URL") # set in Hugging Face Space secrets
24
- QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY") # set in Hugging Face Space secrets
25
-
26
- # Local fallbacks (for local testing) - set them before running locally if needed:
27
- # os.environ["GEMINI_API_KEY"]="..." ; os.environ["QDRANT_URL"]="..." ; os.environ["QDRANT_API_KEY"]="..."
28
 
29
  # -------------------------
30
  # Initialize clients/models
31
  # -------------------------
32
  print("Loading CLIP model (this may take 20-60s the first time)...")
33
  MODEL_ID = "sentence-transformers/clip-ViT-B-32-multilingual-v1"
34
- clip_model = SentenceTransformer(MODEL_ID) # model maps text & images to same vector space
35
 
36
- # Gemini client (for tags/captions)
37
- if GEMINI_API_KEY:
38
- genai_client = genai.Client(api_key=GEMINI_API_KEY)
39
- else:
40
- genai_client = None
41
 
42
  # Qdrant client
43
  if not QDRANT_URL:
44
- # If you prefer local Qdrant for dev: client = QdrantClient(":memory:") or local url
45
  raise RuntimeError("Please set QDRANT_URL environment variable")
46
  qclient = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)
47
 
@@ -59,26 +52,20 @@ if not qclient.collection_exists(COLLECTION):
59
  # Helpers
60
  # -------------------------
61
  def embed_text(text: str):
62
- vec = clip_model.encode(text, convert_to_numpy=True)
63
  return vec
64
 
65
  def embed_image_pil(pil_img: Image.Image):
66
- # sentence-transformers supports directly encoding a PIL image for CLIP models
67
- vec = clip_model.encode(pil_img, convert_to_numpy=True)
68
  return vec
69
 
70
  def gen_tags_from_image_file(local_path: str) -> str:
71
- """Upload image file to Gemini and ask for 4 short tags.
72
- Returns the raw text response (expected comma-separated tags)."""
73
- if genai_client is None:
74
  return ""
75
- # Upload file (Gemini Developer API supports client.files.upload)
76
  file_obj = genai_client.files.upload(file=local_path)
77
- # Ask Gemini: produce short tags only
78
  prompt_text = (
79
  "Give 4 short tags (comma-separated) describing this item in the image. "
80
- "Tags should be short single words or two-word phrases (e.g. 'black backpack', 'water bottle'). "
81
- "Respond only with tags, no extra explanation."
82
  )
83
  response = genai_client.models.generate_content(
84
  model="gemini-2.5-flash",
@@ -90,36 +77,25 @@ def gen_tags_from_image_file(local_path: str) -> str:
90
  # App logic: add item
91
  # -------------------------
92
  def add_item(mode: str, uploaded_image, text_description: str):
93
- """
94
- mode: 'lost' or 'found'
95
- uploaded_image: PIL image or None
96
- text_description: str
97
- """
98
  item_id = str(uuid.uuid4())
99
  payload = {"mode": mode, "text": text_description}
100
 
101
  if uploaded_image is not None:
102
- # Save image to temp file (so we can upload to Gemini)
103
  tmp_path = f"/tmp/{item_id}.png"
104
  uploaded_image.save(tmp_path)
105
- # embed image
106
  vec = embed_image_pil(uploaded_image).tolist()
107
  payload["has_image"] = True
108
- # optional: get tags from Gemini (if available)
109
  try:
110
  tags = gen_tags_from_image_file(tmp_path)
111
- except Exception as e:
112
  tags = ""
113
  payload["tags"] = tags
114
- # store image bytes (tiny) so we can show result in the UI (base64)
115
  with open(tmp_path, "rb") as f:
116
  b64 = f.read()
117
- payload["image_b64"] = True # flag (we will return/show image via Gradio from file bytes)
118
  else:
119
- # only text provided
120
  vec = embed_text(text_description).tolist()
121
  payload["has_image"] = False
122
- # ask Gemini to suggest tags from text
123
  if genai_client:
124
  try:
125
  resp = genai_client.models.generate_content(
@@ -132,30 +108,24 @@ def add_item(mode: str, uploaded_image, text_description: str):
132
  else:
133
  payload["tags"] = ""
134
 
135
- # Upsert into Qdrant
136
  point = PointStruct(id=item_id, vector=vec, payload=payload)
137
  qclient.upsert(collection_name=COLLECTION, points=[point], wait=True)
138
 
139
  return f"Saved item id: {item_id}\nTags: {payload.get('tags','')}"
140
 
141
-
142
  # -------------------------
143
  # App logic: search
144
  # -------------------------
145
  def search_items(query_image, query_text, limit: int = 5):
146
- # produce query embedding
147
  if query_image is not None:
148
  qvec = embed_image_pil(query_image).tolist()
149
- q_type = "image"
150
- else:
151
- if (not query_text) or (len(query_text.strip()) == 0):
152
- return "Please provide a query image or some query text."
153
  qvec = embed_text(query_text).tolist()
154
- q_type = "text"
 
155
 
156
  hits = qclient.search(collection_name=COLLECTION, query_vector=qvec, limit=limit)
157
 
158
- # Format output (list)
159
  results = []
160
  for h in hits:
161
  payload = h.payload or {}
@@ -163,36 +133,35 @@ def search_items(query_image, query_text, limit: int = 5):
163
  results.append(
164
  {
165
  "id": h.id,
166
- "score": float(score) if score is not None else None,
167
  "mode": payload.get("mode", ""),
168
  "text": payload.get("text", ""),
169
  "tags": payload.get("tags", ""),
170
  "has_image": payload.get("has_image", False),
171
  }
172
  )
173
- # Return a simple list for Gradio to show
174
  if not results:
175
  return "No results."
176
- # Convert to text for display
177
- out_lines = []
178
- for r in results:
179
- out_lines.append(f"id:{r['id']} score:{r['score']:.4f} mode:{r['mode']} tags:{r['tags']} text:{r['text']}")
180
  return "\n\n".join(out_lines)
181
 
182
  # -------------------------
183
  # Gradio UI
184
  # -------------------------
185
  with gr.Blocks(title="Lost & Found — Simple Helper") as demo:
186
- gr.Markdown("## Lost & Found Helper (image/text search) — upload items, then search by image or text.")
187
  with gr.Row():
188
  with gr.Column():
189
  mode = gr.Radio(choices=["lost", "found"], value="lost", label="Add as")
190
  upload_img = gr.Image(type="pil", label="Item photo (optional)")
191
- text_desc = gr.Textbox(lines=2, placeholder="Short description (e.g. 'black backpack with blue zipper')", label="Description (optional)")
192
  add_btn = gr.Button("Add item")
193
  add_out = gr.Textbox(label="Add result", interactive=False)
194
  with gr.Column():
195
- gr.Markdown("### Search")
196
  query_img = gr.Image(type="pil", label="Search by image (optional)")
197
  query_text = gr.Textbox(lines=2, label="Search by text (optional)")
198
  search_btn = gr.Button("Search")
 
19
  # -------------------------
20
  # CONFIG (reads env vars)
21
  # -------------------------
22
+ GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
23
+ QDRANT_URL = os.environ.get("QDRANT_URL")
24
+ QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY")
 
 
 
25
 
26
  # -------------------------
27
  # Initialize clients/models
28
  # -------------------------
29
  print("Loading CLIP model (this may take 20-60s the first time)...")
30
  MODEL_ID = "sentence-transformers/clip-ViT-B-32-multilingual-v1"
31
+ clip_model = SentenceTransformer(MODEL_ID)
32
 
33
+ # Gemini client
34
+ genai_client = genai.Client(api_key=GEMINI_API_KEY) if GEMINI_API_KEY else None
 
 
 
35
 
36
  # Qdrant client
37
  if not QDRANT_URL:
 
38
  raise RuntimeError("Please set QDRANT_URL environment variable")
39
  qclient = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)
40
 
 
52
  # Helpers
53
  # -------------------------
54
  def embed_text(text: str):
55
+ vec = clip_model.encode([text], convert_to_numpy=True)[0] # wrap in list
56
  return vec
57
 
58
  def embed_image_pil(pil_img: Image.Image):
59
+ vec = clip_model.encode([pil_img], convert_to_numpy=True)[0] # wrap in list
 
60
  return vec
61
 
62
  def gen_tags_from_image_file(local_path: str) -> str:
63
+ if not genai_client:
 
 
64
  return ""
 
65
  file_obj = genai_client.files.upload(file=local_path)
 
66
  prompt_text = (
67
  "Give 4 short tags (comma-separated) describing this item in the image. "
68
+ "Tags should be short single words or two-word phrases. Respond only with tags."
 
69
  )
70
  response = genai_client.models.generate_content(
71
  model="gemini-2.5-flash",
 
77
  # App logic: add item
78
  # -------------------------
79
  def add_item(mode: str, uploaded_image, text_description: str):
 
 
 
 
 
80
  item_id = str(uuid.uuid4())
81
  payload = {"mode": mode, "text": text_description}
82
 
83
  if uploaded_image is not None:
 
84
  tmp_path = f"/tmp/{item_id}.png"
85
  uploaded_image.save(tmp_path)
 
86
  vec = embed_image_pil(uploaded_image).tolist()
87
  payload["has_image"] = True
 
88
  try:
89
  tags = gen_tags_from_image_file(tmp_path)
90
+ except Exception:
91
  tags = ""
92
  payload["tags"] = tags
 
93
  with open(tmp_path, "rb") as f:
94
  b64 = f.read()
95
+ payload["image_b64"] = True
96
  else:
 
97
  vec = embed_text(text_description).tolist()
98
  payload["has_image"] = False
 
99
  if genai_client:
100
  try:
101
  resp = genai_client.models.generate_content(
 
108
  else:
109
  payload["tags"] = ""
110
 
 
111
  point = PointStruct(id=item_id, vector=vec, payload=payload)
112
  qclient.upsert(collection_name=COLLECTION, points=[point], wait=True)
113
 
114
  return f"Saved item id: {item_id}\nTags: {payload.get('tags','')}"
115
 
 
116
  # -------------------------
117
  # App logic: search
118
  # -------------------------
119
  def search_items(query_image, query_text, limit: int = 5):
 
120
  if query_image is not None:
121
  qvec = embed_image_pil(query_image).tolist()
122
+ elif query_text and query_text.strip():
 
 
 
123
  qvec = embed_text(query_text).tolist()
124
+ else:
125
+ return "Please provide a query image or some query text."
126
 
127
  hits = qclient.search(collection_name=COLLECTION, query_vector=qvec, limit=limit)
128
 
 
129
  results = []
130
  for h in hits:
131
  payload = h.payload or {}
 
133
  results.append(
134
  {
135
  "id": h.id,
136
+ "score": float(score) if score else None,
137
  "mode": payload.get("mode", ""),
138
  "text": payload.get("text", ""),
139
  "tags": payload.get("tags", ""),
140
  "has_image": payload.get("has_image", False),
141
  }
142
  )
143
+
144
  if not results:
145
  return "No results."
146
+ out_lines = [
147
+ f"id:{r['id']} score:{r['score']:.4f} mode:{r['mode']} tags:{r['tags']} text:{r['text']}"
148
+ for r in results
149
+ ]
150
  return "\n\n".join(out_lines)
151
 
152
  # -------------------------
153
  # Gradio UI
154
  # -------------------------
155
  with gr.Blocks(title="Lost & Found — Simple Helper") as demo:
156
+ gr.Markdown("## Lost & Found Helper (image/text search)")
157
  with gr.Row():
158
  with gr.Column():
159
  mode = gr.Radio(choices=["lost", "found"], value="lost", label="Add as")
160
  upload_img = gr.Image(type="pil", label="Item photo (optional)")
161
+ text_desc = gr.Textbox(lines=2, placeholder="Short description", label="Description (optional)")
162
  add_btn = gr.Button("Add item")
163
  add_out = gr.Textbox(label="Add result", interactive=False)
164
  with gr.Column():
 
165
  query_img = gr.Image(type="pil", label="Search by image (optional)")
166
  query_text = gr.Textbox(lines=2, label="Search by text (optional)")
167
  search_btn = gr.Button("Search")