Spaces:
Running
Running
| import os | |
| from typing import List, Dict, Any, Tuple | |
| import chromadb | |
| from src.config import CHROMA_DIR, COLLECTION_NAME | |
| # ---------------- COLLECTION ---------------- | |
| def get_collection(): | |
| os.makedirs(CHROMA_DIR, exist_ok=True) | |
| client = chromadb.PersistentClient(path=CHROMA_DIR) | |
| return client.get_or_create_collection(COLLECTION_NAME) | |
| # ---------------- ADD DOCUMENTS ---------------- | |
| def add_documents( | |
| docs: List[str], | |
| embeddings: List[List[float]], | |
| metadatas: List[Dict[str, Any]], | |
| ids: List[str] | |
| ) -> None: | |
| col = get_collection() | |
| col.add( | |
| documents=docs, | |
| embeddings=embeddings, | |
| metadatas=metadatas, | |
| ids=ids | |
| ) | |
| # ---------------- QUERY ---------------- | |
| def query_by_embedding( | |
| q_embedding: List[float], | |
| top_k: int | |
| ) -> Tuple[List[str], List[Dict[str, Any]]]: | |
| col = get_collection() | |
| res = col.query( | |
| query_embeddings=[q_embedding], | |
| n_results=top_k, | |
| include=["documents", "metadatas"] | |
| ) | |
| return res["documents"][0], res["metadatas"][0] | |
| # ---------------- RESET ---------------- | |
| def reset_collection() -> None: | |
| os.makedirs(CHROMA_DIR, exist_ok=True) | |
| client = chromadb.PersistentClient(path=CHROMA_DIR) | |
| try: | |
| client.delete_collection(COLLECTION_NAME) | |
| except Exception: | |
| pass | |
| client.get_or_create_collection(COLLECTION_NAME) | |