Math-Mentor / rag /knowledge_base.py
krushnakant27's picture
Upload 29 files
3c02b94 verified
import os
from typing import List, Dict
import openai
import chromadb
from chromadb.config import Settings
from chromadb import EmbeddingFunction
class OpenAIEmbedder(EmbeddingFunction):
def __init__(self):
openai.api_key = os.getenv('OPENAI_API_KEY')
def __call__(self, input: List[str]) -> List[List[float]]:
response = openai.Embedding.create(input=input, model="text-embedding-ada-002")
return [item['embedding'] for item in response['data']]
class KnowledgeBase:
def __init__(self, knowledge_dir='knowledge/docs'):
self.knowledge_dir = knowledge_dir
self.embedder = OpenAIEmbedder()
self.client = chromadb.Client(Settings(
anonymized_telemetry=False,
allow_reset=True
))
# Always reset collection to ensure clean state with correct embeddings
try:
self.client.delete_collection('math_knowledge')
except:
pass
self.collection = self.client.create_collection('math_knowledge', embedding_function=self.embedder)
self._load_documents()
def _load_documents(self):
docs = []
metadatas = []
ids = []
for filename in os.listdir(self.knowledge_dir):
if filename.endswith('.txt'):
filepath = os.path.join(self.knowledge_dir, filename)
topic = filename.replace('.txt', '')
with open(filepath, 'r', encoding='utf-8') as f:
content = f.read()
chunks = self._chunk_document(content)
for i, chunk in enumerate(chunks):
docs.append(chunk)
metadatas.append({
'topic': topic,
'source': filename
})
ids.append(f"{topic}_{i}")
if docs:
self.collection.add(
documents=docs,
metadatas=metadatas,
ids=ids
)
def _chunk_document(self, content: str, chunk_size: int = 500) -> List[str]:
lines = content.split('\n')
chunks = []
current_chunk = []
current_size = 0
for line in lines:
line_size = len(line)
if current_size + line_size > chunk_size and current_chunk:
chunks.append('\n'.join(current_chunk))
current_chunk = [line]
current_size = line_size
else:
current_chunk.append(line)
current_size += line_size
if current_chunk:
chunks.append('\n'.join(current_chunk))
return chunks
def search(self, query: str, topic: str = None, k: int = 3) -> List[Dict]:
where = {"topic": topic} if topic else None
results = self.collection.query(
query_texts=[query],
n_results=k,
where=where
)
retrieved = []
if results['documents']:
for i, doc in enumerate(results['documents'][0]):
retrieved.append({
'content': doc,
'metadata': results['metadatas'][0][i],
'distance': results['distances'][0][i] if 'distances' in results else 0
})
return retrieved