from __future__ import annotations import asyncio import concurrent.futures import json import os import shutil import time import uuid from datetime import datetime, timezone from typing import Any, Dict, List, Optional, Tuple import zvec from sqlalchemy import select from app.config import get_settings from app.core.logger import get_logger from app.core.vector_store.deps import AsyncSessionLocal from app.core.vector_store.models import VectorStoreIndex from app.services.chunking_service import chunk_text_async from app.services.embeddings_service import EmbeddingService logger = get_logger(__name__) settings = get_settings() _EMBEDDING_DIM = 384 _MAX_WORKERS = min(16, (os.cpu_count() or 1) + 4) def _run_sync(fn, *args, **kwargs): return fn(*args, **kwargs) class VectorStoreRecord: def __init__( self, store_id: str, name: str, path: str, description: str = "", metadata: Optional[Dict[str, Any]] = None, created_at: Optional[str] = None, ): self.store_id = store_id self.name = name self.path = path self.description = description self.metadata = metadata or {} self.created_at = created_at or datetime.now(timezone.utc).isoformat() class VectorStoreService: def __init__(self, embedding_service: EmbeddingService): self._embedding_service = embedding_service self._stores: Dict[str, VectorStoreRecord] = {} self._collections: Dict[str, "zvec.Collection"] = {} self._data_dir = os.path.join(settings.data_dir, "vector_stores") self._thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=_MAX_WORKERS, thread_name_prefix="zvec") os.makedirs(self._data_dir, exist_ok=True) def _store_path(self, store_id: str) -> str: return os.path.join(self._data_dir, f"store_{store_id}") def _get_collection(self, store_id: str) -> Optional["zvec.Collection"]: return self._collections.get(store_id) # --- SQLite persistence --- async def init_db(self) -> None: from app.core.vector_store.deps import init_vector_store_db await init_vector_store_db() async with AsyncSessionLocal() as session: result = await session.execute(select(VectorStoreIndex)) rows = result.scalars().all() for row in rows: d = row.to_dict() record = VectorStoreRecord( store_id=d["store_id"], name=d["name"], path=d["path"], description=d["description"], metadata=d["metadata"], created_at=d["created_at"], ) self._stores[record.store_id] = record store_path = d["path"] if os.path.exists(os.path.join(store_path, "__zvec_meta")): try: col = zvec.open(store_path) if col is not None: self._collections[record.store_id] = col except Exception as exc: logger.warning("Could not open collection %s: %s", record.store_id, exc) async def _persist_store(self, record: VectorStoreRecord) -> None: async with AsyncSessionLocal() as session: existing = await session.get(VectorStoreIndex, record.store_id) if existing: existing.name = record.name existing.description = record.description existing.metadata_json = json.dumps(record.metadata) else: session.add(VectorStoreIndex.from_dict({ "store_id": record.store_id, "name": record.name, "path": record.path, "description": record.description, "metadata": record.metadata, "created_at": record.created_at, })) await session.commit() async def _remove_persisted_store(self, store_id: str) -> None: async with AsyncSessionLocal() as session: row = await session.get(VectorStoreIndex, store_id) if row: await session.delete(row) await session.commit() # --- Synchronous helpers (run in thread pool) --- def _open_or_create_collection_sync(self, store_id: str, store_path: str) -> "zvec.Collection": col = self._collections.get(store_id) if col is not None: return col if os.path.exists(os.path.join(store_path, "__zvec_meta")): logger.info("Opening existing collection: %s", store_id) try: col = zvec.open(store_path) if col is not None: self._collections[store_id] = col return col except Exception as exc: logger.warning("Could not open collection %s: %s", store_id, exc) schema = zvec.CollectionSchema( name=f"store_{store_id}", fields=[ zvec.FieldSchema(name="text", data_type=zvec.DataType.STRING), zvec.FieldSchema( name="doc_id", data_type=zvec.DataType.STRING, index_param=zvec.InvertIndexParam(), ), zvec.FieldSchema(name="chunk_index", data_type=zvec.DataType.INT32), zvec.FieldSchema( name="source", data_type=zvec.DataType.STRING, index_param=zvec.InvertIndexParam(), ), zvec.FieldSchema(name="created_at", data_type=zvec.DataType.INT64), ], vectors=[ zvec.VectorSchema( name="embedding", data_type=zvec.DataType.VECTOR_FP32, dimension=_EMBEDDING_DIM, index_param=zvec.HnswIndexParam(metric_type=zvec.MetricType.COSINE), ), ], ) logger.info("Creating new collection: %s", store_id) col = zvec.create_and_open(path=store_path, schema=schema) self._collections[store_id] = col return col def _ingest_document_sync( self, store_id: str, doc_id: str, text: str, chunks: List[str], embeddings: List[List[float]], source: str, ) -> Tuple[int, float]: col = self._get_collection(store_id) if col is None: record = self._stores.get(store_id) if record is None: raise ValueError(f"Vector store {store_id} not found") col = self._open_or_create_collection_sync(store_id, record.path) now_ts = int(time.time()) docs = [ zvec.Doc( id=f"{doc_id}_{i}", vectors={"embedding": emb}, fields={ "text": chunk_text_str, "doc_id": doc_id, "chunk_index": i, "source": source or "", "created_at": now_ts, }, ) for i, (chunk_text_str, emb) in enumerate(zip(chunks, embeddings)) ] for i in range(0, len(docs), 100): col.insert(docs[i:i + 100]) col.flush() col.optimize() return len(chunks), 0.0 def _search_sync( self, store_id: str, query_emb: List[float], top_k: int, filter_expr: Optional[str], min_score: Optional[float] = None, include_vectors: bool = False, include_metadata: bool = False, ) -> List[Dict[str, Any]]: col = self._get_collection(store_id) if col is None: record = self._stores.get(store_id) if record is None: raise ValueError(f"Vector store {store_id} not found") col = self._open_or_create_collection_sync(store_id, record.path) kwargs: Dict[str, Any] = { "vectors": zvec.VectorQuery(field_name="embedding", vector=query_emb), "topk": top_k, } if filter_expr: kwargs["filter"] = filter_expr results = col.query(**kwargs) items = [] for i, r in enumerate(results): score = float(r.score) if hasattr(r, "score") and r.score is not None else 0.0 if min_score is not None and score > (1.0 - min_score): continue item: Dict[str, Any] = { "rank": len(items) + 1, "doc_id": r.field("doc_id") if hasattr(r, "field") else r.id.rsplit("_", 1)[0], "chunk_index": r.field("chunk_index") if hasattr(r, "field") else 0, "text": r.field("text") if hasattr(r, "field") else "", "score": score, "source": r.field("source") if hasattr(r, "field") else "", "metadata": {}, "vector": None, } if include_vectors: try: vec = r.vector("embedding") if hasattr(r, "vector") else None item["vector"] = list(vec) if vec is not None else None except Exception: item["vector"] = None if include_metadata: item["metadata"] = { "doc_id": item["doc_id"], "chunk_index": item["chunk_index"], "source": item["source"], } items.append(item) return items def _fetch_documents_sync(self, store_id: str, ids: List[str]) -> Dict[str, Any]: col = self._get_collection(store_id) if col is None: record = self._stores.get(store_id) if record is None: raise ValueError(f"Vector store {store_id} not found") col = self._open_or_create_collection_sync(store_id, record.path) internal_ids = [] for doc_id in ids: fetched = col.fetch(ids=[doc_id]) if doc_id in fetched: internal_ids.append(doc_id) continue for i in range(0, 1024): chunk_id = f"{doc_id}_{i}" fetched = col.fetch(ids=[chunk_id]) if chunk_id in fetched: internal_ids.append(chunk_id) else: break break if not internal_ids: return {} fetched = col.fetch(ids=internal_ids) result = {} for k, v in fetched.items(): result[k] = { "id": v.id, "text": v.field("text") if hasattr(v, "field") else "", "doc_id": v.field("doc_id") if hasattr(v, "field") else "", "chunk_index": v.field("chunk_index") if hasattr(v, "field") else 0, "source": v.field("source") if hasattr(v, "field") else "", } return result def _delete_documents_sync( self, store_id: str, ids: Optional[List[str]], filter_expr: Optional[str], ) -> int: col = self._get_collection(store_id) if col is None: record = self._stores.get(store_id) if record is None: raise ValueError(f"Vector store {store_id} not found") col = self._open_or_create_collection_sync(store_id, record.path) deleted = 0 if ids: chunk_ids = [] for doc_id in ids: chunk_ids.extend(f"{doc_id}_{i}" for i in range(4096)) fetched = col.fetch(ids=chunk_ids[:1000]) actual_ids = [k for k in chunk_ids if k in fetched] if actual_ids: col.delete(ids=actual_ids) deleted = len(actual_ids) if filter_expr: col.delete_by_filter(filter=filter_expr) deleted = max(deleted, 1) col.flush() return deleted def _get_store_stats_sync(self, store_id: str) -> Dict[str, Any]: col = self._get_collection(store_id) if col is None: record = self._stores.get(store_id) if record is None: raise ValueError(f"Vector store {store_id} not found") col = self._open_or_create_collection_sync(store_id, record.path) try: stats = col.stats doc_count = getattr(stats, 'doc_count', 0) if stats else 0 except Exception: doc_count = 0 record = self._stores[store_id] return { "store_id": store_id, "name": record.name, "description": record.description, "app_id": settings.application_id or store_id, "embedding_dimension": _EMBEDDING_DIM, "document_count": doc_count, "created_at": record.created_at, "metadata": record.metadata, } # --- Async public API --- async def _run_sync_fn(self, fn): loop = asyncio.get_running_loop() return await loop.run_in_executor(self._thread_pool, fn) async def create_store( self, name: str, description: str = "", metadata: Optional[Dict[str, Any]] = None, ) -> Tuple[str, str]: store_id = str(uuid.uuid4()) store_path = self._store_path(store_id) await self._run_sync_fn( lambda: self._open_or_create_collection_sync(store_id, store_path) ) record = VectorStoreRecord( store_id=store_id, name=name, path=store_path, description=description, metadata=metadata or {}, ) self._stores[store_id] = record await self._persist_store(record) app_id = settings.application_id or store_id logger.info("Created vector store: %s (name=%s, app_id=%s)", store_id, name, app_id) return store_id, app_id def list_stores(self) -> List[VectorStoreRecord]: return list(self._stores.values()) def get_store(self, store_id: str) -> Optional[VectorStoreRecord]: return self._stores.get(store_id) async def delete_store(self, store_id: str) -> bool: record = self._stores.pop(store_id, None) if record is None: return False col = self._collections.pop(store_id, None) if col is not None: try: await self._run_sync_fn(lambda: col.destroy()) except Exception as exc: logger.warning("Error destroying collection %s: %s", store_id, exc) store_path = record.path if os.path.exists(store_path): await self._run_sync_fn(lambda: shutil.rmtree(store_path, ignore_errors=True)) await self._remove_persisted_store(store_id) logger.info("Deleted vector store: %s", store_id) return True async def _ensure_embedding_model(self) -> None: if not self._embedding_service.is_loaded(_EMBEDDING_DIM): loop = asyncio.get_running_loop() await loop.run_in_executor(None, self._embedding_service.load_model, _EMBEDDING_DIM) async def ingest_document( self, store_id: str, doc_id: str, text: str, source: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, chunk_size: Optional[int] = None, chunk_overlap: Optional[int] = None, ) -> Tuple[int, float]: await self._ensure_embedding_model() size = chunk_size or 512 overlap = chunk_overlap or 64 chunks = await chunk_text_async(text, chunk_size=size, chunk_overlap=overlap) metadata = metadata or {} source = source or metadata.get("source", "") start = time.perf_counter() loop = asyncio.get_running_loop() embeddings = await loop.run_in_executor( self._thread_pool, self._embedding_service.generate_embedding, chunks, _EMBEDDING_DIM, ) elapsed_sync = await self._run_sync_fn( lambda: self._ingest_document_sync(store_id, doc_id, text, chunks, embeddings, source or "") ) elapsed = (time.perf_counter() - start) * 1000 logger.info("Ingested doc %s into store %s: %d chunks in %.2f ms", doc_id, store_id, len(chunks), elapsed) return len(chunks), elapsed async def search( self, store_id: str, query_text: str, top_k: int = 10, filter_expr: Optional[str] = None, min_score: Optional[float] = None, include_vectors: bool = False, include_metadata: bool = False, ) -> Tuple[List[Dict[str, Any]], float]: await self._ensure_embedding_model() start = time.perf_counter() loop = asyncio.get_running_loop() query_emb = await loop.run_in_executor( self._thread_pool, lambda: self._embedding_service.generate_embedding([query_text], _EMBEDDING_DIM)[0], ) items = await self._run_sync_fn( lambda: self._search_sync(store_id, query_emb, top_k, filter_expr, min_score, include_vectors, include_metadata) ) elapsed = (time.perf_counter() - start) * 1000 return items, elapsed async def fetch_documents(self, store_id: str, ids: List[str]) -> Dict[str, Any]: return await self._run_sync_fn( lambda: self._fetch_documents_sync(store_id, ids) ) async def delete_documents( self, store_id: str, ids: Optional[List[str]] = None, filter_expr: Optional[str] = None, ) -> int: return await self._run_sync_fn( lambda: self._delete_documents_sync(store_id, ids, filter_expr) ) async def get_store_stats(self, store_id: str) -> Dict[str, Any]: return await self._run_sync_fn( lambda: self._get_store_stats_sync(store_id) ) async def get_total_document_count(self) -> int: total = 0 for store_id in list(self._stores.keys()): try: stats = await self.get_store_stats(store_id) total += stats.get("document_count", 0) except Exception: pass return total async def close_all(self) -> None: for store_id, col in list(self._collections.items()): try: await self._run_sync_fn(lambda: col.flush()) except Exception: pass self._collections.clear() self._stores.clear() self._thread_pool.shutdown(wait=True)