Spaces:
Runtime error
Runtime error
| import os | |
| import uuid | |
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| from qdrant_client import QdrantClient | |
| from qdrant_client.http.models import VectorParams, Distance, PointStruct | |
| from sentence_transformers import SentenceTransformer | |
| # =============================== | |
| # Config / Setup | |
| # =============================== | |
| UPLOAD_DIR = "uploaded_images" | |
| os.makedirs(UPLOAD_DIR, exist_ok=True) | |
| COLLECTION = "lost_and_found" | |
| qclient = QdrantClient(":memory:") | |
| # Load CLIP model | |
| encoder = SentenceTransformer("clip-ViT-B-32") | |
| # Get vector dimension safely | |
| try: | |
| VECTOR_SIZE = encoder.get_sentence_embedding_dimension() | |
| if VECTOR_SIZE is None: | |
| VECTOR_SIZE = len(encoder.encode(["test"])[0]) | |
| except Exception: | |
| VECTOR_SIZE = len(encoder.encode(["test"])[0]) | |
| # Create collection if not exists | |
| if not qclient.collection_exists(COLLECTION): | |
| qclient.create_collection( | |
| collection_name=COLLECTION, | |
| vectors_config=VectorParams(size=int(VECTOR_SIZE), distance=Distance.COSINE), | |
| ) | |
| # =============================== | |
| # Encoding function | |
| # =============================== | |
| def encode_data(text=None, image=None): | |
| """Encode either text or image into embeddings""" | |
| if isinstance(image, Image.Image): | |
| return encoder.encode(image.convert("RGB")) | |
| if isinstance(image, str): | |
| return encoder.encode(Image.open(image).convert("RGB")) | |
| if text: | |
| return encoder.encode([text])[0] | |
| return None | |
| # =============================== | |
| # Add Item | |
| # =============================== | |
| def add_item(text, image, uploader_name, uploader_phone): | |
| try: | |
| img_path = None | |
| vector = None | |
| if isinstance(image, Image.Image): | |
| img_id = str(uuid.uuid4()) | |
| img_path = os.path.join(UPLOAD_DIR, f"{img_id}.png") | |
| image.save(img_path) | |
| vector = encode_data(image=image) | |
| elif text: | |
| vector = encode_data(text=text) | |
| if vector is None: | |
| return "β Please provide at least an image or text." | |
| vec = np.asarray(vector, dtype=float) | |
| payload = { | |
| "text": text or "", | |
| "uploader_name": uploader_name or "N/A", | |
| "uploader_phone": uploader_phone or "N/A", | |
| "image_path": img_path, | |
| "has_image": bool(img_path), | |
| } | |
| qclient.upsert( | |
| collection_name=COLLECTION, | |
| points=[PointStruct(id=str(uuid.uuid4()), vector=vec.tolist(), payload=payload)], | |
| wait=True, | |
| ) | |
| return "β Item added to database!" | |
| except Exception as e: | |
| return f"β Error: {e}" | |
| # =============================== | |
| # Search Items | |
| # =============================== | |
| def search_items(text, image, max_results, min_score): | |
| try: | |
| text_vec = None | |
| img_vec = None | |
| if isinstance(image, Image.Image): | |
| img_vec = encode_data(image=image) | |
| img_vec = np.asarray(img_vec, dtype=float) | |
| if text and len(text.strip()) > 0: | |
| text_vec = encode_data(text=text) | |
| text_vec = np.asarray(text_vec, dtype=float) | |
| if img_vec is not None and text_vec is not None: | |
| # Combine both queries | |
| v1 = img_vec / (np.linalg.norm(img_vec) + 1e-12) | |
| v2 = text_vec / (np.linalg.norm(text_vec) + 1e-12) | |
| qvec = (v1 + v2) / 2 | |
| elif img_vec is not None: | |
| qvec = img_vec | |
| elif text_vec is not None: | |
| qvec = text_vec | |
| else: | |
| return "β Provide text or image to search.", [] | |
| hits = qclient.search( | |
| collection_name=COLLECTION, | |
| query_vector=qvec.tolist(), | |
| limit=int(max_results), | |
| score_threshold=float(min_score), | |
| with_payload=True, | |
| ) | |
| if not hits: | |
| return "No matches found.", [] | |
| result_texts = [] | |
| gallery_items = [] | |
| for h in hits: | |
| payload = h.payload or {} | |
| score_str = f"{getattr(h, 'score', 0):.3f}" | |
| uploader_name = payload.get("uploader_name", "N/A") or "N/A" | |
| uploader_phone = payload.get("uploader_phone", "N/A") or "N/A" | |
| desc = ( | |
| f"id:{h.id} | score:{score_str} | text:{payload.get('text','')} " | |
| f"| finder:{uploader_name} ({uploader_phone})" | |
| ) | |
| result_texts.append(desc) | |
| img_path = payload.get("image_path") | |
| if img_path and os.path.exists(img_path): | |
| gallery_items.append(img_path) | |
| return "\n".join(result_texts), gallery_items | |
| except Exception as e: | |
| return f"β Error: {e}", [] | |
| # =============================== | |
| # Clear DB | |
| # =============================== | |
| def clear_database(): | |
| try: | |
| if qclient.collection_exists(COLLECTION): | |
| qclient.delete_collection(COLLECTION) | |
| qclient.create_collection( | |
| collection_name=COLLECTION, | |
| vectors_config=VectorParams(size=int(VECTOR_SIZE), distance=Distance.COSINE), | |
| ) | |
| for f in os.listdir(UPLOAD_DIR): | |
| try: | |
| os.remove(os.path.join(UPLOAD_DIR, f)) | |
| except: | |
| pass | |
| return "ποΈ Database cleared!" | |
| except Exception as e: | |
| return f"β Error clearing DB: {e}" | |
| # =============================== | |
| # Gradio UI | |
| # =============================== | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## π Lost & Found") | |
| with gr.Tab("β Add Found Item"): | |
| text_in = gr.Textbox(label="Description (optional)") | |
| img_in = gr.Image(type="pil", label="Upload Image (optional)") | |
| uploader_name = gr.Textbox(label="Finder's name") | |
| uploader_phone = gr.Textbox(label="Finder's phone") | |
| add_btn = gr.Button("Add to database") | |
| add_status = gr.Textbox(label="Status") | |
| add_btn.click(add_item, inputs=[text_in, img_in, uploader_name, uploader_phone], outputs=[add_status]) | |
| with gr.Tab("π Search Lost Item"): | |
| search_text = gr.Textbox(label="Search by text (optional)") | |
| search_img = gr.Image(type="pil", label="Search by image (optional)") | |
| max_results = gr.Slider(1, 20, value=5, step=1, label="Max results") | |
| min_score = gr.Slider(0.0, 1.0, value=0.75, step=0.01, label="Min similarity score") | |
| search_btn = gr.Button("Search") | |
| search_text_out = gr.Textbox(label="Search results (text)") | |
| search_gallery = gr.Gallery(label="Search Results", columns=2, height="auto") | |
| search_btn.click(search_items, inputs=[search_text, search_img, max_results, min_score], outputs=[search_text_out, search_gallery]) | |
| with gr.Tab("ποΈ Admin"): | |
| clear_btn = gr.Button("Clear database") | |
| clear_out = gr.Textbox(label="Status") | |
| clear_btn.click(clear_database, outputs=[clear_out]) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |