|
|
""" |
|
|
Routes for the Image Similarity Search API |
|
|
Contains all endpoints for the application using your original route implementation |
|
|
""" |
|
|
|
|
|
import uuid |
|
|
import base64 |
|
|
import io |
|
|
from typing import List, Optional |
|
|
from fastapi import APIRouter, FastAPI, File, UploadFile, Form, Query, Path |
|
|
from pydantic import BaseModel |
|
|
from PIL import Image |
|
|
|
|
|
from services.embedding_service import ImageEmbeddingModel |
|
|
from services.vector_db_service import VectorDatabaseClient |
|
|
|
|
|
|
|
|
class Base64ImageRequest(BaseModel): |
|
|
"""Request model for base64 encoded images""" |
|
|
image_data: str |
|
|
|
|
|
|
|
|
def register_routes( |
|
|
app: FastAPI, |
|
|
embedding_model: ImageEmbeddingModel, |
|
|
vector_db: VectorDatabaseClient, |
|
|
): |
|
|
"""Register all routes with the FastAPI app""" |
|
|
|
|
|
@app.api_route("/", methods=["GET", "HEAD"]) |
|
|
async def read_root(): |
|
|
return {"status": "API running"} |
|
|
|
|
|
@app.post("/add-image/") |
|
|
async def add_image( |
|
|
file: UploadFile = File(...), |
|
|
item_name: str = Form(...), |
|
|
design_name: str = Form(...), |
|
|
item_price: float = Form(...) |
|
|
): |
|
|
"""Upload an image with product details and store its embedding""" |
|
|
|
|
|
|
|
|
embedding = await embedding_model.get_embedding_from_upload(file) |
|
|
|
|
|
|
|
|
image_id = str(uuid.uuid4()) |
|
|
|
|
|
|
|
|
payload = { |
|
|
"filename": file.filename, |
|
|
"item_name": item_name, |
|
|
"design_name": design_name, |
|
|
"item_price": item_price |
|
|
} |
|
|
|
|
|
|
|
|
vector_db.add_image(image_id, embedding, payload) |
|
|
|
|
|
return {"message": "Image added successfully", "id": image_id} |
|
|
|
|
|
@app.post("/add-images-from-folder/") |
|
|
async def add_images_from_folder(folder_path: str): |
|
|
"""Process and add all images from a specified folder""" |
|
|
embeddings = embedding_model.get_embeddings_from_folder(folder_path) |
|
|
return {"embeddings": embeddings} |
|
|
|
|
|
@app.post("/search-by-image/") |
|
|
async def search_by_image(file: UploadFile = File(...)): |
|
|
"""Search for similar images by uploading a file""" |
|
|
|
|
|
|
|
|
embedding = await embedding_model.get_embedding_from_upload(file) |
|
|
|
|
|
|
|
|
results = vector_db.search_by_vector(embedding, limit=1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return results |
|
|
|
|
|
@app.post("/search-by-image-scan/") |
|
|
async def search_by_image_scan(request: Base64ImageRequest): |
|
|
"""Search for similar images using a base64 encoded image""" |
|
|
|
|
|
image_data = request.image_data |
|
|
image_bytes = base64.b64decode(image_data.split(',')[1] if ',' in image_data else image_data) |
|
|
|
|
|
|
|
|
image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
|
|
|
|
|
|
|
|
embedding = embedding_model.get_embedding_from_pil(image) |
|
|
|
|
|
|
|
|
results = vector_db.search_by_vector(embedding, limit=1) |
|
|
|
|
|
return results |
|
|
|
|
|
@app.get("/collections") |
|
|
def list_collections(): |
|
|
"""List all available collections in the vector database""" |
|
|
return vector_db.list_collections() |