hackerloi45 commited on
Commit
7fae8fb
·
1 Parent(s): 571e22c

Fix CLIrrr2 model issue in appetete333.py

Browse files
Files changed (1) hide show
  1. app.py +109 -80
app.py CHANGED
@@ -1,140 +1,169 @@
 
1
  import os
2
  import uuid
3
  import io
4
  import base64
5
  from PIL import Image
6
  import gradio as gr
 
 
 
7
  from sentence_transformers import SentenceTransformer
8
- import google.generativeai as genai
 
 
 
 
9
  from qdrant_client import QdrantClient
10
  from qdrant_client.http.models import VectorParams, Distance, PointStruct
11
 
12
- # --- Configuration ---
13
- GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
14
- QDRANT_URL = os.environ.get("QDRANT_URL")
15
- QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY")
16
-
17
- # --- Model Loading and Client Initialization ---
18
- print("Loading CLIP model...")
 
 
 
 
19
  MODEL_ID = "sentence-transformers/clip-ViT-B-32-multilingual-v1"
20
  clip_model = SentenceTransformer(MODEL_ID)
21
 
22
- # Configure the Gemini client
23
- if GEMINI_API_KEY:
24
- genai.configure(api_key=GEMINI_API_KEY)
25
 
26
  if not QDRANT_URL:
27
- raise RuntimeError("Set QDRANT_URL env var")
28
- qclient = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)
29
 
30
- # --- Qdrant Collection Setup ---
31
  COLLECTION = "lost_found_items"
32
  VECTOR_SIZE = 512
33
- if not qclient.collection_exists(COLLECTION):
34
- print(f"Creating collection: {COLLECTION}")
35
- qclient.create_collection(
36
- collection_name=COLLECTION,
37
- vectors_config=VectorParams(size=VECTOR_SIZE, distance=Distance.COSINE),
38
- )
39
-
40
- # --- Core Functions ---
 
 
 
 
 
 
41
  def embed_text(text: str):
42
- """Generates an embedding for the given text."""
43
  return clip_model.encode(text, convert_to_numpy=True)
44
 
45
  def embed_image_pil(pil_img: Image.Image):
46
- """Generates an embedding for the given PIL image."""
47
  return clip_model.encode(pil_img, convert_to_numpy=True)
48
 
49
- def gen_tags_from_image_file(img_bytes: io.BytesIO) -> str:
50
- """Generates descriptive tags for an image using the Gemini API."""
51
- if not GEMINI_API_KEY:
52
- print("Warning: GEMINI_API_KEY not set. Skipping tag generation.")
53
  return ""
54
  try:
55
- img = Image.open(img_bytes)
56
- model = genai.GenerativeModel('gemini-pro-vision')
57
- prompt = ("Give 4 short tags (comma-separated) describing this item in the image. "
58
- "Respond only with tags.")
59
- resp = model.generate_content([prompt, img])
60
- return resp.text.strip()
61
- except Exception as e:
62
- print(f"Error calling Gemini API: {e}")
 
 
 
 
63
  return ""
64
 
65
- def add_item(mode: str, uploaded_image: Image.Image, text_description: str):
66
- """Adds a new lost or found item to the database."""
67
- if not uploaded_image and not text_description:
68
- return "Error: Please provide either an image or a description."
69
-
70
  item_id = str(uuid.uuid4())
71
  payload = {"mode": mode, "text": text_description}
72
 
73
- if uploaded_image:
74
  img_bytes = io.BytesIO()
75
- uploaded_image.save(img_bytes, format="PNG")
76
  img_bytes.seek(0)
77
 
78
  vec = embed_image_pil(uploaded_image).tolist()
79
  payload["has_image"] = True
 
80
  payload["tags"] = gen_tags_from_image_file(img_bytes)
81
-
82
- img_bytes.seek(0)
83
- payload["image_b64"] = base64.b64encode(img_bytes.read()).decode("utf-8")
84
  else:
85
  vec = embed_text(text_description).tolist()
86
  payload["has_image"] = False
87
- payload["tags"] = ""
88
-
89
- point = PointStruct(id=item_id, vector=vec, payload=payload)
90
- qclient.upsert(collection_name=COLLECTION, points=[point], wait=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
  return f"Saved item id: {item_id}\nTags: {payload.get('tags','')}"
93
 
94
- def search_items(query_image: Image.Image, query_text: str, limit: int = 5):
95
- """Searches for similar items in the database."""
96
- if query_image:
 
 
97
  qvec = embed_image_pil(query_image).tolist()
98
- elif query_text:
99
  qvec = embed_text(query_text).tolist()
100
  else:
101
- return "Provide a query image or text to search."
 
 
 
 
 
102
 
103
- hits = qclient.search(collection_name=COLLECTION, query_vector=qvec, limit=limit)
104
  if not hits:
105
- return "No results found."
106
 
107
  results = []
108
  for h in hits:
109
  payload = h.payload or {}
110
- score = getattr(h, "score", 0)
111
  results.append(
112
- f"ID: {h.id}\nScore: {float(score):.4f}\nMode: {payload.get('mode','')}\n"
113
- f"Tags: {payload.get('tags','')}\nText: {payload.get('text','')}\n"
114
  )
115
-
116
  return "\n\n".join(results)
117
 
118
- # --- Gradio UI ---
119
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
120
- gr.Markdown("## Lost & Found Helper")
121
- gr.Markdown("Add items that were lost or found, and search for them using a photo or description.")
122
-
123
  with gr.Row():
124
- with gr.Column(scale=1):
125
- gr.Markdown("### Add an Item")
126
- mode = gr.Radio(["lost", "found"], value="lost", label="I have...")
127
- upload_img = gr.Image(type="pil", label="Item Photo (optional)")
128
- text_desc = gr.Textbox(lines=2, placeholder="e.g., 'red backpack with a keychain'", label="Description")
129
- add_btn = gr.Button("Add Item", variant="primary")
130
- add_out = gr.Textbox(interactive=False, label="Result", lines=3)
131
-
132
- with gr.Column(scale=2):
133
- gr.Markdown("### Search for an Item")
134
- query_img = gr.Image(type="pil", label="Search by Image (optional)")
135
- query_text = gr.Textbox(lines=2, placeholder="e.g., 'blue water bottle'", label="Search by Text (optional)")
136
- search_btn = gr.Button("Search", variant="primary")
137
- search_out = gr.Textbox(interactive=False, label="Search Results", lines=10)
138
 
139
  add_btn.click(add_item, inputs=[mode, upload_img, text_desc], outputs=[add_out])
140
  search_btn.click(search_items, inputs=[query_img, query_text], outputs=[search_out])
 
1
+ # app.py
2
  import os
3
  import uuid
4
  import io
5
  import base64
6
  from PIL import Image
7
  import gradio as gr
8
+ import numpy as np
9
+
10
+ # CLIP via Sentence-Transformers
11
  from sentence_transformers import SentenceTransformer
12
+
13
+ # Gemini (Google) client
14
+ from google import genai
15
+
16
+ # Qdrant client & helpers
17
  from qdrant_client import QdrantClient
18
  from qdrant_client.http.models import VectorParams, Distance, PointStruct
19
 
20
+ # -------------------------
21
+ # CONFIG (reads env vars)
22
+ # -------------------------
23
+ GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY", "").strip()
24
+ QDRANT_URL = os.environ.get("QDRANT_URL", "").strip()
25
+ QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY", "").strip()
26
+
27
+ # -------------------------
28
+ # Initialize clients/models
29
+ # -------------------------
30
+ print("Loading CLIP model (this may take 20-60s the first time)...")
31
  MODEL_ID = "sentence-transformers/clip-ViT-B-32-multilingual-v1"
32
  clip_model = SentenceTransformer(MODEL_ID)
33
 
34
+ genai_client = genai.Client(api_key=GEMINI_API_KEY) if GEMINI_API_KEY else None
 
 
35
 
36
  if not QDRANT_URL:
37
+ raise RuntimeError("Please set QDRANT_URL environment variable")
 
38
 
39
+ qclient = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)
40
  COLLECTION = "lost_found_items"
41
  VECTOR_SIZE = 512
42
+
43
+ # Create collection if missing
44
+ try:
45
+ if not qclient.collection_exists(COLLECTION):
46
+ qclient.create_collection(
47
+ collection_name=COLLECTION,
48
+ vectors_config=VectorParams(size=VECTOR_SIZE, distance=Distance.COSINE),
49
+ )
50
+ except Exception as e:
51
+ print("Error initializing Qdrant collection:", e)
52
+
53
+ # -------------------------
54
+ # Helpers
55
+ # -------------------------
56
  def embed_text(text: str):
 
57
  return clip_model.encode(text, convert_to_numpy=True)
58
 
59
  def embed_image_pil(pil_img: Image.Image):
60
+ pil_img = pil_img.convert("RGB")
61
  return clip_model.encode(pil_img, convert_to_numpy=True)
62
 
63
+ def gen_tags_from_image_file(image_bytes: io.BytesIO) -> str:
64
+ if genai_client is None:
 
 
65
  return ""
66
  try:
67
+ file_obj = genai_client.files.upload(file=image_bytes)
68
+ prompt_text = (
69
+ "Give 4 short tags (comma-separated) describing this item in the image. "
70
+ "Tags should be short single words or two-word phrases (e.g. 'black backpack', 'water bottle'). "
71
+ "Respond only with tags, no extra explanation."
72
+ )
73
+ response = genai_client.models.generate_content(
74
+ model="gemini-2.5-flash",
75
+ contents=[prompt_text, file_obj],
76
+ )
77
+ return response.text.strip()
78
+ except Exception:
79
  return ""
80
 
81
+ # -------------------------
82
+ # App logic: add item
83
+ # -------------------------
84
+ def add_item(mode: str, uploaded_image, text_description: str):
 
85
  item_id = str(uuid.uuid4())
86
  payload = {"mode": mode, "text": text_description}
87
 
88
+ if uploaded_image is not None:
89
  img_bytes = io.BytesIO()
90
+ uploaded_image.convert("RGB").save(img_bytes, format="PNG")
91
  img_bytes.seek(0)
92
 
93
  vec = embed_image_pil(uploaded_image).tolist()
94
  payload["has_image"] = True
95
+
96
  payload["tags"] = gen_tags_from_image_file(img_bytes)
97
+ payload["image_b64"] = base64.b64encode(img_bytes.getvalue()).decode("utf-8")
 
 
98
  else:
99
  vec = embed_text(text_description).tolist()
100
  payload["has_image"] = False
101
+ if genai_client:
102
+ try:
103
+ resp = genai_client.models.generate_content(
104
+ model="gemini-2.5-flash",
105
+ contents=f"Give 4 short, comma-separated tags for this item described as: {text_description}. Reply only with tags."
106
+ )
107
+ payload["tags"] = resp.text.strip()
108
+ except Exception:
109
+ payload["tags"] = ""
110
+ else:
111
+ payload["tags"] = ""
112
+
113
+ try:
114
+ point = PointStruct(id=item_id, vector=vec, payload=payload)
115
+ qclient.upsert(collection_name=COLLECTION, points=[point], wait=True)
116
+ except Exception as e:
117
+ return f"Error saving to Qdrant: {e}"
118
 
119
  return f"Saved item id: {item_id}\nTags: {payload.get('tags','')}"
120
 
121
+ # -------------------------
122
+ # App logic: search
123
+ # -------------------------
124
+ def search_items(query_image, query_text, limit: int = 5):
125
+ if query_image is not None:
126
  qvec = embed_image_pil(query_image).tolist()
127
+ elif query_text and len(query_text.strip()) > 0:
128
  qvec = embed_text(query_text).tolist()
129
  else:
130
+ return "Please provide a query image or some query text."
131
+
132
+ try:
133
+ hits = qclient.search(collection_name=COLLECTION, query_vector=qvec, limit=limit)
134
+ except Exception as e:
135
+ return f"Error querying Qdrant: {e}"
136
 
 
137
  if not hits:
138
+ return "No results."
139
 
140
  results = []
141
  for h in hits:
142
  payload = h.payload or {}
143
+ score = getattr(h, "score", None)
144
  results.append(
145
+ f"id:{h.id} score:{float(score) if score else None} mode:{payload.get('mode','')} tags:{payload.get('tags','')} text:{payload.get('text','')}"
 
146
  )
 
147
  return "\n\n".join(results)
148
 
149
+ # -------------------------
150
+ # Gradio UI
151
+ # -------------------------
152
+ with gr.Blocks(title="Lost & Found Simple Helper") as demo:
153
+ gr.Markdown("## Lost & Found Helper (image/text search) — upload items, then search by image or text.")
154
  with gr.Row():
155
+ with gr.Column():
156
+ mode = gr.Radio(choices=["lost", "found"], value="lost", label="Add as")
157
+ upload_img = gr.Image(type="pil", label="Item photo (optional)")
158
+ text_desc = gr.Textbox(lines=2, placeholder="Short description (e.g. 'black backpack with blue zipper')", label="Description (optional)")
159
+ add_btn = gr.Button("Add item")
160
+ add_out = gr.Textbox(label="Add result", interactive=False)
161
+ with gr.Column():
162
+ gr.Markdown("### Search")
163
+ query_img = gr.Image(type="pil", label="Search by image (optional)")
164
+ query_text = gr.Textbox(lines=2, label="Search by text (optional)")
165
+ search_btn = gr.Button("Search")
166
+ search_out = gr.Textbox(label="Search results", interactive=False)
 
 
167
 
168
  add_btn.click(add_item, inputs=[mode, upload_img, text_desc], outputs=[add_out])
169
  search_btn.click(search_items, inputs=[query_img, query_text], outputs=[search_out])