hackerloi45 commited on
Commit
571e22c
·
1 Parent(s): 08d340c

Fix CLIrrr2 model issue in appetete333.py

Browse files
Files changed (1) hide show
  1. app.py +54 -52
app.py CHANGED
@@ -1,89 +1,84 @@
1
- # app.py
2
  import os
3
  import uuid
4
  import io
5
- import base64 # <-- FIX: This was missing
6
  from PIL import Image
7
  import gradio as gr
8
  from sentence_transformers import SentenceTransformer
9
- import google.generativeai as genai # <-- FIX: Correct import for the genai library
10
  from qdrant_client import QdrantClient
11
  from qdrant_client.http.models import VectorParams, Distance, PointStruct
12
 
13
- # Note: The QDRANT_URL, QDRANT_API_KEY, and GEMINI_API_KEY environment variables
14
- # must be set for this application to work correctly.
15
-
16
  GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
17
  QDRANT_URL = os.environ.get("QDRANT_URL")
18
  QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY")
19
 
 
20
  print("Loading CLIP model...")
21
  MODEL_ID = "sentence-transformers/clip-ViT-B-32-multilingual-v1"
22
  clip_model = SentenceTransformer(MODEL_ID)
23
 
24
- # Initialize the GenAI client with the correct API key
25
- genai.configure(api_key=GEMINI_API_KEY)
 
26
 
27
  if not QDRANT_URL:
28
  raise RuntimeError("Set QDRANT_URL env var")
29
  qclient = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)
30
 
 
31
  COLLECTION = "lost_found_items"
32
  VECTOR_SIZE = 512
33
- # Only create the collection if it doesn't already exist
34
- if not qclient.get_collections().collections:
35
  qclient.create_collection(
36
  collection_name=COLLECTION,
37
  vectors_config=VectorParams(size=VECTOR_SIZE, distance=Distance.COSINE),
38
  )
39
 
 
40
  def embed_text(text: str):
 
41
  return clip_model.encode(text, convert_to_numpy=True)
42
 
43
  def embed_image_pil(pil_img: Image.Image):
 
44
  return clip_model.encode(pil_img, convert_to_numpy=True)
45
 
46
- # FIX: This function is updated to take a PIL Image object directly and
47
- # uses an inlineData object for the Gemini API call, as file upload is
48
- # not supported for gemini-2.5-flash in this manner.
49
- def gen_tags_from_image(pil_img: Image.Image) -> str:
50
  if not GEMINI_API_KEY:
 
51
  return ""
52
  try:
53
- # Convert PIL Image to a byte array
54
- img_bytes = io.BytesIO()
55
- pil_img.save(img_bytes, format="PNG")
56
- img_bytes.seek(0)
57
-
58
- # Use inlineData to pass the image to the model
59
- model = genai.GenerativeModel("gemini-2.5-flash")
60
  prompt = ("Give 4 short tags (comma-separated) describing this item in the image. "
61
  "Respond only with tags.")
62
- image_part = {
63
- "mime_type": "image/png",
64
- "data": img_bytes.getvalue()
65
- }
66
- resp = model.generate_content([prompt, image_part])
67
  return resp.text.strip()
68
  except Exception as e:
69
- print(f"Error generating tags: {e}")
70
  return ""
71
 
72
- def add_item(mode: str, uploaded_image, text_description: str):
 
 
 
 
73
  item_id = str(uuid.uuid4())
74
  payload = {"mode": mode, "text": text_description}
75
 
76
  if uploaded_image:
77
- # Use the PIL image directly for embedding
 
 
 
78
  vec = embed_image_pil(uploaded_image).tolist()
79
  payload["has_image"] = True
 
80
 
81
- # FIX: Pass the PIL image object to the tag generation function
82
- payload["tags"] = gen_tags_from_image(uploaded_image)
83
-
84
- # Convert the PIL image to base64 string for storage in payload
85
- img_bytes = io.BytesIO()
86
- uploaded_image.save(img_bytes, format="PNG")
87
  img_bytes.seek(0)
88
  payload["image_b64"] = base64.b64encode(img_bytes.read()).decode("utf-8")
89
  else:
@@ -96,43 +91,50 @@ def add_item(mode: str, uploaded_image, text_description: str):
96
 
97
  return f"Saved item id: {item_id}\nTags: {payload.get('tags','')}"
98
 
99
- def search_items(query_image, query_text, limit: int = 5):
 
100
  if query_image:
101
  qvec = embed_image_pil(query_image).tolist()
102
  elif query_text:
103
  qvec = embed_text(query_text).tolist()
104
  else:
105
- return "Provide query image or text."
106
 
107
  hits = qclient.search(collection_name=COLLECTION, query_vector=qvec, limit=limit)
108
  if not hits:
109
- return "No results."
110
 
111
  results = []
112
  for h in hits:
113
  payload = h.payload or {}
114
  score = getattr(h, "score", 0)
115
  results.append(
116
- f"ID:{h.id}\nScore:{float(score):.4f}\nMode:{payload.get('mode','')}\n"
117
- f"Tags:{payload.get('tags','')}\nText:{payload.get('text','')}\n"
118
  )
119
 
120
  return "\n\n".join(results)
121
 
122
- with gr.Blocks() as demo:
 
123
  gr.Markdown("## Lost & Found Helper")
 
 
124
  with gr.Row():
125
- with gr.Column():
126
- mode = gr.Radio(["lost", "found"], value="lost", label="Add as")
127
- upload_img = gr.Image(type="pil", label="Item photo (optional)")
128
- text_desc = gr.Textbox(lines=2, placeholder="Short description", label="Description")
129
- add_btn = gr.Button("Add item")
130
- add_out = gr.Textbox(interactive=False, label="Result")
131
- with gr.Column():
132
- query_img = gr.Image(type="pil", label="Search by image (optional)")
133
- query_text = gr.Textbox(lines=2, label="Search by text (optional)")
134
- search_btn = gr.Button("Search")
135
- search_out = gr.Textbox(interactive=False, label="Search results")
 
 
 
136
 
137
  add_btn.click(add_item, inputs=[mode, upload_img, text_desc], outputs=[add_out])
138
  search_btn.click(search_items, inputs=[query_img, query_text], outputs=[search_out])
 
 
1
  import os
2
  import uuid
3
  import io
4
+ import base64
5
  from PIL import Image
6
  import gradio as gr
7
  from sentence_transformers import SentenceTransformer
8
+ import google.generativeai as genai
9
  from qdrant_client import QdrantClient
10
  from qdrant_client.http.models import VectorParams, Distance, PointStruct
11
 
12
+ # --- Configuration ---
 
 
13
  GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
14
  QDRANT_URL = os.environ.get("QDRANT_URL")
15
  QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY")
16
 
17
+ # --- Model Loading and Client Initialization ---
18
  print("Loading CLIP model...")
19
  MODEL_ID = "sentence-transformers/clip-ViT-B-32-multilingual-v1"
20
  clip_model = SentenceTransformer(MODEL_ID)
21
 
22
+ # Configure the Gemini client
23
+ if GEMINI_API_KEY:
24
+ genai.configure(api_key=GEMINI_API_KEY)
25
 
26
  if not QDRANT_URL:
27
  raise RuntimeError("Set QDRANT_URL env var")
28
  qclient = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)
29
 
30
+ # --- Qdrant Collection Setup ---
31
  COLLECTION = "lost_found_items"
32
  VECTOR_SIZE = 512
33
+ if not qclient.collection_exists(COLLECTION):
34
+ print(f"Creating collection: {COLLECTION}")
35
  qclient.create_collection(
36
  collection_name=COLLECTION,
37
  vectors_config=VectorParams(size=VECTOR_SIZE, distance=Distance.COSINE),
38
  )
39
 
40
+ # --- Core Functions ---
41
  def embed_text(text: str):
42
+ """Generates an embedding for the given text."""
43
  return clip_model.encode(text, convert_to_numpy=True)
44
 
45
  def embed_image_pil(pil_img: Image.Image):
46
+ """Generates an embedding for the given PIL image."""
47
  return clip_model.encode(pil_img, convert_to_numpy=True)
48
 
49
+ def gen_tags_from_image_file(img_bytes: io.BytesIO) -> str:
50
+ """Generates descriptive tags for an image using the Gemini API."""
 
 
51
  if not GEMINI_API_KEY:
52
+ print("Warning: GEMINI_API_KEY not set. Skipping tag generation.")
53
  return ""
54
  try:
55
+ img = Image.open(img_bytes)
56
+ model = genai.GenerativeModel('gemini-pro-vision')
 
 
 
 
 
57
  prompt = ("Give 4 short tags (comma-separated) describing this item in the image. "
58
  "Respond only with tags.")
59
+ resp = model.generate_content([prompt, img])
 
 
 
 
60
  return resp.text.strip()
61
  except Exception as e:
62
+ print(f"Error calling Gemini API: {e}")
63
  return ""
64
 
65
+ def add_item(mode: str, uploaded_image: Image.Image, text_description: str):
66
+ """Adds a new lost or found item to the database."""
67
+ if not uploaded_image and not text_description:
68
+ return "Error: Please provide either an image or a description."
69
+
70
  item_id = str(uuid.uuid4())
71
  payload = {"mode": mode, "text": text_description}
72
 
73
  if uploaded_image:
74
+ img_bytes = io.BytesIO()
75
+ uploaded_image.save(img_bytes, format="PNG")
76
+ img_bytes.seek(0)
77
+
78
  vec = embed_image_pil(uploaded_image).tolist()
79
  payload["has_image"] = True
80
+ payload["tags"] = gen_tags_from_image_file(img_bytes)
81
 
 
 
 
 
 
 
82
  img_bytes.seek(0)
83
  payload["image_b64"] = base64.b64encode(img_bytes.read()).decode("utf-8")
84
  else:
 
91
 
92
  return f"Saved item id: {item_id}\nTags: {payload.get('tags','')}"
93
 
94
+ def search_items(query_image: Image.Image, query_text: str, limit: int = 5):
95
+ """Searches for similar items in the database."""
96
  if query_image:
97
  qvec = embed_image_pil(query_image).tolist()
98
  elif query_text:
99
  qvec = embed_text(query_text).tolist()
100
  else:
101
+ return "Provide a query image or text to search."
102
 
103
  hits = qclient.search(collection_name=COLLECTION, query_vector=qvec, limit=limit)
104
  if not hits:
105
+ return "No results found."
106
 
107
  results = []
108
  for h in hits:
109
  payload = h.payload or {}
110
  score = getattr(h, "score", 0)
111
  results.append(
112
+ f"ID: {h.id}\nScore: {float(score):.4f}\nMode: {payload.get('mode','')}\n"
113
+ f"Tags: {payload.get('tags','')}\nText: {payload.get('text','')}\n"
114
  )
115
 
116
  return "\n\n".join(results)
117
 
118
+ # --- Gradio UI ---
119
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
120
  gr.Markdown("## Lost & Found Helper")
121
+ gr.Markdown("Add items that were lost or found, and search for them using a photo or description.")
122
+
123
  with gr.Row():
124
+ with gr.Column(scale=1):
125
+ gr.Markdown("### Add an Item")
126
+ mode = gr.Radio(["lost", "found"], value="lost", label="I have...")
127
+ upload_img = gr.Image(type="pil", label="Item Photo (optional)")
128
+ text_desc = gr.Textbox(lines=2, placeholder="e.g., 'red backpack with a keychain'", label="Description")
129
+ add_btn = gr.Button("Add Item", variant="primary")
130
+ add_out = gr.Textbox(interactive=False, label="Result", lines=3)
131
+
132
+ with gr.Column(scale=2):
133
+ gr.Markdown("### Search for an Item")
134
+ query_img = gr.Image(type="pil", label="Search by Image (optional)")
135
+ query_text = gr.Textbox(lines=2, placeholder="e.g., 'blue water bottle'", label="Search by Text (optional)")
136
+ search_btn = gr.Button("Search", variant="primary")
137
+ search_out = gr.Textbox(interactive=False, label="Search Results", lines=10)
138
 
139
  add_btn.click(add_item, inputs=[mode, upload_img, text_desc], outputs=[add_out])
140
  search_btn.click(search_items, inputs=[query_img, query_text], outputs=[search_out])