ABAO77's picture
Upload 12 files
0ec5620 verified
import os
import base64
import json
from dotenv import load_dotenv
load_dotenv(override=True)
encoded_env = os.getenv("ENCODED_ENV_IMAGE")
if encoded_env:
decoded_env = base64.b64decode(encoded_env).decode()
env_data = json.loads(decoded_env)
for key, value in env_data.items():
os.environ[key] = value
import torch
from fastapi import FastAPI
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from dotenv import load_dotenv
import faulthandler
from PIL import Image
from src.utils.image_utils import base64_to_image, image_to_base64, is_image_file
from src.utils.zip_utils import extract_zip_file
from src.utils.model_utils import init_models, search_similar_images
from src.firebase.firebase_provider import process_images
# Enable fault handler to debug segmentation faults
faulthandler.enable()
load_dotenv(override=True)
# Force CPU mode to avoid segmentation faults with ONNX/PyTorch
os.environ["CUDA_VISIBLE_DEVICES"] = ""
torch.set_num_threads(1)
# Load environment variables
# Initialize FastAPI app
app = FastAPI(docs_url="/")
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize paths and models
index_path = "./model/db_vit_b_16.index"
onnx_path = "./model/vit_b_16_feature_extractor.onnx"
index, feature_extractor = init_models(index_path, onnx_path)
# Extract images if needed
zip_file = "./images_2.zip"
extract_path = "./data"
extract_zip_file(zip_file, extract_path)
class ImageSearchBody(BaseModel):
base64_image: str = Field(..., title="Base64 Image String")
@app.post("/search-image/")
def search_image(body: ImageSearchBody):
try:
# Convert base64 to image
image = base64_to_image(body.base64_image)
# Extract features using ONNX model
features = feature_extractor.extract_features(image)
# Search for similar images
D, I = search_similar_images(index, features)
# Get the matched image
image_list = sorted(
[f for f in os.listdir(extract_path + "/images") if is_image_file(f)]
)
image_name = image_list[int(I[0][0])]
matched_image_path = f"{extract_path}/images/{image_name}"
matched_image = Image.open(matched_image_path)
matched_image_base64 = image_to_base64(matched_image)
# Post-process image name: remove underscores, numbers, and file extension
image_name_post_process = image_name.replace(
"_", " "
) # Replace underscores with spaces
image_name_post_process = "".join(
[c for c in image_name_post_process if not c.isdigit()]
) # Remove numbers
image_name_post_process = image_name_post_process.rsplit(".", 1)[
0
] # Remove file extension
return JSONResponse(
content={
"image_base64": matched_image_base64,
"image_name": image_name_post_process,
"similarity_score": float(D[0][0]),
},
status_code=200,
)
except Exception as e:
print(f"Error in search_image: {str(e)}")
return JSONResponse(
content={"error": f"Error processing image: {str(e)}"}, status_code=500
)
class Body(BaseModel):
base64_image: list[str] = Field(..., title="Base64 Image String")
model_config = {
"json_schema_extra": {
"examples": [
{
"base64_image": [
"iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAYAAACNiR0NAAABdUlEQVR42mNk",
]
}
]
}
}
@app.post("/upload_image")
async def upload_image(body: Body):
try:
public_url = await process_images(body.base64_image)
return JSONResponse(content={"public_url": public_url}, status_code=200)
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=500)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)