Spaces:
Runtime error
Runtime error
File size: 5,277 Bytes
982b011 55cd92f 982b011 351bcee 55cd92f 982b011 55cd92f 8b4bdcd 351bcee 36457bd 04ab814 36457bd 982b011 55cd92f 982b011 4dc9354 b4384d8 4dc9354 b4384d8 4dc9354 b4384d8 4dc9354 55ecbbd 4dc9354 982b011 4dc9354 607aaf8 982b011 55cd92f 982b011 dfde318 55cd92f 712da4d 8b4bdcd 712da4d 351bcee 8b4bdcd 712da4d 351bcee 8b4bdcd 351bcee 712da4d 55cd92f dfde318 712da4d 982b011 55cd92f 55ecbbd 4dc9354 55ecbbd 4dc9354 55ecbbd 4dc9354 55ecbbd 4dc9354 55ecbbd 4dc9354 982b011 dfde318 b4384d8 04ab814 4dc9354 519e314 b4384d8 04ab814 b4384d8 04ab814 519e314 b4384d8 519e314 b4384d8 519e314 982b011 dfde318 982b011 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 | import os
import torch
import faiss
import base64
from PIL import Image
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from io import BytesIO
from src.modules import FeatureExtractor
from fastapi.middleware.cors import CORSMiddleware
import zipfile
from pydantic import BaseModel, Field
import json
from dotenv import load_dotenv
load_dotenv(override=True)
encoded_env = os.getenv("ENCODED_ENV")
if encoded_env:
# Decode the base64 string
decoded_env = base64.b64decode(encoded_env).decode()
# Load it as a dictionary
env_data = json.loads(decoded_env)
# Set environment variables
for key, value in env_data.items():
os.environ[key] = value
app = FastAPI(docs_url="/")
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize paths
index_path = "./model/db_vit_b_16.index"
onnx_path = "./model/vit_b_16_feature_extractor.onnx"
# Check if index file exists
if not os.path.exists(index_path):
raise FileNotFoundError(f"Index file not found: {index_path}")
try:
# Load FAISS index
index = faiss.read_index(index_path)
print(f"Successfully loaded FAISS index from {index_path}")
# Initialize feature extractor with ONNX support
feature_extractor = FeatureExtractor(base_model="vit_b_16", onnx_path=onnx_path)
print("Successfully initialized feature extractor with ONNX support")
except Exception as e:
raise RuntimeError(f"Error initializing models: {str(e)}")
def base64_to_image(base64_str: str) -> Image.Image:
try:
image_data = base64.b64decode(base64_str)
image = Image.open(BytesIO(image_data)).convert("RGB")
return image
except Exception as e:
raise HTTPException(status_code=400, detail="Invalid Base64 image")
def image_to_base64(image: Image.Image) -> str:
buffered = BytesIO()
image.save(buffered, format="JPEG")
return base64.b64encode(buffered.getvalue()).decode("utf-8")
def unzip_folder(zip_file_path, extract_to_path):
if not os.path.exists(zip_file_path):
raise FileNotFoundError(f"Zip file not found: {zip_file_path}")
with zipfile.ZipFile(zip_file_path, "r") as zip_ref:
for member in zip_ref.infolist():
filename = member.filename.encode("cp437").decode("utf-8")
extracted_path = os.path.join(extract_to_path, filename)
os.makedirs(os.path.dirname(extracted_path), exist_ok=True)
with zip_ref.open(member) as source, open(extracted_path, "wb") as target:
target.write(source.read())
print(f"Extracted all files to: {extract_to_path}")
zip_file = "./images.zip"
extract_path = "./data"
unzip_folder(zip_file, extract_path)
def is_image_file(filename):
valid_extensions = (".png", ".jpg", ".jpeg", ".bmp", ".gif", ".tiff", ".webp")
return filename.lower().endswith(valid_extensions)
class ImageSearchBody(BaseModel):
base64_image: str = Field(..., title="Base64 Image String")
@app.post("/search-image/")
async def search_image(body: ImageSearchBody):
try:
# Convert base64 to image
image = base64_to_image(body.base64_image)
# Extract features using ONNX model
output = feature_extractor.extract_features(image)
# Prepare features for FAISS search
output = output.view(output.size(0), -1)
output = output / output.norm(p=2, dim=1, keepdim=True)
# Search for similar images
D, I = index.search(output.cpu().numpy(), 1)
# Get the matched image
image_list = sorted([f for f in os.listdir(extract_path) if is_image_file(f)])
image_name = image_list[int(I[0][0])]
matched_image_path = f"{extract_path}/{image_name}"
matched_image = Image.open(matched_image_path)
matched_image_base64 = image_to_base64(matched_image)
return JSONResponse(
content={
"image_base64": matched_image_base64,
"image_name": image_name,
"similarity_score": float(D[0][0]),
},
status_code=200,
)
except Exception as e:
print(f"Error in search_image: {str(e)}")
return JSONResponse(
content={"error": f"Error processing image: {str(e)}"}, status_code=500
)
from src.firebase.firebase_provider import process_images
class Body(BaseModel):
base64_image: list[str] = Field(..., title="Base64 Image String")
model_config = {
"json_schema_extra": {
"examples": [
{
"base64_image": [
"iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAYAAACNiR0NAAABdUlEQVR42mNk",
]
}
]
}
}
@app.post("/upload_image")
async def upload_image(body: Body):
try:
public_url = await process_images(body.base64_image)
return JSONResponse(content={"public_url": public_url}, status_code=200)
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=500)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
|