Spaces:
Running
Running
File size: 1,382 Bytes
c4233b7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
import os
from typing import List, Dict, Any, Tuple
import chromadb
from src.config import CHROMA_DIR, COLLECTION_NAME
# ---------------- COLLECTION ----------------
def get_collection():
os.makedirs(CHROMA_DIR, exist_ok=True)
client = chromadb.PersistentClient(path=CHROMA_DIR)
return client.get_or_create_collection(COLLECTION_NAME)
# ---------------- ADD DOCUMENTS ----------------
def add_documents(
docs: List[str],
embeddings: List[List[float]],
metadatas: List[Dict[str, Any]],
ids: List[str]
) -> None:
col = get_collection()
col.add(
documents=docs,
embeddings=embeddings,
metadatas=metadatas,
ids=ids
)
# ---------------- QUERY ----------------
def query_by_embedding(
q_embedding: List[float],
top_k: int
) -> Tuple[List[str], List[Dict[str, Any]]]:
col = get_collection()
res = col.query(
query_embeddings=[q_embedding],
n_results=top_k,
include=["documents", "metadatas"]
)
return res["documents"][0], res["metadatas"][0]
# ---------------- RESET ----------------
def reset_collection() -> None:
os.makedirs(CHROMA_DIR, exist_ok=True)
client = chromadb.PersistentClient(path=CHROMA_DIR)
try:
client.delete_collection(COLLECTION_NAME)
except Exception:
pass
client.get_or_create_collection(COLLECTION_NAME)
|