test2text / app /backend /chromadb_utils.py
Davide Panza
Update app/backend/chromadb_utils.py
cf868d7 verified
import chromadb
from chromadb.utils import embedding_functions
from .text_processing import text_chunking
def initialize_chromadb(EMBEDDING_MODEL, local_model_path=None):
"""
Initialize ChromaDB client and embedding function, using a local model path if provided.
"""
client = chromadb.Client()
if local_model_path:
embedding_func = embedding_functions.SentenceTransformerEmbeddingFunction(
model_name=local_model_path
)
else:
embedding_func = embedding_functions.SentenceTransformerEmbeddingFunction(
model_name=EMBEDDING_MODEL
)
return client, embedding_func
def initialize_collection(client, embedding_func, collection_name):
"""
Initialize a collection in ChromaDB.
"""
collection = client.get_or_create_collection(
name=collection_name,
embedding_function=embedding_func,
metadata={"hnsw:space": "cosine"},
)
return collection
def update_collection(
collection,
text,
max_words=200,
min_words=100,
overlap_sentences=3,
):
"""
Update the ChromaDB collection with text chunks.
Args:
collection: ChromaDB collection object.
text (str): The text to be chunked and added.
max_words (int): Maximum number of words per chunk.
min_words (int): Minimum number of words per chunk.
overlap_sentences (int): Number of sentences to overlap between chunks.
Returns:
None
"""
chunks = text_chunking(text, max_words=max_words, min_words=min_words, overlap_sentences=overlap_sentences)
collection.add(
documents=chunks,
ids=[f"chunk_{j:04d}" for j in range(len(chunks))],
metadatas=[{"chunk_index": j} for j in range(len(chunks))]
)