|
|
import os |
|
|
from typing import List, Dict, Any |
|
|
from sentence_transformers import SentenceTransformer |
|
|
import chromadb |
|
|
|
|
|
FOLDER_PATH = "/home/nishtha/document-based-assistant/docs" |
|
|
|
|
|
model = SentenceTransformer("all-MiniLM-L6-v2") |
|
|
chroma_client = chromadb.Client() |
|
|
|
|
|
def load_document(docs_path: str = FOLDER_PATH) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Load text documents from docs_path. |
|
|
""" |
|
|
documents = [] |
|
|
for doc in os.listdir(docs_path): |
|
|
filepath = os.path.join(docs_path, doc) |
|
|
with open(filepath, 'r', encoding='utf-8') as file: |
|
|
content = file.read() |
|
|
documents.append({ |
|
|
"filename": doc, |
|
|
"content": content |
|
|
}) |
|
|
return documents |
|
|
|
|
|
|
|
|
def chunk_text(content: str, max_tokens: int = 250, overlap_tokens: int = 50) -> List[str]: |
|
|
""" |
|
|
Clean, preprocess and split a long text into chunks of up to max_tokens, |
|
|
with overlap of overlap_tokens between consecutive chunks to preserve context. |
|
|
""" |
|
|
content = content.strip() |
|
|
content = " ".join(content.split()) |
|
|
content = content.lower() |
|
|
|
|
|
words = content.split() |
|
|
chunks = [] |
|
|
start = 0 |
|
|
length = len(words) |
|
|
|
|
|
while start < length: |
|
|
end = min(start + max_tokens, length) |
|
|
chunk = " ".join(words[start:end]) |
|
|
chunks.append(chunk) |
|
|
|
|
|
start += (max_tokens - overlap_tokens) |
|
|
|
|
|
return chunks |
|
|
|
|
|
|
|
|
def embed_text(chunks: List[str], batch_size: int = 32) -> List[List[float]]: |
|
|
""" |
|
|
Embed text chunks using the pre-loaded SentenceTransformer model. |
|
|
Returns a list of embeddings. |
|
|
""" |
|
|
embeddings = model.encode(chunks, batch_size=batch_size, convert_to_tensor=False, show_progress_bar=True) |
|
|
return embeddings |
|
|
|
|
|
|
|
|
def vector_store(collection, ids: list, documents: list, metadatas: list, embeddings: list = None): |
|
|
""" |
|
|
Create or get a ChromaDB collection for storing embeddings. |
|
|
""" |
|
|
collection = chroma_client.get_or_create_collection(name="document_assistant_collection") |
|
|
collection.add( |
|
|
ids=ids, |
|
|
documents=documents, |
|
|
metadatas=metadatas, |
|
|
embeddings=embeddings, |
|
|
) |
|
|
return collection |
|
|
|
|
|
|
|
|
def retrieve_query(query: str, collection, top_k: int = 5) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Retrieve the top_k most relevant document chunks for the given query from the collection.""" |
|
|
clean_q = query.strip().lower() |
|
|
q_emb = model.encode([clean_q], convert_to_tensor=False)[0] |
|
|
results = collection.query( |
|
|
query_embeddings=[q_emb], |
|
|
n_results=top_k, |
|
|
include=["documents", "metadatas", "distances"] |
|
|
) |
|
|
retrieved = [] |
|
|
docs_list = results["documents"][0] |
|
|
metas_list = results["metadatas"][0] |
|
|
dists_list = results["distances"][0] |
|
|
for text, meta, dist in zip(docs_list, metas_list, dists_list): |
|
|
retrieved.append({ |
|
|
"source": meta.get("source"), |
|
|
"text": text, |
|
|
"score": dist |
|
|
}) |
|
|
return retrieved |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
collection = chroma_client.get_or_create_collection(name="document_assistant_collection") |
|
|
|
|
|
|
|
|
docs = load_document() |
|
|
ids, chunks, metas = [], [], [] |
|
|
for d in docs: |
|
|
for idx, c in enumerate(chunk_text(d["content"], max_tokens=250, overlap_tokens=50)): |
|
|
ids.append(f"{d['filename']}_chunk{idx}") |
|
|
chunks.append(c) |
|
|
metas.append({"source": d["filename"], "chunk_index": idx}) |
|
|
|
|
|
|
|
|
embeddings = embed_text(chunks, batch_size=32) |
|
|
coll = vector_store(collection, ids=ids, documents=chunks, metadatas=metas, embeddings=embeddings) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|