File size: 3,659 Bytes
b36cb8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
"""
Routes for the Image Similarity Search API
Contains all endpoints for the application using your original route implementation
"""

import uuid
import base64
import io
from typing import List, Optional
from fastapi import APIRouter, FastAPI, File, UploadFile, Form, Query, Path # type: ignore
from pydantic import BaseModel
from PIL import Image

from services.embedding_service import ImageEmbeddingModel
from services.vector_db_service import VectorDatabaseClient


class Base64ImageRequest(BaseModel):
    """Request model for base64 encoded images"""
    image_data: str


def register_routes(
    app: FastAPI,
    embedding_model: ImageEmbeddingModel,
    vector_db: VectorDatabaseClient,
):
    """Register all routes with the FastAPI app"""

    @app.api_route("/", methods=["GET", "HEAD"])
    async def read_root():
        return {"status": "API running"}
    
    @app.post("/add-image/")
    async def add_image(
        file: UploadFile = File(...),
        item_name: str = Form(...),
        design_name: str = Form(...),
        item_price: float = Form(...)
    ):
        """Upload an image with product details and store its embedding"""
        # Process the image to get embedding
        # image_data = await file.read()
        embedding = await embedding_model.get_embedding_from_upload(file)
        
        # Generate a unique ID
        image_id = str(uuid.uuid4())
        
        # Store additional metadata in payload
        payload = {
            "filename": file.filename,
            "item_name": item_name,
            "design_name": design_name,
            "item_price": item_price
        }
        
        # Store in vector database
        vector_db.add_image(image_id, embedding, payload)
        
        return {"message": "Image added successfully", "id": image_id}
    
    @app.post("/add-images-from-folder/")
    async def add_images_from_folder(folder_path: str):
        """Process and add all images from a specified folder"""
        embeddings = embedding_model.get_embeddings_from_folder(folder_path)
        return {"embeddings": embeddings}
    
    @app.post("/search-by-image/")
    async def search_by_image(file: UploadFile = File(...)):
        """Search for similar images by uploading a file"""
        # Process the image to get embedding
        # image_data = await file.read()
        embedding = await embedding_model.get_embedding_from_upload(file)
        
        # Search using the embedding
        results = vector_db.search_by_vector(embedding, limit=1)
        
        # return [
        #     {
        #         "id": r.id,
        #         "score": r.score,
        #         "payload": r.payload
        #     }
        #     for r in results
        # ]
        return results
    
    @app.post("/search-by-image-scan/")
    async def search_by_image_scan(request: Base64ImageRequest):
        """Search for similar images using a base64 encoded image"""
        # Decode base64 image
        image_data = request.image_data
        image_bytes = base64.b64decode(image_data.split(',')[1] if ',' in image_data else image_data)
        
        # Convert to PIL Image
        image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
        
        # Process image to get embedding
        embedding = embedding_model.get_embedding_from_pil(image)
        
        # Search using the embedding
        results = vector_db.search_by_vector(embedding, limit=1)
        
        return results
    
    @app.get("/collections")
    def list_collections():
        """List all available collections in the vector database"""
        return vector_db.list_collections()