Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import faiss | |
| import base64 | |
| from PIL import Image | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.responses import JSONResponse | |
| from io import BytesIO | |
| from src.modules import FeatureExtractor | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import zipfile | |
| from pydantic import BaseModel, Field | |
| import json | |
| from dotenv import load_dotenv | |
| load_dotenv(override=True) | |
| encoded_env = os.getenv("ENCODED_ENV") | |
| if encoded_env: | |
| # Decode the base64 string | |
| decoded_env = base64.b64decode(encoded_env).decode() | |
| # Load it as a dictionary | |
| env_data = json.loads(decoded_env) | |
| # Set environment variables | |
| for key, value in env_data.items(): | |
| os.environ[key] = value | |
| app = FastAPI(docs_url="/") | |
| origins = ["*"] | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=origins, | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Initialize paths | |
| index_path = "./model/db_vit_b_16.index" | |
| onnx_path = "./model/vit_b_16_feature_extractor.onnx" | |
| # Check if index file exists | |
| if not os.path.exists(index_path): | |
| raise FileNotFoundError(f"Index file not found: {index_path}") | |
| try: | |
| # Load FAISS index | |
| index = faiss.read_index(index_path) | |
| print(f"Successfully loaded FAISS index from {index_path}") | |
| # Initialize feature extractor with ONNX support | |
| feature_extractor = FeatureExtractor(base_model="vit_b_16", onnx_path=onnx_path) | |
| print("Successfully initialized feature extractor with ONNX support") | |
| except Exception as e: | |
| raise RuntimeError(f"Error initializing models: {str(e)}") | |
| def base64_to_image(base64_str: str) -> Image.Image: | |
| try: | |
| image_data = base64.b64decode(base64_str) | |
| image = Image.open(BytesIO(image_data)).convert("RGB") | |
| return image | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail="Invalid Base64 image") | |
| def image_to_base64(image: Image.Image) -> str: | |
| buffered = BytesIO() | |
| image.save(buffered, format="JPEG") | |
| return base64.b64encode(buffered.getvalue()).decode("utf-8") | |
| def unzip_folder(zip_file_path, extract_to_path): | |
| if not os.path.exists(zip_file_path): | |
| raise FileNotFoundError(f"Zip file not found: {zip_file_path}") | |
| with zipfile.ZipFile(zip_file_path, "r") as zip_ref: | |
| for member in zip_ref.infolist(): | |
| filename = member.filename.encode("cp437").decode("utf-8") | |
| extracted_path = os.path.join(extract_to_path, filename) | |
| os.makedirs(os.path.dirname(extracted_path), exist_ok=True) | |
| with zip_ref.open(member) as source, open(extracted_path, "wb") as target: | |
| target.write(source.read()) | |
| print(f"Extracted all files to: {extract_to_path}") | |
| zip_file = "./images.zip" | |
| extract_path = "./data" | |
| unzip_folder(zip_file, extract_path) | |
| def is_image_file(filename): | |
| valid_extensions = (".png", ".jpg", ".jpeg", ".bmp", ".gif", ".tiff", ".webp") | |
| return filename.lower().endswith(valid_extensions) | |
| class ImageSearchBody(BaseModel): | |
| base64_image: str = Field(..., title="Base64 Image String") | |
| async def search_image(body: ImageSearchBody): | |
| try: | |
| # Convert base64 to image | |
| image = base64_to_image(body.base64_image) | |
| # Extract features using ONNX model | |
| output = feature_extractor.extract_features(image) | |
| # Prepare features for FAISS search | |
| output = output.view(output.size(0), -1) | |
| output = output / output.norm(p=2, dim=1, keepdim=True) | |
| # Search for similar images | |
| D, I = index.search(output.cpu().numpy(), 1) | |
| # Get the matched image | |
| image_list = sorted([f for f in os.listdir(extract_path) if is_image_file(f)]) | |
| image_name = image_list[int(I[0][0])] | |
| matched_image_path = f"{extract_path}/{image_name}" | |
| matched_image = Image.open(matched_image_path) | |
| matched_image_base64 = image_to_base64(matched_image) | |
| return JSONResponse( | |
| content={ | |
| "image_base64": matched_image_base64, | |
| "image_name": image_name, | |
| "similarity_score": float(D[0][0]), | |
| }, | |
| status_code=200, | |
| ) | |
| except Exception as e: | |
| print(f"Error in search_image: {str(e)}") | |
| return JSONResponse( | |
| content={"error": f"Error processing image: {str(e)}"}, status_code=500 | |
| ) | |
| from src.firebase.firebase_provider import process_images | |
| class Body(BaseModel): | |
| base64_image: list[str] = Field(..., title="Base64 Image String") | |
| model_config = { | |
| "json_schema_extra": { | |
| "examples": [ | |
| { | |
| "base64_image": [ | |
| "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAYAAACNiR0NAAABdUlEQVR42mNk", | |
| ] | |
| } | |
| ] | |
| } | |
| } | |
| async def upload_image(body: Body): | |
| try: | |
| public_url = await process_images(body.base64_image) | |
| return JSONResponse(content={"public_url": public_url}, status_code=200) | |
| except Exception as e: | |
| return JSONResponse(content={"error": str(e)}, status_code=500) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |