Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -45,7 +45,7 @@ if not os.path.exists(index_path):
|
|
| 45 |
raise FileNotFoundError(f"Index file not found: {index_path}")
|
| 46 |
|
| 47 |
try:
|
| 48 |
-
index =
|
| 49 |
except RuntimeError as e:
|
| 50 |
raise RuntimeError(f"Error reading FAISS index: {e}")
|
| 51 |
feature_extractor = FeatureExtractor(base_model="vit_b_16")
|
|
@@ -104,32 +104,32 @@ class ImageSearchBody(BaseModel):
|
|
| 104 |
|
| 105 |
@app.post("/search-image/")
|
| 106 |
async def search_image(body: ImageSearchBody):
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
|
| 134 |
|
| 135 |
from src.firebase.firebase_provider import process_images
|
|
|
|
| 45 |
raise FileNotFoundError(f"Index file not found: {index_path}")
|
| 46 |
|
| 47 |
try:
|
| 48 |
+
index = faiss.read_index(index_path)
|
| 49 |
except RuntimeError as e:
|
| 50 |
raise RuntimeError(f"Error reading FAISS index: {e}")
|
| 51 |
feature_extractor = FeatureExtractor(base_model="vit_b_16")
|
|
|
|
| 104 |
|
| 105 |
@app.post("/search-image/")
|
| 106 |
async def search_image(body: ImageSearchBody):
|
| 107 |
+
try:
|
| 108 |
+
image = base64_to_image(body.base64_image)
|
| 109 |
+
with torch.no_grad():
|
| 110 |
+
output = feature_extractor.extract_features(image)
|
| 111 |
+
output = output.view(output.size(0), -1)
|
| 112 |
+
output = output / output.norm(p=2, dim=1, keepdim=True)
|
| 113 |
+
D, I = index.search(output.cpu().numpy(), 1)
|
| 114 |
+
print(D, I)
|
| 115 |
+
image_list = sorted([f for f in os.listdir(extract_path) if is_image_file(f)])
|
| 116 |
+
print(image_list)
|
| 117 |
+
image_name = image_list[int(I[0][0])]
|
| 118 |
+
matched_image_path = f"{extract_path}/{image_list[int(I[0][0])]}"
|
| 119 |
+
matched_image = Image.open(matched_image_path)
|
| 120 |
+
matched_image_base64 = image_to_base64(matched_image)
|
| 121 |
+
|
| 122 |
+
return JSONResponse(
|
| 123 |
+
content={
|
| 124 |
+
"image_base64": matched_image_base64,
|
| 125 |
+
"image_name": image_name,
|
| 126 |
+
"similarity_score": float(D[0][0]),
|
| 127 |
+
},
|
| 128 |
+
status_code=200,
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
except Exception as e:
|
| 132 |
+
return JSONResponse(content={"error": str(e)}, status_code=500)
|
| 133 |
|
| 134 |
|
| 135 |
from src.firebase.firebase_provider import process_images
|