File size: 4,198 Bytes
0ec5620
5c8a6b6
 
0ec5620
 
5c8a6b6
0ec5620
5c8a6b6
 
 
 
 
 
 
73ce1d5
 
0ec5620
 
 
5c8a6b6
 
 
0ec5620
 
 
73ce1d5
 
 
0ec5620
73ce1d5
 
 
 
 
0ec5620
73ce1d5
 
5c8a6b6
73ce1d5
 
 
 
 
 
 
 
 
 
 
0ec5620
73ce1d5
 
0ec5620
 
 
5c8a6b6
 
73ce1d5
 
 
0ec5620
 
 
 
73ce1d5
 
 
 
 
 
 
0ec5620
73ce1d5
 
0ec5620
73ce1d5
 
0ec5620
 
 
73ce1d5
 
 
 
0ec5620
 
 
 
 
 
 
 
 
 
 
 
73ce1d5
 
 
 
 
 
 
 
 
 
 
 
0ec5620
73ce1d5
 
 
0ec5620
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73ce1d5
0ec5620
73ce1d5
 
0ec5620
73ce1d5
0ec5620
73ce1d5
 
 
 
0ec5620
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import os
import base64
import json
from dotenv import load_dotenv

load_dotenv(override=True)
encoded_env = os.getenv("ENCODED_ENV_IMAGE")
if encoded_env:
    decoded_env = base64.b64decode(encoded_env).decode()
    env_data = json.loads(decoded_env)
    for key, value in env_data.items():
        os.environ[key] = value
import torch
from fastapi import FastAPI
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from dotenv import load_dotenv
import faulthandler
from PIL import Image

from src.utils.image_utils import base64_to_image, image_to_base64, is_image_file
from src.utils.zip_utils import extract_zip_file
from src.utils.model_utils import init_models, search_similar_images
from src.firebase.firebase_provider import process_images

# Enable fault handler to debug segmentation faults
faulthandler.enable()
load_dotenv(override=True)

# Force CPU mode to avoid segmentation faults with ONNX/PyTorch
os.environ["CUDA_VISIBLE_DEVICES"] = ""
torch.set_num_threads(1)

# Load environment variables


# Initialize FastAPI app
app = FastAPI(docs_url="/")
origins = ["*"]

app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Initialize paths and models
index_path = "./model/db_vit_b_16.index"
onnx_path = "./model/vit_b_16_feature_extractor.onnx"
index, feature_extractor = init_models(index_path, onnx_path)

# Extract images if needed
zip_file = "./images_2.zip"
extract_path = "./data"
extract_zip_file(zip_file, extract_path)


class ImageSearchBody(BaseModel):
    base64_image: str = Field(..., title="Base64 Image String")


@app.post("/search-image/")
def search_image(body: ImageSearchBody):
    try:
        # Convert base64 to image
        image = base64_to_image(body.base64_image)

        # Extract features using ONNX model
        features = feature_extractor.extract_features(image)

        # Search for similar images
        D, I = search_similar_images(index, features)

        # Get the matched image
        image_list = sorted(
            [f for f in os.listdir(extract_path + "/images") if is_image_file(f)]
        )
        image_name = image_list[int(I[0][0])]
        matched_image_path = f"{extract_path}/images/{image_name}"
        matched_image = Image.open(matched_image_path)
        matched_image_base64 = image_to_base64(matched_image)

        # Post-process image name: remove underscores, numbers, and file extension
        image_name_post_process = image_name.replace(
            "_", " "
        )  # Replace underscores with spaces
        image_name_post_process = "".join(
            [c for c in image_name_post_process if not c.isdigit()]
        )  # Remove numbers
        image_name_post_process = image_name_post_process.rsplit(".", 1)[
            0
        ]  # Remove file extension

        return JSONResponse(
            content={
                "image_base64": matched_image_base64,
                "image_name": image_name_post_process,
                "similarity_score": float(D[0][0]),
            },
            status_code=200,
        )

    except Exception as e:
        print(f"Error in search_image: {str(e)}")
        return JSONResponse(
            content={"error": f"Error processing image: {str(e)}"}, status_code=500
        )


class Body(BaseModel):
    base64_image: list[str] = Field(..., title="Base64 Image String")
    model_config = {
        "json_schema_extra": {
            "examples": [
                {
                    "base64_image": [
                        "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAYAAACNiR0NAAABdUlEQVR42mNk",
                    ]
                }
            ]
        }
    }


@app.post("/upload_image")
async def upload_image(body: Body):
    try:
        public_url = await process_images(body.base64_image)
        return JSONResponse(content={"public_url": public_url}, status_code=200)
    except Exception as e:
        return JSONResponse(content={"error": str(e)}, status_code=500)


if __name__ == "__main__":
    import uvicorn

    uvicorn.run(app, host="0.0.0.0", port=8000)