llm-ready-data / app /services /vector_store_service.py
light-infer-chat's picture
ok
08240ea
Raw
History Blame Contribute Delete
18.8 kB
from __future__ import annotations
import asyncio
import concurrent.futures
import json
import os
import shutil
import time
import uuid
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional, Tuple
import zvec
from sqlalchemy import select
from app.config import get_settings
from app.core.logger import get_logger
from app.core.vector_store.deps import AsyncSessionLocal
from app.core.vector_store.models import VectorStoreIndex
from app.services.chunking_service import chunk_text_async
from app.services.embeddings_service import EmbeddingService
logger = get_logger(__name__)
settings = get_settings()
_EMBEDDING_DIM = 384
_MAX_WORKERS = min(16, (os.cpu_count() or 1) + 4)
def _run_sync(fn, *args, **kwargs):
return fn(*args, **kwargs)
class VectorStoreRecord:
def __init__(
self,
store_id: str,
name: str,
path: str,
description: str = "",
metadata: Optional[Dict[str, Any]] = None,
created_at: Optional[str] = None,
):
self.store_id = store_id
self.name = name
self.path = path
self.description = description
self.metadata = metadata or {}
self.created_at = created_at or datetime.now(timezone.utc).isoformat()
class VectorStoreService:
def __init__(self, embedding_service: EmbeddingService):
self._embedding_service = embedding_service
self._stores: Dict[str, VectorStoreRecord] = {}
self._collections: Dict[str, "zvec.Collection"] = {}
self._data_dir = os.path.join(settings.data_dir, "vector_stores")
self._thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=_MAX_WORKERS, thread_name_prefix="zvec")
os.makedirs(self._data_dir, exist_ok=True)
def _store_path(self, store_id: str) -> str:
return os.path.join(self._data_dir, f"store_{store_id}")
def _get_collection(self, store_id: str) -> Optional["zvec.Collection"]:
return self._collections.get(store_id)
# --- SQLite persistence ---
async def init_db(self) -> None:
from app.core.vector_store.deps import init_vector_store_db
await init_vector_store_db()
async with AsyncSessionLocal() as session:
result = await session.execute(select(VectorStoreIndex))
rows = result.scalars().all()
for row in rows:
d = row.to_dict()
record = VectorStoreRecord(
store_id=d["store_id"],
name=d["name"],
path=d["path"],
description=d["description"],
metadata=d["metadata"],
created_at=d["created_at"],
)
self._stores[record.store_id] = record
store_path = d["path"]
if os.path.exists(os.path.join(store_path, "__zvec_meta")):
try:
col = zvec.open(store_path)
if col is not None:
self._collections[record.store_id] = col
except Exception as exc:
logger.warning("Could not open collection %s: %s", record.store_id, exc)
async def _persist_store(self, record: VectorStoreRecord) -> None:
async with AsyncSessionLocal() as session:
existing = await session.get(VectorStoreIndex, record.store_id)
if existing:
existing.name = record.name
existing.description = record.description
existing.metadata_json = json.dumps(record.metadata)
else:
session.add(VectorStoreIndex.from_dict({
"store_id": record.store_id,
"name": record.name,
"path": record.path,
"description": record.description,
"metadata": record.metadata,
"created_at": record.created_at,
}))
await session.commit()
async def _remove_persisted_store(self, store_id: str) -> None:
async with AsyncSessionLocal() as session:
row = await session.get(VectorStoreIndex, store_id)
if row:
await session.delete(row)
await session.commit()
# --- Synchronous helpers (run in thread pool) ---
def _open_or_create_collection_sync(self, store_id: str, store_path: str) -> "zvec.Collection":
col = self._collections.get(store_id)
if col is not None:
return col
if os.path.exists(os.path.join(store_path, "__zvec_meta")):
logger.info("Opening existing collection: %s", store_id)
try:
col = zvec.open(store_path)
if col is not None:
self._collections[store_id] = col
return col
except Exception as exc:
logger.warning("Could not open collection %s: %s", store_id, exc)
schema = zvec.CollectionSchema(
name=f"store_{store_id}",
fields=[
zvec.FieldSchema(name="text", data_type=zvec.DataType.STRING),
zvec.FieldSchema(
name="doc_id",
data_type=zvec.DataType.STRING,
index_param=zvec.InvertIndexParam(),
),
zvec.FieldSchema(name="chunk_index", data_type=zvec.DataType.INT32),
zvec.FieldSchema(
name="source",
data_type=zvec.DataType.STRING,
index_param=zvec.InvertIndexParam(),
),
zvec.FieldSchema(name="created_at", data_type=zvec.DataType.INT64),
],
vectors=[
zvec.VectorSchema(
name="embedding",
data_type=zvec.DataType.VECTOR_FP32,
dimension=_EMBEDDING_DIM,
index_param=zvec.HnswIndexParam(metric_type=zvec.MetricType.COSINE),
),
],
)
logger.info("Creating new collection: %s", store_id)
col = zvec.create_and_open(path=store_path, schema=schema)
self._collections[store_id] = col
return col
def _ingest_document_sync(
self,
store_id: str,
doc_id: str,
text: str,
chunks: List[str],
embeddings: List[List[float]],
source: str,
) -> Tuple[int, float]:
col = self._get_collection(store_id)
if col is None:
record = self._stores.get(store_id)
if record is None:
raise ValueError(f"Vector store {store_id} not found")
col = self._open_or_create_collection_sync(store_id, record.path)
now_ts = int(time.time())
docs = [
zvec.Doc(
id=f"{doc_id}_{i}",
vectors={"embedding": emb},
fields={
"text": chunk_text_str,
"doc_id": doc_id,
"chunk_index": i,
"source": source or "",
"created_at": now_ts,
},
)
for i, (chunk_text_str, emb) in enumerate(zip(chunks, embeddings))
]
for i in range(0, len(docs), 100):
col.insert(docs[i:i + 100])
col.flush()
col.optimize()
return len(chunks), 0.0
def _search_sync(
self,
store_id: str,
query_emb: List[float],
top_k: int,
filter_expr: Optional[str],
min_score: Optional[float] = None,
include_vectors: bool = False,
include_metadata: bool = False,
) -> List[Dict[str, Any]]:
col = self._get_collection(store_id)
if col is None:
record = self._stores.get(store_id)
if record is None:
raise ValueError(f"Vector store {store_id} not found")
col = self._open_or_create_collection_sync(store_id, record.path)
kwargs: Dict[str, Any] = {
"vectors": zvec.VectorQuery(field_name="embedding", vector=query_emb),
"topk": top_k,
}
if filter_expr:
kwargs["filter"] = filter_expr
results = col.query(**kwargs)
items = []
for i, r in enumerate(results):
score = float(r.score) if hasattr(r, "score") and r.score is not None else 0.0
if min_score is not None and score > (1.0 - min_score):
continue
item: Dict[str, Any] = {
"rank": len(items) + 1,
"doc_id": r.field("doc_id") if hasattr(r, "field") else r.id.rsplit("_", 1)[0],
"chunk_index": r.field("chunk_index") if hasattr(r, "field") else 0,
"text": r.field("text") if hasattr(r, "field") else "",
"score": score,
"source": r.field("source") if hasattr(r, "field") else "",
"metadata": {},
"vector": None,
}
if include_vectors:
try:
vec = r.vector("embedding") if hasattr(r, "vector") else None
item["vector"] = list(vec) if vec is not None else None
except Exception:
item["vector"] = None
if include_metadata:
item["metadata"] = {
"doc_id": item["doc_id"],
"chunk_index": item["chunk_index"],
"source": item["source"],
}
items.append(item)
return items
def _fetch_documents_sync(self, store_id: str, ids: List[str]) -> Dict[str, Any]:
col = self._get_collection(store_id)
if col is None:
record = self._stores.get(store_id)
if record is None:
raise ValueError(f"Vector store {store_id} not found")
col = self._open_or_create_collection_sync(store_id, record.path)
internal_ids = []
for doc_id in ids:
fetched = col.fetch(ids=[doc_id])
if doc_id in fetched:
internal_ids.append(doc_id)
continue
for i in range(0, 1024):
chunk_id = f"{doc_id}_{i}"
fetched = col.fetch(ids=[chunk_id])
if chunk_id in fetched:
internal_ids.append(chunk_id)
else:
break
break
if not internal_ids:
return {}
fetched = col.fetch(ids=internal_ids)
result = {}
for k, v in fetched.items():
result[k] = {
"id": v.id,
"text": v.field("text") if hasattr(v, "field") else "",
"doc_id": v.field("doc_id") if hasattr(v, "field") else "",
"chunk_index": v.field("chunk_index") if hasattr(v, "field") else 0,
"source": v.field("source") if hasattr(v, "field") else "",
}
return result
def _delete_documents_sync(
self,
store_id: str,
ids: Optional[List[str]],
filter_expr: Optional[str],
) -> int:
col = self._get_collection(store_id)
if col is None:
record = self._stores.get(store_id)
if record is None:
raise ValueError(f"Vector store {store_id} not found")
col = self._open_or_create_collection_sync(store_id, record.path)
deleted = 0
if ids:
chunk_ids = []
for doc_id in ids:
chunk_ids.extend(f"{doc_id}_{i}" for i in range(4096))
fetched = col.fetch(ids=chunk_ids[:1000])
actual_ids = [k for k in chunk_ids if k in fetched]
if actual_ids:
col.delete(ids=actual_ids)
deleted = len(actual_ids)
if filter_expr:
col.delete_by_filter(filter=filter_expr)
deleted = max(deleted, 1)
col.flush()
return deleted
def _get_store_stats_sync(self, store_id: str) -> Dict[str, Any]:
col = self._get_collection(store_id)
if col is None:
record = self._stores.get(store_id)
if record is None:
raise ValueError(f"Vector store {store_id} not found")
col = self._open_or_create_collection_sync(store_id, record.path)
try:
stats = col.stats
doc_count = getattr(stats, 'doc_count', 0) if stats else 0
except Exception:
doc_count = 0
record = self._stores[store_id]
return {
"store_id": store_id,
"name": record.name,
"description": record.description,
"app_id": settings.application_id or store_id,
"embedding_dimension": _EMBEDDING_DIM,
"document_count": doc_count,
"created_at": record.created_at,
"metadata": record.metadata,
}
# --- Async public API ---
async def _run_sync_fn(self, fn):
loop = asyncio.get_running_loop()
return await loop.run_in_executor(self._thread_pool, fn)
async def create_store(
self,
name: str,
description: str = "",
metadata: Optional[Dict[str, Any]] = None,
) -> Tuple[str, str]:
store_id = str(uuid.uuid4())
store_path = self._store_path(store_id)
await self._run_sync_fn(
lambda: self._open_or_create_collection_sync(store_id, store_path)
)
record = VectorStoreRecord(
store_id=store_id,
name=name,
path=store_path,
description=description,
metadata=metadata or {},
)
self._stores[store_id] = record
await self._persist_store(record)
app_id = settings.application_id or store_id
logger.info("Created vector store: %s (name=%s, app_id=%s)", store_id, name, app_id)
return store_id, app_id
def list_stores(self) -> List[VectorStoreRecord]:
return list(self._stores.values())
def get_store(self, store_id: str) -> Optional[VectorStoreRecord]:
return self._stores.get(store_id)
async def delete_store(self, store_id: str) -> bool:
record = self._stores.pop(store_id, None)
if record is None:
return False
col = self._collections.pop(store_id, None)
if col is not None:
try:
await self._run_sync_fn(lambda: col.destroy())
except Exception as exc:
logger.warning("Error destroying collection %s: %s", store_id, exc)
store_path = record.path
if os.path.exists(store_path):
await self._run_sync_fn(lambda: shutil.rmtree(store_path, ignore_errors=True))
await self._remove_persisted_store(store_id)
logger.info("Deleted vector store: %s", store_id)
return True
async def _ensure_embedding_model(self) -> None:
if not self._embedding_service.is_loaded(_EMBEDDING_DIM):
loop = asyncio.get_running_loop()
await loop.run_in_executor(None, self._embedding_service.load_model, _EMBEDDING_DIM)
async def ingest_document(
self,
store_id: str,
doc_id: str,
text: str,
source: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
chunk_size: Optional[int] = None,
chunk_overlap: Optional[int] = None,
) -> Tuple[int, float]:
await self._ensure_embedding_model()
size = chunk_size or 512
overlap = chunk_overlap or 64
chunks = await chunk_text_async(text, chunk_size=size, chunk_overlap=overlap)
metadata = metadata or {}
source = source or metadata.get("source", "")
start = time.perf_counter()
loop = asyncio.get_running_loop()
embeddings = await loop.run_in_executor(
self._thread_pool,
self._embedding_service.generate_embedding,
chunks,
_EMBEDDING_DIM,
)
elapsed_sync = await self._run_sync_fn(
lambda: self._ingest_document_sync(store_id, doc_id, text, chunks, embeddings, source or "")
)
elapsed = (time.perf_counter() - start) * 1000
logger.info("Ingested doc %s into store %s: %d chunks in %.2f ms", doc_id, store_id, len(chunks), elapsed)
return len(chunks), elapsed
async def search(
self,
store_id: str,
query_text: str,
top_k: int = 10,
filter_expr: Optional[str] = None,
min_score: Optional[float] = None,
include_vectors: bool = False,
include_metadata: bool = False,
) -> Tuple[List[Dict[str, Any]], float]:
await self._ensure_embedding_model()
start = time.perf_counter()
loop = asyncio.get_running_loop()
query_emb = await loop.run_in_executor(
self._thread_pool,
lambda: self._embedding_service.generate_embedding([query_text], _EMBEDDING_DIM)[0],
)
items = await self._run_sync_fn(
lambda: self._search_sync(store_id, query_emb, top_k, filter_expr, min_score, include_vectors, include_metadata)
)
elapsed = (time.perf_counter() - start) * 1000
return items, elapsed
async def fetch_documents(self, store_id: str, ids: List[str]) -> Dict[str, Any]:
return await self._run_sync_fn(
lambda: self._fetch_documents_sync(store_id, ids)
)
async def delete_documents(
self,
store_id: str,
ids: Optional[List[str]] = None,
filter_expr: Optional[str] = None,
) -> int:
return await self._run_sync_fn(
lambda: self._delete_documents_sync(store_id, ids, filter_expr)
)
async def get_store_stats(self, store_id: str) -> Dict[str, Any]:
return await self._run_sync_fn(
lambda: self._get_store_stats_sync(store_id)
)
async def get_total_document_count(self) -> int:
total = 0
for store_id in list(self._stores.keys()):
try:
stats = await self.get_store_stats(store_id)
total += stats.get("document_count", 0)
except Exception:
pass
return total
async def close_all(self) -> None:
for store_id, col in list(self._collections.items()):
try:
await self._run_sync_fn(lambda: col.flush())
except Exception:
pass
self._collections.clear()
self._stores.clear()
self._thread_pool.shutdown(wait=True)