document-based-assistant / retriever.py
naveensharma16's picture
Upload 8 files
178f14f verified
import os
from typing import List, Dict, Any
from sentence_transformers import SentenceTransformer
import chromadb
FOLDER_PATH = "/home/nishtha/document-based-assistant/docs"
model = SentenceTransformer("all-MiniLM-L6-v2")
chroma_client = chromadb.Client()
def load_document(docs_path: str = FOLDER_PATH) -> List[Dict[str, Any]]:
"""
Load text documents from docs_path.
"""
documents = []
for doc in os.listdir(docs_path):
filepath = os.path.join(docs_path, doc)
with open(filepath, 'r', encoding='utf-8') as file:
content = file.read()
documents.append({
"filename": doc,
"content": content
})
return documents
def chunk_text(content: str, max_tokens: int = 250, overlap_tokens: int = 50) -> List[str]:
"""
Clean, preprocess and split a long text into chunks of up to max_tokens,
with overlap of overlap_tokens between consecutive chunks to preserve context.
"""
content = content.strip()
content = " ".join(content.split())
content = content.lower()
words = content.split()
chunks = []
start = 0
length = len(words)
while start < length:
end = min(start + max_tokens, length)
chunk = " ".join(words[start:end])
chunks.append(chunk)
start += (max_tokens - overlap_tokens)
return chunks
def embed_text(chunks: List[str], batch_size: int = 32) -> List[List[float]]:
"""
Embed text chunks using the pre-loaded SentenceTransformer model.
Returns a list of embeddings.
"""
embeddings = model.encode(chunks, batch_size=batch_size, convert_to_tensor=False, show_progress_bar=True)
return embeddings
def vector_store(collection, ids: list, documents: list, metadatas: list, embeddings: list = None):
"""
Create or get a ChromaDB collection for storing embeddings.
"""
collection = chroma_client.get_or_create_collection(name="document_assistant_collection")
collection.add(
ids=ids,
documents=documents,
metadatas=metadatas,
embeddings=embeddings,
)
return collection
def retrieve_query(query: str, collection, top_k: int = 5) -> List[Dict[str, Any]]:
"""
Retrieve the top_k most relevant document chunks for the given query from the collection."""
clean_q = query.strip().lower()
q_emb = model.encode([clean_q], convert_to_tensor=False)[0]
results = collection.query(
query_embeddings=[q_emb],
n_results=top_k,
include=["documents", "metadatas", "distances"]
)
retrieved = []
docs_list = results["documents"][0]
metas_list = results["metadatas"][0]
dists_list = results["distances"][0]
for text, meta, dist in zip(docs_list, metas_list, dists_list):
retrieved.append({
"source": meta.get("source"),
"text": text,
"score": dist
})
return retrieved
if __name__ == "__main__":
# Get or create the collection
collection = chroma_client.get_or_create_collection(name="document_assistant_collection")
# Load and prepare documents
docs = load_document()
ids, chunks, metas = [], [], []
for d in docs:
for idx, c in enumerate(chunk_text(d["content"], max_tokens=250, overlap_tokens=50)):
ids.append(f"{d['filename']}_chunk{idx}")
chunks.append(c)
metas.append({"source": d["filename"], "chunk_index": idx})
# Embed chunks and add to vector store
embeddings = embed_text(chunks, batch_size=32)
coll = vector_store(collection, ids=ids, documents=chunks, metadatas=metas, embeddings=embeddings)
# Interactive question loop
# while True:
# q = input("Your question (or 'exit'): ")
# if q.lower() in ("exit", "quit"):
# break
# results = retrieve_query(q, coll, top_k=5)
# for r in results:
# print(f"Source: {r['source']} | Score: {r['score']}\nText: {r['text']}\n---")