File size: 5,277 Bytes
982b011
 
 
55cd92f
982b011
351bcee
55cd92f
 
982b011
55cd92f
8b4bdcd
351bcee
36457bd
 
04ab814
36457bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
982b011
55cd92f
 
 
 
 
 
 
 
 
982b011
4dc9354
b4384d8
4dc9354
b4384d8
4dc9354
b4384d8
 
 
 
4dc9354
55ecbbd
4dc9354
982b011
4dc9354
 
 
 
 
607aaf8
982b011
55cd92f
 
 
 
 
 
 
982b011
dfde318
55cd92f
 
 
 
712da4d
 
8b4bdcd
 
 
712da4d
351bcee
 
 
 
 
 
8b4bdcd
 
712da4d
351bcee
 
8b4bdcd
 
 
351bcee
 
 
712da4d
 
55cd92f
 
dfde318
712da4d
982b011
55cd92f
55ecbbd
4dc9354
55ecbbd
4dc9354
 
 
 
 
 
 
 
 
 
 
 
55ecbbd
 
4dc9354
55ecbbd
 
4dc9354
55ecbbd
 
 
 
 
 
 
 
 
 
4dc9354
 
 
 
982b011
dfde318
b4384d8
04ab814
4dc9354
519e314
b4384d8
04ab814
 
 
 
b4384d8
 
 
04ab814
 
 
 
519e314
 
 
b4384d8
519e314
b4384d8
519e314
 
 
 
 
982b011
 
dfde318
982b011
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import os
import torch
import faiss
import base64
from PIL import Image
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from io import BytesIO
from src.modules import FeatureExtractor
from fastapi.middleware.cors import CORSMiddleware
import zipfile
from pydantic import BaseModel, Field
import json
from dotenv import load_dotenv

load_dotenv(override=True)

encoded_env = os.getenv("ENCODED_ENV")
if encoded_env:
    # Decode the base64 string
    decoded_env = base64.b64decode(encoded_env).decode()

    # Load it as a dictionary
    env_data = json.loads(decoded_env)

    # Set environment variables
    for key, value in env_data.items():
        os.environ[key] = value


app = FastAPI(docs_url="/")
origins = ["*"]

app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Initialize paths
index_path = "./model/db_vit_b_16.index"
onnx_path = "./model/vit_b_16_feature_extractor.onnx"

# Check if index file exists
if not os.path.exists(index_path):
    raise FileNotFoundError(f"Index file not found: {index_path}")

try:
    # Load FAISS index
    index = faiss.read_index(index_path)
    print(f"Successfully loaded FAISS index from {index_path}")

    # Initialize feature extractor with ONNX support
    feature_extractor = FeatureExtractor(base_model="vit_b_16", onnx_path=onnx_path)
    print("Successfully initialized feature extractor with ONNX support")
except Exception as e:
    raise RuntimeError(f"Error initializing models: {str(e)}")


def base64_to_image(base64_str: str) -> Image.Image:
    try:
        image_data = base64.b64decode(base64_str)
        image = Image.open(BytesIO(image_data)).convert("RGB")
        return image
    except Exception as e:
        raise HTTPException(status_code=400, detail="Invalid Base64 image")


def image_to_base64(image: Image.Image) -> str:
    buffered = BytesIO()
    image.save(buffered, format="JPEG")
    return base64.b64encode(buffered.getvalue()).decode("utf-8")


def unzip_folder(zip_file_path, extract_to_path):
    if not os.path.exists(zip_file_path):
        raise FileNotFoundError(f"Zip file not found: {zip_file_path}")
    with zipfile.ZipFile(zip_file_path, "r") as zip_ref:
        for member in zip_ref.infolist():
            filename = member.filename.encode("cp437").decode("utf-8")
            extracted_path = os.path.join(extract_to_path, filename)
            os.makedirs(os.path.dirname(extracted_path), exist_ok=True)
            with zip_ref.open(member) as source, open(extracted_path, "wb") as target:
                target.write(source.read())
        print(f"Extracted all files to: {extract_to_path}")


zip_file = "./images.zip"
extract_path = "./data"
unzip_folder(zip_file, extract_path)


def is_image_file(filename):
    valid_extensions = (".png", ".jpg", ".jpeg", ".bmp", ".gif", ".tiff", ".webp")
    return filename.lower().endswith(valid_extensions)


class ImageSearchBody(BaseModel):
    base64_image: str = Field(..., title="Base64 Image String")


@app.post("/search-image/")
async def search_image(body: ImageSearchBody):
    try:
        # Convert base64 to image
        image = base64_to_image(body.base64_image)

        # Extract features using ONNX model
        output = feature_extractor.extract_features(image)

        # Prepare features for FAISS search
        output = output.view(output.size(0), -1)
        output = output / output.norm(p=2, dim=1, keepdim=True)

        # Search for similar images
        D, I = index.search(output.cpu().numpy(), 1)

        # Get the matched image
        image_list = sorted([f for f in os.listdir(extract_path) if is_image_file(f)])
        image_name = image_list[int(I[0][0])]
        matched_image_path = f"{extract_path}/{image_name}"
        matched_image = Image.open(matched_image_path)
        matched_image_base64 = image_to_base64(matched_image)

        return JSONResponse(
            content={
                "image_base64": matched_image_base64,
                "image_name": image_name,
                "similarity_score": float(D[0][0]),
            },
            status_code=200,
        )

    except Exception as e:
        print(f"Error in search_image: {str(e)}")
        return JSONResponse(
            content={"error": f"Error processing image: {str(e)}"}, status_code=500
        )


from src.firebase.firebase_provider import process_images


class Body(BaseModel):
    base64_image: list[str] = Field(..., title="Base64 Image String")
    model_config = {
        "json_schema_extra": {
            "examples": [
                {
                    "base64_image": [
                        "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAYAAACNiR0NAAABdUlEQVR42mNk",
                    ]
                }
            ]
        }
    }


@app.post("/upload_image")
async def upload_image(body: Body):
    try:
        public_url = await process_images(body.base64_image)
        return JSONResponse(content={"public_url": public_url}, status_code=200)
    except Exception as e:
        return JSONResponse(content={"error": str(e)}, status_code=500)


if __name__ == "__main__":
    import uvicorn

    uvicorn.run(app, host="0.0.0.0", port=8000)