File size: 2,499 Bytes
5ce8318
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from fastapi import APIRouter, status, Depends
from fastapi.responses import JSONResponse
import os
import base64
import json
from dotenv import load_dotenv


load_dotenv(override=True)
encoded_env = os.getenv("ENCODED_ENV_IMAGE")
if encoded_env:
    decoded_env = base64.b64decode(encoded_env).decode()
    env_data = json.loads(decoded_env)
    for key, value in env_data.items():
        os.environ[key] = value

import faulthandler
from PIL import Image
from src.utils.image_utils import base64_to_image, image_to_base64, is_image_file
from src.utils.zip_utils import extract_zip_file
from src.utils.model_utils import init_models, search_similar_images

# Enable fault handler to debug segmentation faults
faulthandler.enable()
load_dotenv(override=True)

# Force CPU mode to avoid segmentation faults with ONNX/PyTorch
os.environ["CUDA_VISIBLE_DEVICES"] = ""

router = APIRouter(prefix="/image_search", tags=["Image Search"])
index_path = "./model/db_vit_b_16.index"
onnx_path = "./model/vit_b_16_feature_extractor.onnx"
index, feature_extractor = init_models(index_path, onnx_path)

# Extract images if needed
zip_file = "images/images_2.zip"
extract_path = "./data"
extract_zip_file(zip_file, extract_path)

def process_search_image(base64_image):
    image = base64_to_image(base64_image)

    # Extract features using ONNX model
    features = feature_extractor.extract_features(image)

    # Search for similar images
    D, I = search_similar_images(index, features)

    # Get the matched image
    image_list = sorted(
        [f for f in os.listdir(extract_path + "/images") if is_image_file(f)]
    )
    image_name = image_list[int(I[0][0])]
    matched_image_path = f"{extract_path}/images/{image_name}"
    matched_image = Image.open(matched_image_path)
    matched_image_base64 = image_to_base64(matched_image)

    # Post-process image name: remove underscores, numbers, and file extension
    image_name_post_process = image_name.replace(
        "_", " "
    )  # Replace underscores with spaces
    image_name_post_process = "".join(
        [c for c in image_name_post_process if not c.isdigit()]
    )  # Remove numbers
    image_name_post_process = image_name_post_process.rsplit(".", 1)[
        0
    ]  # Remove file extension
    print("image_base64: ", matched_image_base64, "image_name: ", image_name_post_process, "similarity_score: ", float(D[0][0]))
    similarity_score = float(D[0][0])
    return matched_image_base64,image_name_post_process,similarity_score