File size: 3,659 Bytes
b36cb8b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
"""
Routes for the Image Similarity Search API
Contains all endpoints for the application using your original route implementation
"""
import uuid
import base64
import io
from typing import List, Optional
from fastapi import APIRouter, FastAPI, File, UploadFile, Form, Query, Path # type: ignore
from pydantic import BaseModel
from PIL import Image
from services.embedding_service import ImageEmbeddingModel
from services.vector_db_service import VectorDatabaseClient
class Base64ImageRequest(BaseModel):
"""Request model for base64 encoded images"""
image_data: str
def register_routes(
app: FastAPI,
embedding_model: ImageEmbeddingModel,
vector_db: VectorDatabaseClient,
):
"""Register all routes with the FastAPI app"""
@app.api_route("/", methods=["GET", "HEAD"])
async def read_root():
return {"status": "API running"}
@app.post("/add-image/")
async def add_image(
file: UploadFile = File(...),
item_name: str = Form(...),
design_name: str = Form(...),
item_price: float = Form(...)
):
"""Upload an image with product details and store its embedding"""
# Process the image to get embedding
# image_data = await file.read()
embedding = await embedding_model.get_embedding_from_upload(file)
# Generate a unique ID
image_id = str(uuid.uuid4())
# Store additional metadata in payload
payload = {
"filename": file.filename,
"item_name": item_name,
"design_name": design_name,
"item_price": item_price
}
# Store in vector database
vector_db.add_image(image_id, embedding, payload)
return {"message": "Image added successfully", "id": image_id}
@app.post("/add-images-from-folder/")
async def add_images_from_folder(folder_path: str):
"""Process and add all images from a specified folder"""
embeddings = embedding_model.get_embeddings_from_folder(folder_path)
return {"embeddings": embeddings}
@app.post("/search-by-image/")
async def search_by_image(file: UploadFile = File(...)):
"""Search for similar images by uploading a file"""
# Process the image to get embedding
# image_data = await file.read()
embedding = await embedding_model.get_embedding_from_upload(file)
# Search using the embedding
results = vector_db.search_by_vector(embedding, limit=1)
# return [
# {
# "id": r.id,
# "score": r.score,
# "payload": r.payload
# }
# for r in results
# ]
return results
@app.post("/search-by-image-scan/")
async def search_by_image_scan(request: Base64ImageRequest):
"""Search for similar images using a base64 encoded image"""
# Decode base64 image
image_data = request.image_data
image_bytes = base64.b64decode(image_data.split(',')[1] if ',' in image_data else image_data)
# Convert to PIL Image
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
# Process image to get embedding
embedding = embedding_model.get_embedding_from_pil(image)
# Search using the embedding
results = vector_db.search_by_vector(embedding, limit=1)
return results
@app.get("/collections")
def list_collections():
"""List all available collections in the vector database"""
return vector_db.list_collections() |