hackerloi45 commited on
Commit
08d340c
·
1 Parent(s): 0d5f8a4

Fix CLIrrr2 model issue in appetete.py

Browse files
Files changed (1) hide show
  1. app.py +36 -13
app.py CHANGED
@@ -2,13 +2,17 @@
2
  import os
3
  import uuid
4
  import io
 
5
  from PIL import Image
6
  import gradio as gr
7
  from sentence_transformers import SentenceTransformer
8
- from google import genai
9
  from qdrant_client import QdrantClient
10
  from qdrant_client.http.models import VectorParams, Distance, PointStruct
11
 
 
 
 
12
  GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
13
  QDRANT_URL = os.environ.get("QDRANT_URL")
14
  QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY")
@@ -17,7 +21,8 @@ print("Loading CLIP model...")
17
  MODEL_ID = "sentence-transformers/clip-ViT-B-32-multilingual-v1"
18
  clip_model = SentenceTransformer(MODEL_ID)
19
 
20
- genai_client = genai.Client(api_key=GEMINI_API_KEY) if GEMINI_API_KEY else None
 
21
 
22
  if not QDRANT_URL:
23
  raise RuntimeError("Set QDRANT_URL env var")
@@ -25,7 +30,8 @@ qclient = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)
25
 
26
  COLLECTION = "lost_found_items"
27
  VECTOR_SIZE = 512
28
- if not qclient.collection_exists(COLLECTION):
 
29
  qclient.create_collection(
30
  collection_name=COLLECTION,
31
  vectors_config=VectorParams(size=VECTOR_SIZE, distance=Distance.COSINE),
@@ -37,17 +43,30 @@ def embed_text(text: str):
37
  def embed_image_pil(pil_img: Image.Image):
38
  return clip_model.encode(pil_img, convert_to_numpy=True)
39
 
40
- def gen_tags_from_image_file(img_bytes: io.BytesIO) -> str:
41
- if not genai_client:
 
 
 
42
  return ""
43
  try:
44
- file_obj = genai_client.files.upload(file=img_bytes)
 
 
 
 
 
 
45
  prompt = ("Give 4 short tags (comma-separated) describing this item in the image. "
46
  "Respond only with tags.")
47
- resp = genai_client.models.generate_content(model="gemini-2.5-flash",
48
- contents=[prompt, file_obj])
 
 
 
49
  return resp.text.strip()
50
- except Exception:
 
51
  return ""
52
 
53
  def add_item(mode: str, uploaded_image, text_description: str):
@@ -55,12 +74,16 @@ def add_item(mode: str, uploaded_image, text_description: str):
55
  payload = {"mode": mode, "text": text_description}
56
 
57
  if uploaded_image:
58
- img_bytes = io.BytesIO()
59
- uploaded_image.save(img_bytes, format="PNG")
60
- img_bytes.seek(0)
61
  vec = embed_image_pil(uploaded_image).tolist()
62
  payload["has_image"] = True
63
- payload["tags"] = gen_tags_from_image_file(img_bytes)
 
 
 
 
 
 
64
  img_bytes.seek(0)
65
  payload["image_b64"] = base64.b64encode(img_bytes.read()).decode("utf-8")
66
  else:
 
2
  import os
3
  import uuid
4
  import io
5
+ import base64 # <-- FIX: This was missing
6
  from PIL import Image
7
  import gradio as gr
8
  from sentence_transformers import SentenceTransformer
9
+ import google.generativeai as genai # <-- FIX: Correct import for the genai library
10
  from qdrant_client import QdrantClient
11
  from qdrant_client.http.models import VectorParams, Distance, PointStruct
12
 
13
+ # Note: The QDRANT_URL, QDRANT_API_KEY, and GEMINI_API_KEY environment variables
14
+ # must be set for this application to work correctly.
15
+
16
  GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
17
  QDRANT_URL = os.environ.get("QDRANT_URL")
18
  QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY")
 
21
  MODEL_ID = "sentence-transformers/clip-ViT-B-32-multilingual-v1"
22
  clip_model = SentenceTransformer(MODEL_ID)
23
 
24
+ # Initialize the GenAI client with the correct API key
25
+ genai.configure(api_key=GEMINI_API_KEY)
26
 
27
  if not QDRANT_URL:
28
  raise RuntimeError("Set QDRANT_URL env var")
 
30
 
31
  COLLECTION = "lost_found_items"
32
  VECTOR_SIZE = 512
33
+ # Only create the collection if it doesn't already exist
34
+ if not qclient.get_collections().collections:
35
  qclient.create_collection(
36
  collection_name=COLLECTION,
37
  vectors_config=VectorParams(size=VECTOR_SIZE, distance=Distance.COSINE),
 
43
  def embed_image_pil(pil_img: Image.Image):
44
  return clip_model.encode(pil_img, convert_to_numpy=True)
45
 
46
+ # FIX: This function is updated to take a PIL Image object directly and
47
+ # uses an inlineData object for the Gemini API call, as file upload is
48
+ # not supported for gemini-2.5-flash in this manner.
49
+ def gen_tags_from_image(pil_img: Image.Image) -> str:
50
+ if not GEMINI_API_KEY:
51
  return ""
52
  try:
53
+ # Convert PIL Image to a byte array
54
+ img_bytes = io.BytesIO()
55
+ pil_img.save(img_bytes, format="PNG")
56
+ img_bytes.seek(0)
57
+
58
+ # Use inlineData to pass the image to the model
59
+ model = genai.GenerativeModel("gemini-2.5-flash")
60
  prompt = ("Give 4 short tags (comma-separated) describing this item in the image. "
61
  "Respond only with tags.")
62
+ image_part = {
63
+ "mime_type": "image/png",
64
+ "data": img_bytes.getvalue()
65
+ }
66
+ resp = model.generate_content([prompt, image_part])
67
  return resp.text.strip()
68
+ except Exception as e:
69
+ print(f"Error generating tags: {e}")
70
  return ""
71
 
72
  def add_item(mode: str, uploaded_image, text_description: str):
 
74
  payload = {"mode": mode, "text": text_description}
75
 
76
  if uploaded_image:
77
+ # Use the PIL image directly for embedding
 
 
78
  vec = embed_image_pil(uploaded_image).tolist()
79
  payload["has_image"] = True
80
+
81
+ # FIX: Pass the PIL image object to the tag generation function
82
+ payload["tags"] = gen_tags_from_image(uploaded_image)
83
+
84
+ # Convert the PIL image to base64 string for storage in payload
85
+ img_bytes = io.BytesIO()
86
+ uploaded_image.save(img_bytes, format="PNG")
87
  img_bytes.seek(0)
88
  payload["image_b64"] = base64.b64encode(img_bytes.read()).decode("utf-8")
89
  else: