meet4150/ALIV_AI / app /agent /kb_retrieval.py
download
raw
3.03 kB
from __future__ import annotations
import time
from app.agent.kb_embedding import KBEmbeddingService
from app.db.chroma_client import get_collection
def _query_collection(query: str, disease_id: str | None, top_k: int, include: list[str]):
query_embedding = KBEmbeddingService().embed(query)
query_kwargs = {
"query_embeddings": [query_embedding],
"n_results": top_k,
"include": include,
}
if disease_id and disease_id != "general":
query_kwargs["where"] = {"disease_id": disease_id}
last_exc = None
for attempt in range(5):
try:
collection = get_collection(force_refresh=(attempt > 0))
return collection.query(**query_kwargs)
except Exception as exc:
last_exc = exc
time.sleep(0.4 * (attempt + 1))
continue
raise RuntimeError(f"Vector DB query failed after retries: {last_exc}")
def retrieve(query: str, disease_id: str | None = None, top_k: int = 5) -> list[str]:
results = _query_collection(query, disease_id, top_k, include=["documents"])
documents = results.get("documents", [[]])
if documents and documents[0]:
return list(documents[0])
if disease_id and disease_id != "general":
fallback_results = _query_collection(query, None, top_k, include=["documents"])
fallback_documents = fallback_results.get("documents", [[]])
return list(fallback_documents[0]) if fallback_documents and fallback_documents[0] else []
return []
def retrieve_with_scores(
query: str,
disease_id: str | None = None,
top_k: int = 5,
) -> list[dict]:
results = _query_collection(
query,
disease_id,
top_k,
include=["documents", "distances", "metadatas"],
)
documents = results.get("documents", [[]])
distances = results.get("distances", [[]])
metadatas = results.get("metadatas", [[]])
if (not documents or not documents[0]) and disease_id and disease_id != "general":
fallback_results = _query_collection(
query,
None,
top_k,
include=["documents", "distances", "metadatas"],
)
documents = fallback_results.get("documents", [[]])
distances = fallback_results.get("distances", [[]])
metadatas = fallback_results.get("metadatas", [[]])
if not documents or not documents[0]:
return []
scored_results = []
for content, distance, metadata in zip(documents[0], distances[0], metadatas[0]):
score = 1 - float(distance)
scored_results.append(
{
"content": content,
"score": score,
"metadata": metadata,
}
)
return scored_results
def validate_similarity(text1: str, text2: str) -> float:
embedding_one, embedding_two = KBEmbeddingService().embed_batch([text1, text2])
return float(sum(value_a * value_b for value_a, value_b in zip(embedding_one, embedding_two)))

Xet Storage Details

Size:
3.03 kB
·
Xet hash:
07412d4eb655c311f0207b5cdd5c91bbc8698f8ccca08455445f9ace1a02020c

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.