from fastapi import APIRouter, status, Depends from fastapi.responses import JSONResponse import os import base64 import json from dotenv import load_dotenv load_dotenv(override=True) encoded_env = os.getenv("ENCODED_ENV_IMAGE") if encoded_env: decoded_env = base64.b64decode(encoded_env).decode() env_data = json.loads(decoded_env) for key, value in env_data.items(): os.environ[key] = value import faulthandler from PIL import Image from src.utils.image_utils import base64_to_image, image_to_base64, is_image_file from src.utils.zip_utils import extract_zip_file from src.utils.model_utils import init_models, search_similar_images # Enable fault handler to debug segmentation faults faulthandler.enable() load_dotenv(override=True) # Force CPU mode to avoid segmentation faults with ONNX/PyTorch os.environ["CUDA_VISIBLE_DEVICES"] = "" router = APIRouter(prefix="/image_search", tags=["Image Search"]) index_path = "./model/db_vit_b_16.index" onnx_path = "./model/vit_b_16_feature_extractor.onnx" index, feature_extractor = init_models(index_path, onnx_path) # Extract images if needed zip_file = "images/images_2.zip" extract_path = "./data" extract_zip_file(zip_file, extract_path) def process_search_image(base64_image): image = base64_to_image(base64_image) # Extract features using ONNX model features = feature_extractor.extract_features(image) # Search for similar images D, I = search_similar_images(index, features) # Get the matched image image_list = sorted( [f for f in os.listdir(extract_path + "/images") if is_image_file(f)] ) image_name = image_list[int(I[0][0])] matched_image_path = f"{extract_path}/images/{image_name}" matched_image = Image.open(matched_image_path) matched_image_base64 = image_to_base64(matched_image) # Post-process image name: remove underscores, numbers, and file extension image_name_post_process = image_name.replace( "_", " " ) # Replace underscores with spaces image_name_post_process = "".join( [c for c in image_name_post_process if not c.isdigit()] ) # Remove numbers image_name_post_process = image_name_post_process.rsplit(".", 1)[ 0 ] # Remove file extension print("image_base64: ", matched_image_base64, "image_name: ", image_name_post_process, "similarity_score: ", float(D[0][0])) similarity_score = float(D[0][0]) return matched_image_base64,image_name_post_process,similarity_score