Spaces:
Running
Running
| from __future__ import annotations | |
| import asyncio | |
| import concurrent.futures | |
| import json | |
| import os | |
| import shutil | |
| import time | |
| import uuid | |
| from datetime import datetime, timezone | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import zvec | |
| from sqlalchemy import select | |
| from app.config import get_settings | |
| from app.core.logger import get_logger | |
| from app.core.vector_store.deps import AsyncSessionLocal | |
| from app.core.vector_store.models import VectorStoreIndex | |
| from app.services.chunking_service import chunk_text_async | |
| from app.services.embeddings_service import EmbeddingService | |
| logger = get_logger(__name__) | |
| settings = get_settings() | |
| _EMBEDDING_DIM = 384 | |
| _MAX_WORKERS = min(16, (os.cpu_count() or 1) + 4) | |
| def _run_sync(fn, *args, **kwargs): | |
| return fn(*args, **kwargs) | |
| class VectorStoreRecord: | |
| def __init__( | |
| self, | |
| store_id: str, | |
| name: str, | |
| path: str, | |
| description: str = "", | |
| metadata: Optional[Dict[str, Any]] = None, | |
| created_at: Optional[str] = None, | |
| ): | |
| self.store_id = store_id | |
| self.name = name | |
| self.path = path | |
| self.description = description | |
| self.metadata = metadata or {} | |
| self.created_at = created_at or datetime.now(timezone.utc).isoformat() | |
| class VectorStoreService: | |
| def __init__(self, embedding_service: EmbeddingService): | |
| self._embedding_service = embedding_service | |
| self._stores: Dict[str, VectorStoreRecord] = {} | |
| self._collections: Dict[str, "zvec.Collection"] = {} | |
| self._data_dir = os.path.join(settings.data_dir, "vector_stores") | |
| self._thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=_MAX_WORKERS, thread_name_prefix="zvec") | |
| os.makedirs(self._data_dir, exist_ok=True) | |
| def _store_path(self, store_id: str) -> str: | |
| return os.path.join(self._data_dir, f"store_{store_id}") | |
| def _get_collection(self, store_id: str) -> Optional["zvec.Collection"]: | |
| return self._collections.get(store_id) | |
| # --- SQLite persistence --- | |
| async def init_db(self) -> None: | |
| from app.core.vector_store.deps import init_vector_store_db | |
| await init_vector_store_db() | |
| async with AsyncSessionLocal() as session: | |
| result = await session.execute(select(VectorStoreIndex)) | |
| rows = result.scalars().all() | |
| for row in rows: | |
| d = row.to_dict() | |
| record = VectorStoreRecord( | |
| store_id=d["store_id"], | |
| name=d["name"], | |
| path=d["path"], | |
| description=d["description"], | |
| metadata=d["metadata"], | |
| created_at=d["created_at"], | |
| ) | |
| self._stores[record.store_id] = record | |
| store_path = d["path"] | |
| if os.path.exists(os.path.join(store_path, "__zvec_meta")): | |
| try: | |
| col = zvec.open(store_path) | |
| if col is not None: | |
| self._collections[record.store_id] = col | |
| except Exception as exc: | |
| logger.warning("Could not open collection %s: %s", record.store_id, exc) | |
| async def _persist_store(self, record: VectorStoreRecord) -> None: | |
| async with AsyncSessionLocal() as session: | |
| existing = await session.get(VectorStoreIndex, record.store_id) | |
| if existing: | |
| existing.name = record.name | |
| existing.description = record.description | |
| existing.metadata_json = json.dumps(record.metadata) | |
| else: | |
| session.add(VectorStoreIndex.from_dict({ | |
| "store_id": record.store_id, | |
| "name": record.name, | |
| "path": record.path, | |
| "description": record.description, | |
| "metadata": record.metadata, | |
| "created_at": record.created_at, | |
| })) | |
| await session.commit() | |
| async def _remove_persisted_store(self, store_id: str) -> None: | |
| async with AsyncSessionLocal() as session: | |
| row = await session.get(VectorStoreIndex, store_id) | |
| if row: | |
| await session.delete(row) | |
| await session.commit() | |
| # --- Synchronous helpers (run in thread pool) --- | |
| def _open_or_create_collection_sync(self, store_id: str, store_path: str) -> "zvec.Collection": | |
| col = self._collections.get(store_id) | |
| if col is not None: | |
| return col | |
| if os.path.exists(os.path.join(store_path, "__zvec_meta")): | |
| logger.info("Opening existing collection: %s", store_id) | |
| try: | |
| col = zvec.open(store_path) | |
| if col is not None: | |
| self._collections[store_id] = col | |
| return col | |
| except Exception as exc: | |
| logger.warning("Could not open collection %s: %s", store_id, exc) | |
| schema = zvec.CollectionSchema( | |
| name=f"store_{store_id}", | |
| fields=[ | |
| zvec.FieldSchema(name="text", data_type=zvec.DataType.STRING), | |
| zvec.FieldSchema( | |
| name="doc_id", | |
| data_type=zvec.DataType.STRING, | |
| index_param=zvec.InvertIndexParam(), | |
| ), | |
| zvec.FieldSchema(name="chunk_index", data_type=zvec.DataType.INT32), | |
| zvec.FieldSchema( | |
| name="source", | |
| data_type=zvec.DataType.STRING, | |
| index_param=zvec.InvertIndexParam(), | |
| ), | |
| zvec.FieldSchema(name="created_at", data_type=zvec.DataType.INT64), | |
| ], | |
| vectors=[ | |
| zvec.VectorSchema( | |
| name="embedding", | |
| data_type=zvec.DataType.VECTOR_FP32, | |
| dimension=_EMBEDDING_DIM, | |
| index_param=zvec.HnswIndexParam(metric_type=zvec.MetricType.COSINE), | |
| ), | |
| ], | |
| ) | |
| logger.info("Creating new collection: %s", store_id) | |
| col = zvec.create_and_open(path=store_path, schema=schema) | |
| self._collections[store_id] = col | |
| return col | |
| def _ingest_document_sync( | |
| self, | |
| store_id: str, | |
| doc_id: str, | |
| text: str, | |
| chunks: List[str], | |
| embeddings: List[List[float]], | |
| source: str, | |
| ) -> Tuple[int, float]: | |
| col = self._get_collection(store_id) | |
| if col is None: | |
| record = self._stores.get(store_id) | |
| if record is None: | |
| raise ValueError(f"Vector store {store_id} not found") | |
| col = self._open_or_create_collection_sync(store_id, record.path) | |
| now_ts = int(time.time()) | |
| docs = [ | |
| zvec.Doc( | |
| id=f"{doc_id}_{i}", | |
| vectors={"embedding": emb}, | |
| fields={ | |
| "text": chunk_text_str, | |
| "doc_id": doc_id, | |
| "chunk_index": i, | |
| "source": source or "", | |
| "created_at": now_ts, | |
| }, | |
| ) | |
| for i, (chunk_text_str, emb) in enumerate(zip(chunks, embeddings)) | |
| ] | |
| for i in range(0, len(docs), 100): | |
| col.insert(docs[i:i + 100]) | |
| col.flush() | |
| col.optimize() | |
| return len(chunks), 0.0 | |
| def _search_sync( | |
| self, | |
| store_id: str, | |
| query_emb: List[float], | |
| top_k: int, | |
| filter_expr: Optional[str], | |
| min_score: Optional[float] = None, | |
| include_vectors: bool = False, | |
| include_metadata: bool = False, | |
| ) -> List[Dict[str, Any]]: | |
| col = self._get_collection(store_id) | |
| if col is None: | |
| record = self._stores.get(store_id) | |
| if record is None: | |
| raise ValueError(f"Vector store {store_id} not found") | |
| col = self._open_or_create_collection_sync(store_id, record.path) | |
| kwargs: Dict[str, Any] = { | |
| "vectors": zvec.VectorQuery(field_name="embedding", vector=query_emb), | |
| "topk": top_k, | |
| } | |
| if filter_expr: | |
| kwargs["filter"] = filter_expr | |
| results = col.query(**kwargs) | |
| items = [] | |
| for i, r in enumerate(results): | |
| score = float(r.score) if hasattr(r, "score") and r.score is not None else 0.0 | |
| if min_score is not None and score > (1.0 - min_score): | |
| continue | |
| item: Dict[str, Any] = { | |
| "rank": len(items) + 1, | |
| "doc_id": r.field("doc_id") if hasattr(r, "field") else r.id.rsplit("_", 1)[0], | |
| "chunk_index": r.field("chunk_index") if hasattr(r, "field") else 0, | |
| "text": r.field("text") if hasattr(r, "field") else "", | |
| "score": score, | |
| "source": r.field("source") if hasattr(r, "field") else "", | |
| "metadata": {}, | |
| "vector": None, | |
| } | |
| if include_vectors: | |
| try: | |
| vec = r.vector("embedding") if hasattr(r, "vector") else None | |
| item["vector"] = list(vec) if vec is not None else None | |
| except Exception: | |
| item["vector"] = None | |
| if include_metadata: | |
| item["metadata"] = { | |
| "doc_id": item["doc_id"], | |
| "chunk_index": item["chunk_index"], | |
| "source": item["source"], | |
| } | |
| items.append(item) | |
| return items | |
| def _fetch_documents_sync(self, store_id: str, ids: List[str]) -> Dict[str, Any]: | |
| col = self._get_collection(store_id) | |
| if col is None: | |
| record = self._stores.get(store_id) | |
| if record is None: | |
| raise ValueError(f"Vector store {store_id} not found") | |
| col = self._open_or_create_collection_sync(store_id, record.path) | |
| internal_ids = [] | |
| for doc_id in ids: | |
| fetched = col.fetch(ids=[doc_id]) | |
| if doc_id in fetched: | |
| internal_ids.append(doc_id) | |
| continue | |
| for i in range(0, 1024): | |
| chunk_id = f"{doc_id}_{i}" | |
| fetched = col.fetch(ids=[chunk_id]) | |
| if chunk_id in fetched: | |
| internal_ids.append(chunk_id) | |
| else: | |
| break | |
| break | |
| if not internal_ids: | |
| return {} | |
| fetched = col.fetch(ids=internal_ids) | |
| result = {} | |
| for k, v in fetched.items(): | |
| result[k] = { | |
| "id": v.id, | |
| "text": v.field("text") if hasattr(v, "field") else "", | |
| "doc_id": v.field("doc_id") if hasattr(v, "field") else "", | |
| "chunk_index": v.field("chunk_index") if hasattr(v, "field") else 0, | |
| "source": v.field("source") if hasattr(v, "field") else "", | |
| } | |
| return result | |
| def _delete_documents_sync( | |
| self, | |
| store_id: str, | |
| ids: Optional[List[str]], | |
| filter_expr: Optional[str], | |
| ) -> int: | |
| col = self._get_collection(store_id) | |
| if col is None: | |
| record = self._stores.get(store_id) | |
| if record is None: | |
| raise ValueError(f"Vector store {store_id} not found") | |
| col = self._open_or_create_collection_sync(store_id, record.path) | |
| deleted = 0 | |
| if ids: | |
| chunk_ids = [] | |
| for doc_id in ids: | |
| chunk_ids.extend(f"{doc_id}_{i}" for i in range(4096)) | |
| fetched = col.fetch(ids=chunk_ids[:1000]) | |
| actual_ids = [k for k in chunk_ids if k in fetched] | |
| if actual_ids: | |
| col.delete(ids=actual_ids) | |
| deleted = len(actual_ids) | |
| if filter_expr: | |
| col.delete_by_filter(filter=filter_expr) | |
| deleted = max(deleted, 1) | |
| col.flush() | |
| return deleted | |
| def _get_store_stats_sync(self, store_id: str) -> Dict[str, Any]: | |
| col = self._get_collection(store_id) | |
| if col is None: | |
| record = self._stores.get(store_id) | |
| if record is None: | |
| raise ValueError(f"Vector store {store_id} not found") | |
| col = self._open_or_create_collection_sync(store_id, record.path) | |
| try: | |
| stats = col.stats | |
| doc_count = getattr(stats, 'doc_count', 0) if stats else 0 | |
| except Exception: | |
| doc_count = 0 | |
| record = self._stores[store_id] | |
| return { | |
| "store_id": store_id, | |
| "name": record.name, | |
| "description": record.description, | |
| "app_id": settings.application_id or store_id, | |
| "embedding_dimension": _EMBEDDING_DIM, | |
| "document_count": doc_count, | |
| "created_at": record.created_at, | |
| "metadata": record.metadata, | |
| } | |
| # --- Async public API --- | |
| async def _run_sync_fn(self, fn): | |
| loop = asyncio.get_running_loop() | |
| return await loop.run_in_executor(self._thread_pool, fn) | |
| async def create_store( | |
| self, | |
| name: str, | |
| description: str = "", | |
| metadata: Optional[Dict[str, Any]] = None, | |
| ) -> Tuple[str, str]: | |
| store_id = str(uuid.uuid4()) | |
| store_path = self._store_path(store_id) | |
| await self._run_sync_fn( | |
| lambda: self._open_or_create_collection_sync(store_id, store_path) | |
| ) | |
| record = VectorStoreRecord( | |
| store_id=store_id, | |
| name=name, | |
| path=store_path, | |
| description=description, | |
| metadata=metadata or {}, | |
| ) | |
| self._stores[store_id] = record | |
| await self._persist_store(record) | |
| app_id = settings.application_id or store_id | |
| logger.info("Created vector store: %s (name=%s, app_id=%s)", store_id, name, app_id) | |
| return store_id, app_id | |
| def list_stores(self) -> List[VectorStoreRecord]: | |
| return list(self._stores.values()) | |
| def get_store(self, store_id: str) -> Optional[VectorStoreRecord]: | |
| return self._stores.get(store_id) | |
| async def delete_store(self, store_id: str) -> bool: | |
| record = self._stores.pop(store_id, None) | |
| if record is None: | |
| return False | |
| col = self._collections.pop(store_id, None) | |
| if col is not None: | |
| try: | |
| await self._run_sync_fn(lambda: col.destroy()) | |
| except Exception as exc: | |
| logger.warning("Error destroying collection %s: %s", store_id, exc) | |
| store_path = record.path | |
| if os.path.exists(store_path): | |
| await self._run_sync_fn(lambda: shutil.rmtree(store_path, ignore_errors=True)) | |
| await self._remove_persisted_store(store_id) | |
| logger.info("Deleted vector store: %s", store_id) | |
| return True | |
| async def _ensure_embedding_model(self) -> None: | |
| if not self._embedding_service.is_loaded(_EMBEDDING_DIM): | |
| loop = asyncio.get_running_loop() | |
| await loop.run_in_executor(None, self._embedding_service.load_model, _EMBEDDING_DIM) | |
| async def ingest_document( | |
| self, | |
| store_id: str, | |
| doc_id: str, | |
| text: str, | |
| source: Optional[str] = None, | |
| metadata: Optional[Dict[str, Any]] = None, | |
| chunk_size: Optional[int] = None, | |
| chunk_overlap: Optional[int] = None, | |
| ) -> Tuple[int, float]: | |
| await self._ensure_embedding_model() | |
| size = chunk_size or 512 | |
| overlap = chunk_overlap or 64 | |
| chunks = await chunk_text_async(text, chunk_size=size, chunk_overlap=overlap) | |
| metadata = metadata or {} | |
| source = source or metadata.get("source", "") | |
| start = time.perf_counter() | |
| loop = asyncio.get_running_loop() | |
| embeddings = await loop.run_in_executor( | |
| self._thread_pool, | |
| self._embedding_service.generate_embedding, | |
| chunks, | |
| _EMBEDDING_DIM, | |
| ) | |
| elapsed_sync = await self._run_sync_fn( | |
| lambda: self._ingest_document_sync(store_id, doc_id, text, chunks, embeddings, source or "") | |
| ) | |
| elapsed = (time.perf_counter() - start) * 1000 | |
| logger.info("Ingested doc %s into store %s: %d chunks in %.2f ms", doc_id, store_id, len(chunks), elapsed) | |
| return len(chunks), elapsed | |
| async def search( | |
| self, | |
| store_id: str, | |
| query_text: str, | |
| top_k: int = 10, | |
| filter_expr: Optional[str] = None, | |
| min_score: Optional[float] = None, | |
| include_vectors: bool = False, | |
| include_metadata: bool = False, | |
| ) -> Tuple[List[Dict[str, Any]], float]: | |
| await self._ensure_embedding_model() | |
| start = time.perf_counter() | |
| loop = asyncio.get_running_loop() | |
| query_emb = await loop.run_in_executor( | |
| self._thread_pool, | |
| lambda: self._embedding_service.generate_embedding([query_text], _EMBEDDING_DIM)[0], | |
| ) | |
| items = await self._run_sync_fn( | |
| lambda: self._search_sync(store_id, query_emb, top_k, filter_expr, min_score, include_vectors, include_metadata) | |
| ) | |
| elapsed = (time.perf_counter() - start) * 1000 | |
| return items, elapsed | |
| async def fetch_documents(self, store_id: str, ids: List[str]) -> Dict[str, Any]: | |
| return await self._run_sync_fn( | |
| lambda: self._fetch_documents_sync(store_id, ids) | |
| ) | |
| async def delete_documents( | |
| self, | |
| store_id: str, | |
| ids: Optional[List[str]] = None, | |
| filter_expr: Optional[str] = None, | |
| ) -> int: | |
| return await self._run_sync_fn( | |
| lambda: self._delete_documents_sync(store_id, ids, filter_expr) | |
| ) | |
| async def get_store_stats(self, store_id: str) -> Dict[str, Any]: | |
| return await self._run_sync_fn( | |
| lambda: self._get_store_stats_sync(store_id) | |
| ) | |
| async def get_total_document_count(self) -> int: | |
| total = 0 | |
| for store_id in list(self._stores.keys()): | |
| try: | |
| stats = await self.get_store_stats(store_id) | |
| total += stats.get("document_count", 0) | |
| except Exception: | |
| pass | |
| return total | |
| async def close_all(self) -> None: | |
| for store_id, col in list(self._collections.items()): | |
| try: | |
| await self._run_sync_fn(lambda: col.flush()) | |
| except Exception: | |
| pass | |
| self._collections.clear() | |
| self._stores.clear() | |
| self._thread_pool.shutdown(wait=True) | |