import os import torch import faiss import base64 from PIL import Image from fastapi import FastAPI, HTTPException from fastapi.responses import JSONResponse from io import BytesIO from src.modules import FeatureExtractor from fastapi.middleware.cors import CORSMiddleware import zipfile from pydantic import BaseModel, Field import json from dotenv import load_dotenv load_dotenv(override=True) encoded_env = os.getenv("ENCODED_ENV") if encoded_env: # Decode the base64 string decoded_env = base64.b64decode(encoded_env).decode() # Load it as a dictionary env_data = json.loads(decoded_env) # Set environment variables for key, value in env_data.items(): os.environ[key] = value app = FastAPI(docs_url="/") origins = ["*"] app.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Initialize paths index_path = "./model/db_vit_b_16.index" onnx_path = "./model/vit_b_16_feature_extractor.onnx" # Check if index file exists if not os.path.exists(index_path): raise FileNotFoundError(f"Index file not found: {index_path}") try: # Load FAISS index index = faiss.read_index(index_path) print(f"Successfully loaded FAISS index from {index_path}") # Initialize feature extractor with ONNX support feature_extractor = FeatureExtractor(base_model="vit_b_16", onnx_path=onnx_path) print("Successfully initialized feature extractor with ONNX support") except Exception as e: raise RuntimeError(f"Error initializing models: {str(e)}") def base64_to_image(base64_str: str) -> Image.Image: try: image_data = base64.b64decode(base64_str) image = Image.open(BytesIO(image_data)).convert("RGB") return image except Exception as e: raise HTTPException(status_code=400, detail="Invalid Base64 image") def image_to_base64(image: Image.Image) -> str: buffered = BytesIO() image.save(buffered, format="JPEG") return base64.b64encode(buffered.getvalue()).decode("utf-8") def unzip_folder(zip_file_path, extract_to_path): if not os.path.exists(zip_file_path): raise FileNotFoundError(f"Zip file not found: {zip_file_path}") with zipfile.ZipFile(zip_file_path, "r") as zip_ref: for member in zip_ref.infolist(): filename = member.filename.encode("cp437").decode("utf-8") extracted_path = os.path.join(extract_to_path, filename) os.makedirs(os.path.dirname(extracted_path), exist_ok=True) with zip_ref.open(member) as source, open(extracted_path, "wb") as target: target.write(source.read()) print(f"Extracted all files to: {extract_to_path}") zip_file = "./images.zip" extract_path = "./data" unzip_folder(zip_file, extract_path) def is_image_file(filename): valid_extensions = (".png", ".jpg", ".jpeg", ".bmp", ".gif", ".tiff", ".webp") return filename.lower().endswith(valid_extensions) class ImageSearchBody(BaseModel): base64_image: str = Field(..., title="Base64 Image String") @app.post("/search-image/") async def search_image(body: ImageSearchBody): try: # Convert base64 to image image = base64_to_image(body.base64_image) # Extract features using ONNX model output = feature_extractor.extract_features(image) # Prepare features for FAISS search output = output.view(output.size(0), -1) output = output / output.norm(p=2, dim=1, keepdim=True) # Search for similar images D, I = index.search(output.cpu().numpy(), 1) # Get the matched image image_list = sorted([f for f in os.listdir(extract_path) if is_image_file(f)]) image_name = image_list[int(I[0][0])] matched_image_path = f"{extract_path}/{image_name}" matched_image = Image.open(matched_image_path) matched_image_base64 = image_to_base64(matched_image) return JSONResponse( content={ "image_base64": matched_image_base64, "image_name": image_name, "similarity_score": float(D[0][0]), }, status_code=200, ) except Exception as e: print(f"Error in search_image: {str(e)}") return JSONResponse( content={"error": f"Error processing image: {str(e)}"}, status_code=500 ) from src.firebase.firebase_provider import process_images class Body(BaseModel): base64_image: list[str] = Field(..., title="Base64 Image String") model_config = { "json_schema_extra": { "examples": [ { "base64_image": [ "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAYAAACNiR0NAAABdUlEQVR42mNk", ] } ] } } @app.post("/upload_image") async def upload_image(body: Body): try: public_url = await process_images(body.base64_image) return JSONResponse(content={"public_url": public_url}, status_code=200) except Exception as e: return JSONResponse(content={"error": str(e)}, status_code=500) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)