Andy-6 commited on
Commit
35b0ab1
·
1 Parent(s): 1694430

initial commit

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ data/images/
README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ # Text-to-Image-Retrieval App
app.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ app.py
3
+ ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
4
+ Gradio web app for text-to-image retrieval.
5
+
6
+ How it works:
7
+ 1. At startup: load CLIP (text encoder) + ChromaDB collection (pre-built)
8
+ 2. On query : encode the user's text prompt → cosine search → top-K images
9
+
10
+ Image source (automatic fallback):
11
+ - Local : if data/images/ exists and contains files → serve from disk
12
+ - Remote : otherwise → load images from HuggingFace Flickr8k dataset
13
+
14
+ Run locally:
15
+ python app.py
16
+
17
+ Deploy to HuggingFace Spaces:
18
+ Push this file + requirements.txt + chroma_db/ to your Space.
19
+ (data/images/ is optional — if absent, images are loaded from HuggingFace)
20
+ ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
21
+ """
22
+
23
+ from pathlib import Path
24
+
25
+ import chromadb
26
+ import gradio as gr
27
+ import torch
28
+ from PIL import Image
29
+ from transformers import CLIPModel, CLIPProcessor
30
+
31
+ # ── Config ────────────────────────────────────────────────────────────────────
32
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
33
+ MODEL_NAME = "openai/clip-vit-base-patch16"
34
+ IMAGES_DIR = Path("data/images")
35
+ CHROMA_DIR = Path("chroma_db")
36
+ COLLECTION = "flickr8k"
37
+ DEFAULT_TOPK = 10
38
+ MAX_TOPK = 20
39
+ # ──────────────────────────────────────────────────────────────────────────────
40
+
41
+
42
+ # ── Load CLIP ─────────────────────────────────────────────────────────────────
43
+ print(f"\nStarting up on device: {DEVICE}")
44
+
45
+ print("Loading CLIP model …")
46
+ model = CLIPModel.from_pretrained(MODEL_NAME).to(DEVICE)
47
+ processor = CLIPProcessor.from_pretrained(MODEL_NAME)
48
+ model.eval()
49
+ print(" CLIP ready.\n")
50
+
51
+
52
+ # ── Connect to ChromaDB ───────────────────────────────────────────────────────
53
+ print("Connecting to ChromaDB …")
54
+ if not (CHROMA_DIR / "chroma.sqlite3").exists():
55
+ raise FileNotFoundError(
56
+ f"ChromaDB not found at '{CHROMA_DIR}'. "
57
+ "Run build_index.py first, then re-launch."
58
+ )
59
+ chroma_client = chromadb.PersistentClient(path=str(CHROMA_DIR))
60
+ collection = chroma_client.get_collection(COLLECTION)
61
+ print(f" Collection ready: {collection.count()} images indexed.\n")
62
+
63
+
64
+ # ── Image source: local disk or HuggingFace dataset ──────────────────────────
65
+ USE_LOCAL_IMAGES = IMAGES_DIR.exists() and any(IMAGES_DIR.iterdir())
66
+
67
+ if USE_LOCAL_IMAGES:
68
+ print(f"Image source: local disk ({IMAGES_DIR})\n")
69
+ dataset = None
70
+ else:
71
+ print("Image source: HuggingFace dataset (data/images/ not found locally)")
72
+ print("Loading Flickr8k …")
73
+ from datasets import load_dataset
74
+ dataset = load_dataset("jxie/flickr8k", split="train+validation+test")
75
+ print(f" Dataset ready: {len(dataset)} images.\n")
76
+
77
+
78
+ # ── Helper: load a single image ───────────────────────────────────────────────
79
+ def load_image(meta: dict) -> Image.Image:
80
+ """
81
+ Load an image from local disk or HuggingFace dataset depending on
82
+ what is available at runtime.
83
+ """
84
+ if USE_LOCAL_IMAGES:
85
+ return Image.open(IMAGES_DIR / meta["filename"]).convert("RGB")
86
+ else:
87
+ return dataset[meta["dataset_index"]]["image"].convert("RGB")
88
+
89
+
90
+ # ── Core retrieval function ───────────────────────────────────────────────────
91
+ def retrieve(query: str, top_k: int = DEFAULT_TOPK) -> list[tuple[Image.Image, str]]:
92
+ """
93
+ Encode `query` with CLIP and return the top-k matching (image, score) pairs.
94
+ Returns an empty list when the query is blank.
95
+ """
96
+ query = query.strip()
97
+ if not query:
98
+ return []
99
+
100
+ # Encode text with CLIP
101
+ inputs = processor(text=[query], return_tensors="pt", padding=True).to(DEVICE)
102
+ with torch.no_grad():
103
+ output = model.get_text_features(**inputs)
104
+ # handle both tensor and object outputs across transformers versions
105
+ text_features = output.pooler_output if hasattr(output, "pooler_output") else output
106
+
107
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
108
+ query_vec = text_features.cpu().numpy().tolist()[0]
109
+
110
+ # Vector search in ChromaDB
111
+ results = collection.query(
112
+ query_embeddings=[query_vec],
113
+ n_results=int(top_k),
114
+ include=["metadatas", "distances"],
115
+ )
116
+
117
+ # Build output: (PIL image, score label)
118
+ output = []
119
+ for meta, dist in zip(results["metadatas"][0], results["distances"][0]):
120
+ img = load_image(meta)
121
+ # ChromaDB cosine distance: 0 = identical, 2 = opposite
122
+ # Convert to a 0-100 similarity percentage for display
123
+ similarity = round((1 - dist / 2) * 100, 1)
124
+ output.append((img, f"Score: {similarity}%"))
125
+
126
+ return output
127
+
128
+
129
+ # ── Gradio UI ─────────────────────────────────────────────────────────────────
130
+ _EXAMPLES = [
131
+ ["a dog playing in the snow"],
132
+ ["children playing at a park"],
133
+ ["a man surfing ocean waves"],
134
+ ["a woman reading a book"],
135
+ ["a group of people watching a performance"],
136
+ ["a cat sitting on a windowsill"],
137
+ ["a bike race on a mountain trail"],
138
+ ["fireworks over a city at night"],
139
+ ]
140
+
141
+ with gr.Blocks(
142
+ title="CLIP Text-to-Image Retrieval",
143
+ theme=gr.themes.Soft(),
144
+ ) as demo:
145
+
146
+ gr.Markdown(
147
+ """
148
+ # 🔍 Text-to-Image Retrieval
149
+ Enter a natural language description and find matching images from the **Flickr8k** dataset.
150
+ Built with [CLIP](https://openai.com/research/clip) (ViT-B/16) + [ChromaDB](https://www.trychroma.com/).
151
+ """
152
+ )
153
+
154
+ with gr.Row():
155
+ query_box = gr.Textbox(
156
+ placeholder="e.g. a dog playing in the snow",
157
+ label="Search prompt",
158
+ scale=5,
159
+ )
160
+ topk_slider = gr.Slider(
161
+ minimum=1, maximum=MAX_TOPK, value=DEFAULT_TOPK, step=1,
162
+ label="Results",
163
+ scale=1,
164
+ )
165
+ search_btn = gr.Button("Search 🔎", variant="primary", scale=1)
166
+
167
+ gallery = gr.Gallery(
168
+ label="Top results",
169
+ columns=5,
170
+ rows=2,
171
+ height="auto",
172
+ object_fit="cover",
173
+ show_label=True,
174
+ )
175
+
176
+ gr.Examples(
177
+ examples=_EXAMPLES,
178
+ inputs=query_box,
179
+ label="Try one of these …",
180
+ )
181
+
182
+ # Wire up interactions — both button click and Enter key trigger retrieve()
183
+ search_btn.click(fn=retrieve, inputs=[query_box, topk_slider], outputs=gallery)
184
+ query_box.submit(fn=retrieve, inputs=[query_box, topk_slider], outputs=gallery)
185
+
186
+
187
+ # ── Entry point ───────────────────────────────────────────────────────────────
188
+ if __name__ == "__main__":
189
+ demo.launch(
190
+ server_name="0.0.0.0", # listen on all interfaces (needed for LAN access)
191
+ share=False, # set True for a temporary public gradio.live URL
192
+ )
chroma_db/57e6ab60-34f7-4656-8506-9bb8673dc71a/data_level0.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:615f5509afa96d52f58663bdd7c0f09db6a17b10172b4949d8e740403637f8d1
3
+ size 15683584
chroma_db/57e6ab60-34f7-4656-8506-9bb8673dc71a/header.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:247d206a62e4985a7fe1ffd10f57fcb1c4fd569a80f4d39aa2fc20804739750f
3
+ size 100
chroma_db/57e6ab60-34f7-4656-8506-9bb8673dc71a/index_metadata.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6967640dd90753ba50c1b3822cf0d5eaf73e8a98c1b97b6d80de51cd9b849992
3
+ size 198640
chroma_db/57e6ab60-34f7-4656-8506-9bb8673dc71a/length.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:af779dc40371f141b64ab92aefa7b1b377c564f9c6b0334606ec1aa3abd9d216
3
+ size 28672
chroma_db/57e6ab60-34f7-4656-8506-9bb8673dc71a/link_lists.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:14619dec3385eb7542ec27e892a07d4eb9dee71df09b51c0dd1e43474fc3fc33
3
+ size 62740
chroma_db/chroma.sqlite3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:257d0073b2e1a2bc06370c4b667c97072468d2c661b7e0b9572a49b10dbc674c
3
+ size 6504448
requirements.txt ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==24.1.0
2
+ aiohappyeyeballs==2.6.1
3
+ aiohttp==3.13.3
4
+ aiosignal==1.4.0
5
+ annotated-doc==0.0.4
6
+ annotated-types==0.7.0
7
+ anyio==4.12.1
8
+ attrs==26.1.0
9
+ bcrypt==5.0.0
10
+ brotli==1.2.0
11
+ build==1.4.0
12
+ certifi==2026.2.25
13
+ charset-normalizer==3.4.6
14
+ chromadb==1.5.5
15
+ click==8.3.1
16
+ cuda-bindings==12.9.4
17
+ cuda-pathfinder==1.2.2
18
+ datasets==4.8.3
19
+ dill==0.4.1
20
+ durationpy==0.10
21
+ fastapi==0.135.1
22
+ ffmpy==1.0.0
23
+ filelock==3.20.0
24
+ flatbuffers==25.12.19
25
+ frozenlist==1.8.0
26
+ fsspec==2025.12.0
27
+ googleapis-common-protos==1.73.0
28
+ gradio==6.9.0
29
+ gradio_client==2.3.0
30
+ groovy==0.1.2
31
+ grpcio==1.78.0
32
+ h11==0.16.0
33
+ hf-xet==1.4.2
34
+ httpcore==1.0.9
35
+ httptools==0.7.1
36
+ httpx==0.28.1
37
+ huggingface_hub==1.7.2
38
+ idna==3.11
39
+ importlib_metadata==8.7.1
40
+ importlib_resources==6.5.2
41
+ Jinja2==3.1.6
42
+ jsonschema==4.26.0
43
+ jsonschema-specifications==2025.9.1
44
+ kubernetes==35.0.0
45
+ markdown-it-py==4.0.0
46
+ MarkupSafe==3.0.2
47
+ mdurl==0.1.2
48
+ mmh3==5.2.1
49
+ mpmath==1.3.0
50
+ multidict==6.7.1
51
+ multiprocess==0.70.19
52
+ networkx==3.6.1
53
+ numpy==2.3.5
54
+ nvidia-cublas-cu12==12.6.4.1
55
+ nvidia-cuda-cupti-cu12==12.6.80
56
+ nvidia-cuda-nvrtc-cu12==12.6.77
57
+ nvidia-cuda-runtime-cu12==12.6.77
58
+ nvidia-cudnn-cu12==9.10.2.21
59
+ nvidia-cufft-cu12==11.3.0.4
60
+ nvidia-cufile-cu12==1.11.1.6
61
+ nvidia-curand-cu12==10.3.7.77
62
+ nvidia-cusolver-cu12==11.7.1.2
63
+ nvidia-cusparse-cu12==12.5.4.2
64
+ nvidia-cusparselt-cu12==0.7.1
65
+ nvidia-nccl-cu12==2.27.5
66
+ nvidia-nvjitlink-cu12==12.6.85
67
+ nvidia-nvshmem-cu12==3.4.5
68
+ nvidia-nvtx-cu12==12.6.77
69
+ oauthlib==3.3.1
70
+ onnxruntime==1.24.4
71
+ opentelemetry-api==1.40.0
72
+ opentelemetry-exporter-otlp-proto-common==1.40.0
73
+ opentelemetry-exporter-otlp-proto-grpc==1.40.0
74
+ opentelemetry-proto==1.40.0
75
+ opentelemetry-sdk==1.40.0
76
+ opentelemetry-semantic-conventions==0.61b0
77
+ orjson==3.11.7
78
+ overrides==7.7.0
79
+ packaging @ file:///home/task_176104874243446/conda-bld/packaging_1761049080023/work
80
+ pandas==3.0.1
81
+ pillow==12.0.0
82
+ propcache==0.4.1
83
+ protobuf==6.33.6
84
+ pyarrow==23.0.1
85
+ pybase64==1.4.3
86
+ pydantic==2.12.5
87
+ pydantic-settings==2.13.1
88
+ pydantic_core==2.41.5
89
+ pydub==0.25.1
90
+ Pygments==2.19.2
91
+ PyPika==0.51.1
92
+ pyproject_hooks==1.2.0
93
+ python-dateutil==2.9.0.post0
94
+ python-dotenv==1.2.2
95
+ python-multipart==0.0.22
96
+ pytz==2026.1.post1
97
+ PyYAML==6.0.3
98
+ referencing==0.37.0
99
+ regex==2026.2.28
100
+ requests==2.32.5
101
+ requests-oauthlib==2.0.0
102
+ rich==14.3.3
103
+ rpds-py==0.30.0
104
+ safehttpx==0.1.7
105
+ safetensors==0.7.0
106
+ semantic-version==2.10.0
107
+ setuptools==80.10.2
108
+ shellingham==1.5.4
109
+ six==1.17.0
110
+ starlette==0.52.1
111
+ sympy==1.14.0
112
+ tenacity==9.1.4
113
+ tokenizers==0.22.2
114
+ tomlkit==0.13.3
115
+ torch==2.10.0+cu126
116
+ torchvision==0.25.0+cu126
117
+ tqdm==4.67.3
118
+ transformers==5.3.0
119
+ triton==3.6.0
120
+ typer==0.24.1
121
+ typing-inspection==0.4.2
122
+ typing_extensions==4.15.0
123
+ urllib3==2.6.3
124
+ uvicorn==0.42.0
125
+ uvloop==0.22.1
126
+ watchfiles==1.1.1
127
+ websocket-client==1.9.0
128
+ websockets==16.0
129
+ wheel==0.46.3
130
+ xxhash==3.6.0
131
+ yarl==1.23.0
132
+ zipp==3.23.0