image-retrieval / app.py
ABAO77's picture
Upload 3 files
4dc9354 verified
import os
import torch
import faiss
import base64
from PIL import Image
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from io import BytesIO
from src.modules import FeatureExtractor
from fastapi.middleware.cors import CORSMiddleware
import zipfile
from pydantic import BaseModel, Field
import json
from dotenv import load_dotenv
load_dotenv(override=True)
encoded_env = os.getenv("ENCODED_ENV")
if encoded_env:
# Decode the base64 string
decoded_env = base64.b64decode(encoded_env).decode()
# Load it as a dictionary
env_data = json.loads(decoded_env)
# Set environment variables
for key, value in env_data.items():
os.environ[key] = value
app = FastAPI(docs_url="/")
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize paths
index_path = "./model/db_vit_b_16.index"
onnx_path = "./model/vit_b_16_feature_extractor.onnx"
# Check if index file exists
if not os.path.exists(index_path):
raise FileNotFoundError(f"Index file not found: {index_path}")
try:
# Load FAISS index
index = faiss.read_index(index_path)
print(f"Successfully loaded FAISS index from {index_path}")
# Initialize feature extractor with ONNX support
feature_extractor = FeatureExtractor(base_model="vit_b_16", onnx_path=onnx_path)
print("Successfully initialized feature extractor with ONNX support")
except Exception as e:
raise RuntimeError(f"Error initializing models: {str(e)}")
def base64_to_image(base64_str: str) -> Image.Image:
try:
image_data = base64.b64decode(base64_str)
image = Image.open(BytesIO(image_data)).convert("RGB")
return image
except Exception as e:
raise HTTPException(status_code=400, detail="Invalid Base64 image")
def image_to_base64(image: Image.Image) -> str:
buffered = BytesIO()
image.save(buffered, format="JPEG")
return base64.b64encode(buffered.getvalue()).decode("utf-8")
def unzip_folder(zip_file_path, extract_to_path):
if not os.path.exists(zip_file_path):
raise FileNotFoundError(f"Zip file not found: {zip_file_path}")
with zipfile.ZipFile(zip_file_path, "r") as zip_ref:
for member in zip_ref.infolist():
filename = member.filename.encode("cp437").decode("utf-8")
extracted_path = os.path.join(extract_to_path, filename)
os.makedirs(os.path.dirname(extracted_path), exist_ok=True)
with zip_ref.open(member) as source, open(extracted_path, "wb") as target:
target.write(source.read())
print(f"Extracted all files to: {extract_to_path}")
zip_file = "./images.zip"
extract_path = "./data"
unzip_folder(zip_file, extract_path)
def is_image_file(filename):
valid_extensions = (".png", ".jpg", ".jpeg", ".bmp", ".gif", ".tiff", ".webp")
return filename.lower().endswith(valid_extensions)
class ImageSearchBody(BaseModel):
base64_image: str = Field(..., title="Base64 Image String")
@app.post("/search-image/")
async def search_image(body: ImageSearchBody):
try:
# Convert base64 to image
image = base64_to_image(body.base64_image)
# Extract features using ONNX model
output = feature_extractor.extract_features(image)
# Prepare features for FAISS search
output = output.view(output.size(0), -1)
output = output / output.norm(p=2, dim=1, keepdim=True)
# Search for similar images
D, I = index.search(output.cpu().numpy(), 1)
# Get the matched image
image_list = sorted([f for f in os.listdir(extract_path) if is_image_file(f)])
image_name = image_list[int(I[0][0])]
matched_image_path = f"{extract_path}/{image_name}"
matched_image = Image.open(matched_image_path)
matched_image_base64 = image_to_base64(matched_image)
return JSONResponse(
content={
"image_base64": matched_image_base64,
"image_name": image_name,
"similarity_score": float(D[0][0]),
},
status_code=200,
)
except Exception as e:
print(f"Error in search_image: {str(e)}")
return JSONResponse(
content={"error": f"Error processing image: {str(e)}"}, status_code=500
)
from src.firebase.firebase_provider import process_images
class Body(BaseModel):
base64_image: list[str] = Field(..., title="Base64 Image String")
model_config = {
"json_schema_extra": {
"examples": [
{
"base64_image": [
"iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAYAAACNiR0NAAABdUlEQVR42mNk",
]
}
]
}
}
@app.post("/upload_image")
async def upload_image(body: Body):
try:
public_url = await process_images(body.base64_image)
return JSONResponse(content={"public_url": public_url}, status_code=200)
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=500)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)