meet4150/ALIV_AI / app /db /chroma_client.py
download
raw
10.7 kB
from __future__ import annotations
import os
import time
from pathlib import Path
from threading import Lock
from typing import Any
import chromadb
from dotenv import load_dotenv
try:
from pinecone import Pinecone
_PINECONE_AVAILABLE = True
except Exception:
Pinecone = None
_PINECONE_AVAILABLE = False
load_dotenv()
CHROMA_PATH = Path(__file__).resolve().parents[2] / "chroma_db"
CHROMA_PATH.mkdir(parents=True, exist_ok=True)
PINECONE_API_KEY = os.getenv("ALIVEAI_PINECONE_API_KEY", "").strip()
PINECONE_INDEX_NAME = os.getenv("ALIVEAI_PINECONE_INDEX_NAME", "").strip()
PINECONE_INDEX_HOST = os.getenv("ALIVEAI_PINECONE_INDEX_HOST", "").strip()
PINECONE_NAMESPACE = os.getenv("ALIVEAI_PINECONE_NAMESPACE", "").strip()
# valid values: auto | pinecone | chroma
# default is chroma so Pinecone is opt-in only.
REQUESTED_VECTOR_BACKEND = os.getenv("ALIVEAI_VECTOR_BACKEND", "chroma").strip().lower() or "chroma"
_client: Any = None
_collection: Any = None
_lock = Lock()
_active_backend = "unknown"
_embedding_dimension_cache: int | None = None
class PineconeCollectionAdapter:
def __init__(self, index) -> None:
self._index = index
@staticmethod
def _namespace() -> str | None:
return PINECONE_NAMESPACE or None
@staticmethod
def _to_dict(value: Any) -> dict[str, Any]:
if isinstance(value, dict):
return value
to_dict = getattr(value, "to_dict", None)
if callable(to_dict):
return to_dict()
return {}
@staticmethod
def _match_get(match: Any, key: str, default: Any = None) -> Any:
if isinstance(match, dict):
return match.get(key, default)
return getattr(match, key, default)
@staticmethod
def _to_pinecone_filter(where: dict | None) -> dict | None:
if not where:
return None
converted: dict[str, Any] = {}
for key, value in where.items():
converted[key] = {"$eq": value}
return converted
def get(self, ids: list[str]) -> dict:
if not ids:
return {"ids": []}
response = self._index.fetch(ids=ids, namespace=self._namespace())
response_dict = self._to_dict(response)
vectors = response_dict.get("vectors", {})
if not vectors and not response_dict:
vectors = getattr(response, "vectors", {}) or {}
return {"ids": list(vectors.keys())}
def add(
self,
ids: list[str],
embeddings: list[list[float]],
documents: list[str],
metadatas: list[dict],
) -> None:
vectors = []
for record_id, embedding, document, metadata in zip(ids, embeddings, documents, metadatas):
merged_metadata = dict(metadata or {})
merged_metadata["_document"] = document
vectors.append(
{
"id": record_id,
"values": embedding,
"metadata": merged_metadata,
}
)
self._index.upsert(vectors=vectors, namespace=self._namespace())
def query(
self,
query_embeddings: list[list[float]],
n_results: int,
include: list[str],
where: dict | None = None,
) -> dict:
vector = query_embeddings[0]
pinecone_filter = self._to_pinecone_filter(where)
response = self._index.query(
vector=vector,
top_k=n_results,
filter=pinecone_filter,
namespace=self._namespace(),
include_metadata=True,
include_values=False,
)
response_dict = self._to_dict(response)
matches = response_dict.get("matches", [])
if not matches and not response_dict:
matches = getattr(response, "matches", []) or []
documents: list[str] = []
distances: list[float] = []
metadatas: list[dict] = []
for match in matches:
metadata = dict(self._match_get(match, "metadata", {}) or {})
documents.append(str(metadata.pop("_document", "")))
score = float(self._match_get(match, "score", 0.0))
distances.append(1.0 - score)
metadatas.append(metadata)
payload: dict[str, list[list[Any]]] = {}
if "documents" in include:
payload["documents"] = [documents]
if "distances" in include:
payload["distances"] = [distances]
if "metadatas" in include:
payload["metadatas"] = [metadatas]
return payload
def count(self) -> int:
stats = self._index.describe_index_stats()
stats_dict = self._to_dict(stats)
if stats_dict:
if PINECONE_NAMESPACE:
namespace_stats = (stats_dict.get("namespaces") or {}).get(PINECONE_NAMESPACE, {})
return int(namespace_stats.get("vector_count", 0))
return int(stats_dict.get("total_vector_count", 0))
if PINECONE_NAMESPACE:
namespaces = getattr(stats, "namespaces", {}) or {}
namespace_stats = namespaces.get(PINECONE_NAMESPACE, {})
return int(
namespace_stats.get("vector_count", 0)
if isinstance(namespace_stats, dict)
else getattr(namespace_stats, "vector_count", 0)
)
return int(getattr(stats, "total_vector_count", 0))
def _to_dict(value: Any) -> dict[str, Any]:
if isinstance(value, dict):
return value
to_dict = getattr(value, "to_dict", None)
if callable(to_dict):
return to_dict()
return {}
def _resolve_embedding_dimension() -> int | None:
global _embedding_dimension_cache
if _embedding_dimension_cache is not None:
return _embedding_dimension_cache
configured = os.getenv("ALIVEAI_EMBEDDING_DIMENSION", "").strip()
if configured:
try:
_embedding_dimension_cache = int(configured)
return _embedding_dimension_cache
except ValueError:
raise RuntimeError("ALIVEAI_EMBEDDING_DIMENSION must be an integer when provided.")
try:
from app.agent.kb_embedding import KBEmbeddingService
_embedding_dimension_cache = int(KBEmbeddingService().embedding_dimension())
return _embedding_dimension_cache
except Exception:
return None
def _validate_pinecone_index(client: Any) -> None:
if not client.has_index(PINECONE_INDEX_NAME):
expected_dimension = _resolve_embedding_dimension()
dimension_hint = (
str(expected_dimension)
if expected_dimension is not None
else "<your embedding dimension>"
)
raise RuntimeError(
"Pinecone index not found. "
f"Create index '{PINECONE_INDEX_NAME}' with dimension={dimension_hint} and metric='cosine'."
)
description = client.describe_index(PINECONE_INDEX_NAME)
description_dict = _to_dict(description)
expected_dimension = _resolve_embedding_dimension()
actual_dimension = description_dict.get("dimension")
if actual_dimension is None:
actual_dimension = getattr(description, "dimension", None)
if expected_dimension is not None and actual_dimension is not None and int(actual_dimension) != expected_dimension:
raise RuntimeError(
"Pinecone index dimension mismatch. "
f"Expected {expected_dimension}, found {actual_dimension} on '{PINECONE_INDEX_NAME}'."
)
metric = description_dict.get("metric")
if metric is None:
metric = getattr(description, "metric", None)
if metric and str(metric).lower() != "cosine":
raise RuntimeError(
"Pinecone index metric mismatch. "
f"Expected 'cosine', found '{metric}' on '{PINECONE_INDEX_NAME}'."
)
def _create_chroma_client_and_collection() -> None:
global _client, _collection, _active_backend
_client = chromadb.PersistentClient(path=str(CHROMA_PATH))
_collection = _client.get_or_create_collection(
name="medical_kb",
metadata={"hnsw:space": "cosine"},
)
_active_backend = "chroma"
def _create_pinecone_client_and_collection() -> None:
global _client, _collection, _active_backend
_client = Pinecone(api_key=PINECONE_API_KEY)
_validate_pinecone_index(_client)
if PINECONE_INDEX_HOST:
index = _client.Index(host=PINECONE_INDEX_HOST)
else:
index = _client.Index(PINECONE_INDEX_NAME)
_collection = PineconeCollectionAdapter(index=index)
_active_backend = "pinecone"
def _resolve_backend_choice(backend_override: str | None = None) -> str:
override = (backend_override or "").strip().lower()
if override in {"pinecone", "chroma"}:
return override
if REQUESTED_VECTOR_BACKEND == "pinecone":
return "pinecone"
if REQUESTED_VECTOR_BACKEND in {"chroma", "auto"}:
return "chroma"
return "chroma"
def _create_collection(backend_override: str | None = None) -> None:
backend_choice = _resolve_backend_choice(backend_override)
if backend_choice == "pinecone":
if not _PINECONE_AVAILABLE:
raise RuntimeError("Pinecone SDK is not installed. Run: python3.12 -m pip install pinecone")
if not PINECONE_API_KEY:
raise RuntimeError("Missing ALIVEAI_PINECONE_API_KEY while backend is set to pinecone.")
if not PINECONE_INDEX_NAME:
raise RuntimeError("Missing ALIVEAI_PINECONE_INDEX_NAME while backend is set to pinecone.")
_create_pinecone_client_and_collection()
return
_create_chroma_client_and_collection()
def get_collection(force_refresh: bool = False, backend_override: str | None = None):
global _collection
with _lock:
desired_backend = _resolve_backend_choice(backend_override)
if force_refresh or _collection is None or _active_backend != desired_backend:
_create_collection(backend_override=desired_backend)
return _collection
def get_vector_backend() -> str:
if _active_backend != "unknown":
return _active_backend
return _resolve_backend_choice()
def safe_count(retries: int = 3, delay_seconds: float = 0.5) -> int:
last_exc = None
for attempt in range(retries):
try:
return int(get_collection(force_refresh=(attempt > 0)).count())
except Exception as exc:
last_exc = exc
if attempt == retries - 1:
break
time.sleep(delay_seconds * (attempt + 1))
raise RuntimeError(
f"Failed to read vector DB count after {retries} attempts "
f"(backend={get_vector_backend()}): {last_exc}"
)

Xet Storage Details

Size:
10.7 kB
·
Xet hash:
430dae0b4f7c978704aebe285b43df1d6065e485d24826858e403e1e239d8260

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.