meet4150/alive_pine / app /db /chroma_client.py
download
raw
8.55 kB
from __future__ import annotations
import os
import time
from pathlib import Path
from threading import Lock
from typing import Any
import chromadb
from dotenv import load_dotenv
try:
from pinecone import Pinecone
_PINECONE_AVAILABLE = True
except Exception:
Pinecone = None
_PINECONE_AVAILABLE = False
load_dotenv()
EMBEDDING_DIMENSION = 768
CHROMA_PATH = Path(__file__).resolve().parents[2] / "chroma_db"
CHROMA_PATH.mkdir(parents=True, exist_ok=True)
PINECONE_API_KEY = os.getenv("ALIVEAI_PINECONE_API_KEY", "").strip()
PINECONE_INDEX_NAME = os.getenv("ALIVEAI_PINECONE_INDEX_NAME", "").strip()
PINECONE_INDEX_HOST = os.getenv("ALIVEAI_PINECONE_INDEX_HOST", "").strip()
PINECONE_NAMESPACE = os.getenv("ALIVEAI_PINECONE_NAMESPACE", "").strip()
VECTOR_BACKEND = "pinecone"
_client: Any = None
_collection: Any = None
_lock = Lock()
class PineconeCollectionAdapter:
def __init__(self, index) -> None:
self._index = index
@staticmethod
def _namespace() -> str | None:
return PINECONE_NAMESPACE or None
@staticmethod
def _to_dict(value: Any) -> dict[str, Any]:
if isinstance(value, dict):
return value
to_dict = getattr(value, "to_dict", None)
if callable(to_dict):
return to_dict()
return {}
@staticmethod
def _match_get(match: Any, key: str, default: Any = None) -> Any:
if isinstance(match, dict):
return match.get(key, default)
return getattr(match, key, default)
@staticmethod
def _to_pinecone_filter(where: dict | None) -> dict | None:
if not where:
return None
converted: dict[str, Any] = {}
for key, value in where.items():
converted[key] = {"$eq": value}
return converted
def get(self, ids: list[str]) -> dict:
if not ids:
return {"ids": []}
response = self._index.fetch(ids=ids, namespace=self._namespace())
response_dict = self._to_dict(response)
vectors = response_dict.get("vectors", {})
if not vectors and not response_dict:
vectors = getattr(response, "vectors", {}) or {}
existing_ids = list(vectors.keys())
return {"ids": existing_ids}
def add(
self,
ids: list[str],
embeddings: list[list[float]],
documents: list[str],
metadatas: list[dict],
) -> None:
vectors = []
for record_id, embedding, document, metadata in zip(ids, embeddings, documents, metadatas):
merged_metadata = dict(metadata or {})
merged_metadata["_document"] = document
vectors.append(
{
"id": record_id,
"values": embedding,
"metadata": merged_metadata,
}
)
# Upsert is idempotent for Pinecone.
self._index.upsert(vectors=vectors, namespace=self._namespace())
def query(
self,
query_embeddings: list[list[float]],
n_results: int,
include: list[str],
where: dict | None = None,
) -> dict:
vector = query_embeddings[0]
pinecone_filter = self._to_pinecone_filter(where)
response = self._index.query(
vector=vector,
top_k=n_results,
filter=pinecone_filter,
namespace=self._namespace(),
include_metadata=True,
include_values=False,
)
response_dict = self._to_dict(response)
matches = response_dict.get("matches", [])
if not matches and not response_dict:
matches = getattr(response, "matches", []) or []
documents: list[str] = []
distances: list[float] = []
metadatas: list[dict] = []
for match in matches:
metadata = dict(self._match_get(match, "metadata", {}) or {})
documents.append(str(metadata.pop("_document", "")))
score = float(self._match_get(match, "score", 0.0))
distances.append(1.0 - score)
metadatas.append(metadata)
payload: dict[str, list[list[Any]]] = {}
if "documents" in include:
payload["documents"] = [documents]
if "distances" in include:
payload["distances"] = [distances]
if "metadatas" in include:
payload["metadatas"] = [metadatas]
return payload
def count(self) -> int:
stats = self._index.describe_index_stats()
stats_dict = self._to_dict(stats)
if stats_dict:
if PINECONE_NAMESPACE:
namespace_stats = (stats_dict.get("namespaces") or {}).get(PINECONE_NAMESPACE, {})
return int(namespace_stats.get("vector_count", 0))
return int(stats_dict.get("total_vector_count", 0))
if PINECONE_NAMESPACE:
namespaces = getattr(stats, "namespaces", {}) or {}
namespace_stats = namespaces.get(PINECONE_NAMESPACE, {})
return int(
namespace_stats.get("vector_count", 0)
if isinstance(namespace_stats, dict)
else getattr(namespace_stats, "vector_count", 0)
)
return int(getattr(stats, "total_vector_count", 0))
def get_vector_backend() -> str:
return VECTOR_BACKEND
def _to_dict(value: Any) -> dict[str, Any]:
if isinstance(value, dict):
return value
to_dict = getattr(value, "to_dict", None)
if callable(to_dict):
return to_dict()
return {}
def _validate_pinecone_index(client: Any) -> None:
if not client.has_index(PINECONE_INDEX_NAME):
raise RuntimeError(
"Pinecone index not found. "
f"Create index '{PINECONE_INDEX_NAME}' with dimension={EMBEDDING_DIMENSION} and metric='cosine'."
)
description = client.describe_index(PINECONE_INDEX_NAME)
description_dict = _to_dict(description)
dimension = description_dict.get("dimension")
if dimension is None:
dimension = getattr(description, "dimension", None)
if dimension is not None and int(dimension) != EMBEDDING_DIMENSION:
raise RuntimeError(
"Pinecone index dimension mismatch. "
f"Expected {EMBEDDING_DIMENSION}, found {dimension} on '{PINECONE_INDEX_NAME}'."
)
metric = description_dict.get("metric")
if metric is None:
metric = getattr(description, "metric", None)
if metric and str(metric).lower() != "cosine":
raise RuntimeError(
"Pinecone index metric mismatch. "
f"Expected 'cosine', found '{metric}' on '{PINECONE_INDEX_NAME}'."
)
def _create_chroma_client_and_collection() -> None:
global _client, _collection
_client = chromadb.PersistentClient(path=str(CHROMA_PATH))
_collection = _client.get_or_create_collection(
name="medical_kb",
metadata={"hnsw:space": "cosine"},
)
def _create_pinecone_client_and_collection() -> None:
global _client, _collection
_client = Pinecone(api_key=PINECONE_API_KEY)
_validate_pinecone_index(_client)
if PINECONE_INDEX_HOST:
index = _client.Index(host=PINECONE_INDEX_HOST)
else:
index = _client.Index(PINECONE_INDEX_NAME)
_collection = PineconeCollectionAdapter(index=index)
def _create_collection() -> None:
if not _PINECONE_AVAILABLE:
raise RuntimeError(
"Pinecone SDK is not installed. Run: python3.12 -m pip install pinecone"
)
if not PINECONE_API_KEY:
raise RuntimeError("Missing ALIVEAI_PINECONE_API_KEY for Pinecone-only mode.")
if not PINECONE_INDEX_NAME:
raise RuntimeError("Missing ALIVEAI_PINECONE_INDEX_NAME for Pinecone-only mode.")
_create_pinecone_client_and_collection()
def get_collection(force_refresh: bool = False):
global _collection
with _lock:
if force_refresh or _collection is None:
_create_collection()
return _collection
def safe_count(retries: int = 3, delay_seconds: float = 0.5) -> int:
last_exc = None
for attempt in range(retries):
try:
return int(get_collection(force_refresh=(attempt > 0)).count())
except Exception as exc:
last_exc = exc
if attempt == retries - 1:
break
time.sleep(delay_seconds * (attempt + 1))
raise RuntimeError(
f"Failed to read vector DB count after {retries} attempts "
f"(backend={VECTOR_BACKEND}): {last_exc}"
)

Xet Storage Details

Size:
8.55 kB
·
Xet hash:
b69d60857af8d4982d9a8f058bcc5750fcc0dfe2f341f2322e51b3b6ca1e52d1

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.