| from __future__ import annotations | |
| import time | |
| from app.agent.kb_embedding import KBEmbeddingService | |
| from app.db.chroma_client import get_collection | |
| def _query_collection(query: str, disease_id: str | None, top_k: int, include: list[str]): | |
| query_embedding = KBEmbeddingService().embed(query) | |
| query_kwargs = { | |
| "query_embeddings": [query_embedding], | |
| "n_results": top_k, | |
| "include": include, | |
| } | |
| if disease_id and disease_id != "general": | |
| query_kwargs["where"] = {"disease_id": disease_id} | |
| last_exc = None | |
| for attempt in range(5): | |
| try: | |
| collection = get_collection(force_refresh=(attempt > 0)) | |
| return collection.query(**query_kwargs) | |
| except Exception as exc: | |
| last_exc = exc | |
| time.sleep(0.4 * (attempt + 1)) | |
| continue | |
| raise RuntimeError(f"Vector DB query failed after retries: {last_exc}") | |
| def retrieve(query: str, disease_id: str | None = None, top_k: int = 5) -> list[str]: | |
| results = _query_collection(query, disease_id, top_k, include=["documents"]) | |
| documents = results.get("documents", [[]]) | |
| if documents and documents[0]: | |
| return list(documents[0]) | |
| if disease_id and disease_id != "general": | |
| fallback_results = _query_collection(query, None, top_k, include=["documents"]) | |
| fallback_documents = fallback_results.get("documents", [[]]) | |
| return list(fallback_documents[0]) if fallback_documents and fallback_documents[0] else [] | |
| return [] | |
| def retrieve_with_scores( | |
| query: str, | |
| disease_id: str | None = None, | |
| top_k: int = 5, | |
| ) -> list[dict]: | |
| results = _query_collection( | |
| query, | |
| disease_id, | |
| top_k, | |
| include=["documents", "distances", "metadatas"], | |
| ) | |
| documents = results.get("documents", [[]]) | |
| distances = results.get("distances", [[]]) | |
| metadatas = results.get("metadatas", [[]]) | |
| if (not documents or not documents[0]) and disease_id and disease_id != "general": | |
| fallback_results = _query_collection( | |
| query, | |
| None, | |
| top_k, | |
| include=["documents", "distances", "metadatas"], | |
| ) | |
| documents = fallback_results.get("documents", [[]]) | |
| distances = fallback_results.get("distances", [[]]) | |
| metadatas = fallback_results.get("metadatas", [[]]) | |
| if not documents or not documents[0]: | |
| return [] | |
| scored_results = [] | |
| for content, distance, metadata in zip(documents[0], distances[0], metadatas[0]): | |
| score = 1 - float(distance) | |
| scored_results.append( | |
| { | |
| "content": content, | |
| "score": score, | |
| "metadata": metadata, | |
| } | |
| ) | |
| return scored_results | |
| def validate_similarity(text1: str, text2: str) -> float: | |
| embedding_one, embedding_two = KBEmbeddingService().embed_batch([text1, text2]) | |
| return float(sum(value_a * value_b for value_a, value_b in zip(embedding_one, embedding_two))) | |
Xet Storage Details
- Size:
- 3.03 kB
- Xet hash:
- 07412d4eb655c311f0207b5cdd5c91bbc8698f8ccca08455445f9ace1a02020c
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.