ChristopherMarais commited on
Commit
bb00855
·
verified ·
1 Parent(s): c633c3f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -0
app.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import json
3
+ import numpy as np
4
+ import torch
5
+ from flask import Flask, request, render_template_string
6
+ from transformers import CLIPProcessor, CLIPModel
7
+ import faiss
8
+
9
+ app = Flask(__name__)
10
+
11
+ # Global variables for the model, processor, FAISS index, and image metadata.
12
+ model = None
13
+ processor = None
14
+ index = None
15
+ image_embeddings = None
16
+ image_metadata = None
17
+
18
+ def load_model():
19
+ global model, processor
20
+ print("Loading CLIP model and processor...")
21
+ model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
22
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
23
+
24
+ def load_data():
25
+ global image_embeddings, image_metadata, index
26
+ print("Loading image embeddings and metadata...")
27
+ # Load precomputed embeddings and metadata from a JSON file.
28
+ with open("data/embeddings.json", "r") as f:
29
+ data = json.load(f)
30
+
31
+ # Each item in data should have an "embedding" key and a "url" (and optionally an "id")
32
+ image_embeddings = np.array([d["embedding"] for d in data]).astype('float32')
33
+ image_metadata = data
34
+
35
+ # Build a FAISS index using L2 distance. The dimension 'd' must match the embedding size.
36
+ d = image_embeddings.shape[1]
37
+ index = faiss.IndexFlatL2(d)
38
+ index.add(image_embeddings)
39
+ print(f"FAISS index built with {index.ntotal} embeddings.")
40
+
41
+ @app.route("/", methods=["GET", "POST"])
42
+ def search():
43
+ results_html = ""
44
+ query = ""
45
+ if request.method == "POST":
46
+ query = request.form.get("query", "")
47
+ if query:
48
+ # Encode the text query using CLIP's text encoder.
49
+ inputs = processor(text=[query], return_tensors="pt", padding=True)
50
+ with torch.no_grad():
51
+ text_features = model.get_text_features(**inputs)
52
+ text_features = text_features.cpu().numpy().astype("float32")
53
+
54
+ # Query the FAISS index for the top k similar images.
55
+ k = 10 # number of results to return
56
+ distances, indices = index.search(text_features, k)
57
+
58
+ # Build HTML image elements for each result.
59
+ results = []
60
+ for idx in indices[0]:
61
+ meta = image_metadata[idx]
62
+ results.append(
63
+ f'<div style="margin:10px;"><img src="{meta["url"]}" alt="Image {meta.get("id", "")}" style="max-width:200px;"><br>ID: {meta.get("id", "N/A")}</div>'
64
+ )
65
+ results_html = "".join(results)
66
+
67
+ # Simple HTML form with results displayed below.
68
+ html = f"""
69
+ <!DOCTYPE html>
70
+ <html>
71
+ <head>
72
+ <meta charset="UTF-8">
73
+ <title>Image Search with CLIP &amp; FAISS</title>
74
+ </head>
75
+ <body>
76
+ <h1>Image Search</h1>
77
+ <form method="post">
78
+ <input type="text" name="query" placeholder="Enter search text" value="{query}" required>
79
+ <input type="submit" value="Search">
80
+ </form>
81
+ <div style="display:flex; flex-wrap: wrap; margin-top:20px;">
82
+ {results_html}
83
+ </div>
84
+ </body>
85
+ </html>
86
+ """
87
+ return render_template_string(html)
88
+
89
+ if __name__ == "__main__":
90
+ load_model()
91
+ load_data()
92
+ # Run the Flask app on the port expected by Hugging Face Spaces.
93
+ app.run(host="0.0.0.0", port=8080)