Sankie005's picture
Upload 434 files
c446951
import base64
import os
from io import BytesIO
import faiss
import numpy as np
import requests
from PIL import Image
import argparse
import json
from flask import Flask, request, render_template, send_from_directory
parser = argparse.ArgumentParser(description="Build a search engine with CLIP")
parser.add_argument(
"--dataset_path", type=str, required=True, help="Path to dataset", default="images"
)
parser.add_argument(
"--inference_endpoint",
type=str,
required=True,
help="Roboflow Inference endpoint URL",
default="http://localhost:9001",
)
parser.add_argument(
"--api_key",
type=str,
required=True,
help="Roboflow API key",
default=os.environ.get("ROBOFLOW_API_KEY")
)
args = parser.parse_args()
app = Flask(__name__)
def get_image_embedding(image: str) -> dict:
image = image.convert("RGB")
buffer = BytesIO()
image.save(buffer, format="JPEG")
image = base64.b64encode(buffer.getvalue()).decode("utf-8")
payload = {
"image": {"type": "base64", "value": image},
}
data = requests.post(
args.inference_endpoint + "/clip/embed_image?api_key=" + args.api_key,
json=payload,
)
response = data.json()
embedding = response["embeddings"]
return embedding
if os.path.exists("index.bin"):
index = faiss.read_index("index.bin")
with open("database.json", "r") as f:
file_names = json.load(f)
else:
index = faiss.IndexFlatL2(512)
file_names = []
TRAIN_IMAGES = os.path.join(args.dataset_path, "train")
for frame_name in os.listdir(TRAIN_IMAGES):
try:
frame = Image.open(os.path.join(TRAIN_IMAGES, frame_name))
except IOError:
print("error computing embedding for", frame_name)
continue
embedding = get_image_embedding(frame)
index.add(np.array(embedding).astype(np.float32))
file_names.append(frame_name)
faiss.write_index(index, "index.bin")
with open("database.json", "w") as f:
json.dump(file_names, f)
@app.route("/", methods=["GET", "POST"])
def search():
if request.method == "POST":
file = request.files["file"]
query = get_image_embedding(Image.open(file))
_, I = index.search(np.array(query).astype(np.float32), 3)
images = [os.path.join(TRAIN_IMAGES, file_names[i]) for i in I[0]]
return render_template("index.html", images=images)
return render_template("index.html")
@app.route("/images/<path:path>")
def send_image(path):
return send_from_directory(args.dataset_path, path)
if __name__ == "__main__":
app.run(debug=True)