FastAPI / services /vector_db_service.py
ravi19's picture
Deploy FastAPI to HF Space
b36cb8b
raw
history blame
2.53 kB
"""
Vector database service for interacting with Qdrant
"""
from typing import List, Dict, Any
from fastapi import HTTPException # type: ignore
from qdrant_client import QdrantClient # type: ignore
from qdrant_client.models import Distance, PointStruct, VectorParams # type: ignore
class VectorDatabaseClient:
"""Class for interacting with Qdrant vector database"""
def __init__(self, url: str, api_key: str, collection_name: str, embedding_size: int):
self.url = url
self.api_key = api_key
self.collection_name = collection_name
self.embedding_size = embedding_size
self.client = QdrantClient(url=url, api_key=api_key)
def ensure_collection_exists(self) -> None:
"""Ensure the Qdrant collection exists"""
collections = self.client.get_collections()
collection_names = [c.name for c in collections.collections]
if self.collection_name not in collection_names:
self.client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(
size=self.embedding_size,
distance=Distance.COSINE
)
)
print(f"✅ Collection '{self.collection_name}' created.")
else:
print(f"ℹ️ Collection '{self.collection_name}' already exists.")
def add_image(self, image_id: str, embedding: List[float], payload: Dict[str, Any]) -> None:
"""Add an image embedding to the database"""
self.client.upsert(
collection_name=self.collection_name,
points=[
PointStruct(
id=image_id,
vector=embedding,
payload=payload
)
]
)
def search_by_vector(self, embedding: List[float], limit: int = 1) -> List[Dict[str, Any]]:
"""Search for similar images using an embedding vector"""
results = self.client.search(
collection_name=self.collection_name,
query_vector=embedding,
limit=limit
)
return [
{
"id": r.id,
"score": r.score,
"payload": r.payload
}
for r in results
]
def list_collections(self) -> List[str]:
"""List all collections in the database"""
return [c.name for c in self.client.get_collections().collections]