import os import base64 import json from dotenv import load_dotenv load_dotenv(override=True) encoded_env = os.getenv("ENCODED_ENV_IMAGE") if encoded_env: decoded_env = base64.b64decode(encoded_env).decode() env_data = json.loads(decoded_env) for key, value in env_data.items(): os.environ[key] = value import torch from fastapi import FastAPI from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from dotenv import load_dotenv import faulthandler from PIL import Image from src.utils.image_utils import base64_to_image, image_to_base64, is_image_file from src.utils.zip_utils import extract_zip_file from src.utils.model_utils import init_models, search_similar_images from src.firebase.firebase_provider import process_images # Enable fault handler to debug segmentation faults faulthandler.enable() load_dotenv(override=True) # Force CPU mode to avoid segmentation faults with ONNX/PyTorch os.environ["CUDA_VISIBLE_DEVICES"] = "" torch.set_num_threads(1) # Load environment variables # Initialize FastAPI app app = FastAPI(docs_url="/") origins = ["*"] app.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Initialize paths and models index_path = "./model/db_vit_b_16.index" onnx_path = "./model/vit_b_16_feature_extractor.onnx" index, feature_extractor = init_models(index_path, onnx_path) # Extract images if needed zip_file = "./images_2.zip" extract_path = "./data" extract_zip_file(zip_file, extract_path) class ImageSearchBody(BaseModel): base64_image: str = Field(..., title="Base64 Image String") @app.post("/search-image/") def search_image(body: ImageSearchBody): try: # Convert base64 to image image = base64_to_image(body.base64_image) # Extract features using ONNX model features = feature_extractor.extract_features(image) # Search for similar images D, I = search_similar_images(index, features) # Get the matched image image_list = sorted( [f for f in os.listdir(extract_path + "/images") if is_image_file(f)] ) image_name = image_list[int(I[0][0])] matched_image_path = f"{extract_path}/images/{image_name}" matched_image = Image.open(matched_image_path) matched_image_base64 = image_to_base64(matched_image) # Post-process image name: remove underscores, numbers, and file extension image_name_post_process = image_name.replace( "_", " " ) # Replace underscores with spaces image_name_post_process = "".join( [c for c in image_name_post_process if not c.isdigit()] ) # Remove numbers image_name_post_process = image_name_post_process.rsplit(".", 1)[ 0 ] # Remove file extension return JSONResponse( content={ "image_base64": matched_image_base64, "image_name": image_name_post_process, "similarity_score": float(D[0][0]), }, status_code=200, ) except Exception as e: print(f"Error in search_image: {str(e)}") return JSONResponse( content={"error": f"Error processing image: {str(e)}"}, status_code=500 ) class Body(BaseModel): base64_image: list[str] = Field(..., title="Base64 Image String") model_config = { "json_schema_extra": { "examples": [ { "base64_image": [ "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAYAAACNiR0NAAABdUlEQVR42mNk", ] } ] } } @app.post("/upload_image") async def upload_image(body: Body): try: public_url = await process_images(body.base64_image) return JSONResponse(content={"public_url": public_url}, status_code=200) except Exception as e: return JSONResponse(content={"error": str(e)}, status_code=500) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)