Spaces:
Sleeping
Sleeping
File size: 4,198 Bytes
0ec5620 5c8a6b6 0ec5620 5c8a6b6 0ec5620 5c8a6b6 73ce1d5 0ec5620 5c8a6b6 0ec5620 73ce1d5 0ec5620 73ce1d5 0ec5620 73ce1d5 5c8a6b6 73ce1d5 0ec5620 73ce1d5 0ec5620 5c8a6b6 73ce1d5 0ec5620 73ce1d5 0ec5620 73ce1d5 0ec5620 73ce1d5 0ec5620 73ce1d5 0ec5620 73ce1d5 0ec5620 73ce1d5 0ec5620 73ce1d5 0ec5620 73ce1d5 0ec5620 73ce1d5 0ec5620 73ce1d5 0ec5620 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 | import os
import base64
import json
from dotenv import load_dotenv
load_dotenv(override=True)
encoded_env = os.getenv("ENCODED_ENV_IMAGE")
if encoded_env:
decoded_env = base64.b64decode(encoded_env).decode()
env_data = json.loads(decoded_env)
for key, value in env_data.items():
os.environ[key] = value
import torch
from fastapi import FastAPI
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from dotenv import load_dotenv
import faulthandler
from PIL import Image
from src.utils.image_utils import base64_to_image, image_to_base64, is_image_file
from src.utils.zip_utils import extract_zip_file
from src.utils.model_utils import init_models, search_similar_images
from src.firebase.firebase_provider import process_images
# Enable fault handler to debug segmentation faults
faulthandler.enable()
load_dotenv(override=True)
# Force CPU mode to avoid segmentation faults with ONNX/PyTorch
os.environ["CUDA_VISIBLE_DEVICES"] = ""
torch.set_num_threads(1)
# Load environment variables
# Initialize FastAPI app
app = FastAPI(docs_url="/")
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize paths and models
index_path = "./model/db_vit_b_16.index"
onnx_path = "./model/vit_b_16_feature_extractor.onnx"
index, feature_extractor = init_models(index_path, onnx_path)
# Extract images if needed
zip_file = "./images_2.zip"
extract_path = "./data"
extract_zip_file(zip_file, extract_path)
class ImageSearchBody(BaseModel):
base64_image: str = Field(..., title="Base64 Image String")
@app.post("/search-image/")
def search_image(body: ImageSearchBody):
try:
# Convert base64 to image
image = base64_to_image(body.base64_image)
# Extract features using ONNX model
features = feature_extractor.extract_features(image)
# Search for similar images
D, I = search_similar_images(index, features)
# Get the matched image
image_list = sorted(
[f for f in os.listdir(extract_path + "/images") if is_image_file(f)]
)
image_name = image_list[int(I[0][0])]
matched_image_path = f"{extract_path}/images/{image_name}"
matched_image = Image.open(matched_image_path)
matched_image_base64 = image_to_base64(matched_image)
# Post-process image name: remove underscores, numbers, and file extension
image_name_post_process = image_name.replace(
"_", " "
) # Replace underscores with spaces
image_name_post_process = "".join(
[c for c in image_name_post_process if not c.isdigit()]
) # Remove numbers
image_name_post_process = image_name_post_process.rsplit(".", 1)[
0
] # Remove file extension
return JSONResponse(
content={
"image_base64": matched_image_base64,
"image_name": image_name_post_process,
"similarity_score": float(D[0][0]),
},
status_code=200,
)
except Exception as e:
print(f"Error in search_image: {str(e)}")
return JSONResponse(
content={"error": f"Error processing image: {str(e)}"}, status_code=500
)
class Body(BaseModel):
base64_image: list[str] = Field(..., title="Base64 Image String")
model_config = {
"json_schema_extra": {
"examples": [
{
"base64_image": [
"iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAYAAACNiR0NAAABdUlEQVR42mNk",
]
}
]
}
}
@app.post("/upload_image")
async def upload_image(body: Body):
try:
public_url = await process_images(body.base64_image)
return JSONResponse(content={"public_url": public_url}, status_code=200)
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=500)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
|