File size: 6,951 Bytes
dd67299
 
7c5c440
7116b90
 
 
 
c559bc7
7fae8fb
c559bc7
7116b90
c559bc7
03852d5
7c5c440
22ef1d5
dd67299
8f370f4
7116b90
 
7712c9d
c559bc7
 
7712c9d
 
 
 
 
 
 
 
 
dd67299
 
c559bc7
7712c9d
dd67299
7c5c440
c559bc7
7116b90
c559bc7
22ef1d5
7712c9d
7116b90
c65ef6e
7116b90
c559bc7
7116b90
c559bc7
7116b90
c559bc7
 
7712c9d
c559bc7
9e00920
22ef1d5
03852d5
c559bc7
 
7116b90
03852d5
 
 
c65ef6e
c559bc7
 
 
 
7712c9d
7116b90
 
 
 
 
7712c9d
 
7116b90
 
 
0c4adc5
22ef1d5
7fae8fb
7116b90
 
7fae8fb
9e00920
22ef1d5
 
746bf5b
c559bc7
7712c9d
c559bc7
7c5c440
0d5f8a4
7116b90
 
 
 
 
 
 
 
 
 
 
7712c9d
 
 
 
7116b90
 
 
 
 
7712c9d
7116b90
 
22ef1d5
7116b90
 
 
 
7fae8fb
7c5c440
7116b90
c559bc7
 
7116b90
7712c9d
7116b90
 
 
7712c9d
7116b90
 
 
 
 
 
c559bc7
7116b90
 
 
 
 
 
 
a1501eb
22ef1d5
a06f639
c559bc7
7116b90
c559bc7
 
7116b90
 
 
 
 
7712c9d
7116b90
 
 
 
7712c9d
7116b90
 
 
 
dd67299
c559bc7
03852d5
c559bc7
22ef1d5
7712c9d
7116b90
 
 
 
7712c9d
 
7116b90
 
 
 
 
 
 
 
 
22ef1d5
c559bc7
 
7116b90
7c5c440
dd67299
7116b90
dd67299
7116b90
22ef1d5
7116b90
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import os
import uuid
import gradio as gr
import numpy as np
from PIL import Image
from qdrant_client import QdrantClient
from qdrant_client.http.models import VectorParams, Distance, PointStruct
from sentence_transformers import SentenceTransformer

# ===============================
# Config / Setup
# ===============================
UPLOAD_DIR = "uploaded_images"
os.makedirs(UPLOAD_DIR, exist_ok=True)

COLLECTION = "lost_and_found"

qclient = QdrantClient(":memory:")

# Load CLIP model
encoder = SentenceTransformer("clip-ViT-B-32")

# Get vector dimension safely
try:
    VECTOR_SIZE = encoder.get_sentence_embedding_dimension()
    if VECTOR_SIZE is None:
        VECTOR_SIZE = len(encoder.encode(["test"])[0])
except Exception:
    VECTOR_SIZE = len(encoder.encode(["test"])[0])

# Create collection if not exists
if not qclient.collection_exists(COLLECTION):
    qclient.create_collection(
        collection_name=COLLECTION,
        vectors_config=VectorParams(size=int(VECTOR_SIZE), distance=Distance.COSINE),
    )

# ===============================
# Encoding function
# ===============================
def encode_data(text=None, image=None):
    """Encode either text or image into embeddings"""
    if isinstance(image, Image.Image):
        return encoder.encode(image.convert("RGB"))
    if isinstance(image, str):
        return encoder.encode(Image.open(image).convert("RGB"))
    if text:
        return encoder.encode([text])[0]
    return None

# ===============================
# Add Item
# ===============================
def add_item(text, image, uploader_name, uploader_phone):
    try:
        img_path = None
        vector = None

        if isinstance(image, Image.Image):
            img_id = str(uuid.uuid4())
            img_path = os.path.join(UPLOAD_DIR, f"{img_id}.png")
            image.save(img_path)
            vector = encode_data(image=image)
        elif text:
            vector = encode_data(text=text)

        if vector is None:
            return "❌ Please provide at least an image or text."

        vec = np.asarray(vector, dtype=float)

        payload = {
            "text": text or "",
            "uploader_name": uploader_name or "N/A",
            "uploader_phone": uploader_phone or "N/A",
            "image_path": img_path,
            "has_image": bool(img_path),
        }

        qclient.upsert(
            collection_name=COLLECTION,
            points=[PointStruct(id=str(uuid.uuid4()), vector=vec.tolist(), payload=payload)],
            wait=True,
        )
        return "βœ… Item added to database!"
    except Exception as e:
        return f"❌ Error: {e}"

# ===============================
# Search Items
# ===============================
def search_items(text, image, max_results, min_score):
    try:
        text_vec = None
        img_vec = None

        if isinstance(image, Image.Image):
            img_vec = encode_data(image=image)
            img_vec = np.asarray(img_vec, dtype=float)
        if text and len(text.strip()) > 0:
            text_vec = encode_data(text=text)
            text_vec = np.asarray(text_vec, dtype=float)

        if img_vec is not None and text_vec is not None:
            # Combine both queries
            v1 = img_vec / (np.linalg.norm(img_vec) + 1e-12)
            v2 = text_vec / (np.linalg.norm(text_vec) + 1e-12)
            qvec = (v1 + v2) / 2
        elif img_vec is not None:
            qvec = img_vec
        elif text_vec is not None:
            qvec = text_vec
        else:
            return "❌ Provide text or image to search.", []

        hits = qclient.search(
            collection_name=COLLECTION,
            query_vector=qvec.tolist(),
            limit=int(max_results),
            score_threshold=float(min_score),
            with_payload=True,
        )

        if not hits:
            return "No matches found.", []

        result_texts = []
        gallery_items = []

        for h in hits:
            payload = h.payload or {}
            score_str = f"{getattr(h, 'score', 0):.3f}"
            uploader_name = payload.get("uploader_name", "N/A") or "N/A"
            uploader_phone = payload.get("uploader_phone", "N/A") or "N/A"

            desc = (
                f"id:{h.id} | score:{score_str} | text:{payload.get('text','')} "
                f"| finder:{uploader_name} ({uploader_phone})"
            )
            result_texts.append(desc)

            img_path = payload.get("image_path")
            if img_path and os.path.exists(img_path):
                gallery_items.append(img_path)

        return "\n".join(result_texts), gallery_items
    except Exception as e:
        return f"❌ Error: {e}", []

# ===============================
# Clear DB
# ===============================
def clear_database():
    try:
        if qclient.collection_exists(COLLECTION):
            qclient.delete_collection(COLLECTION)
        qclient.create_collection(
            collection_name=COLLECTION,
            vectors_config=VectorParams(size=int(VECTOR_SIZE), distance=Distance.COSINE),
        )
        for f in os.listdir(UPLOAD_DIR):
            try:
                os.remove(os.path.join(UPLOAD_DIR, f))
            except:
                pass
        return "πŸ—‘οΈ Database cleared!"
    except Exception as e:
        return f"❌ Error clearing DB: {e}"

# ===============================
# Gradio UI
# ===============================
with gr.Blocks() as demo:
    gr.Markdown("## πŸ”Ž Lost & Found")

    with gr.Tab("βž• Add Found Item"):
        text_in = gr.Textbox(label="Description (optional)")
        img_in = gr.Image(type="pil", label="Upload Image (optional)")
        uploader_name = gr.Textbox(label="Finder's name")
        uploader_phone = gr.Textbox(label="Finder's phone")
        add_btn = gr.Button("Add to database")
        add_status = gr.Textbox(label="Status")
        add_btn.click(add_item, inputs=[text_in, img_in, uploader_name, uploader_phone], outputs=[add_status])

    with gr.Tab("πŸ” Search Lost Item"):
        search_text = gr.Textbox(label="Search by text (optional)")
        search_img = gr.Image(type="pil", label="Search by image (optional)")
        max_results = gr.Slider(1, 20, value=5, step=1, label="Max results")
        min_score = gr.Slider(0.0, 1.0, value=0.75, step=0.01, label="Min similarity score")
        search_btn = gr.Button("Search")
        search_text_out = gr.Textbox(label="Search results (text)")
        search_gallery = gr.Gallery(label="Search Results", columns=2, height="auto")
        search_btn.click(search_items, inputs=[search_text, search_img, max_results, min_score], outputs=[search_text_out, search_gallery])

    with gr.Tab("πŸ—‘οΈ Admin"):
        clear_btn = gr.Button("Clear database")
        clear_out = gr.Textbox(label="Status")
        clear_btn.click(clear_database, outputs=[clear_out])

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)