Spaces:
Sleeping
Sleeping
File size: 3,585 Bytes
3c02b94 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 | import os
from typing import List, Dict
import openai
import chromadb
from chromadb.config import Settings
from chromadb import EmbeddingFunction
class OpenAIEmbedder(EmbeddingFunction):
def __init__(self):
openai.api_key = os.getenv('OPENAI_API_KEY')
def __call__(self, input: List[str]) -> List[List[float]]:
response = openai.Embedding.create(input=input, model="text-embedding-ada-002")
return [item['embedding'] for item in response['data']]
class KnowledgeBase:
def __init__(self, knowledge_dir='knowledge/docs'):
self.knowledge_dir = knowledge_dir
self.embedder = OpenAIEmbedder()
self.client = chromadb.Client(Settings(
anonymized_telemetry=False,
allow_reset=True
))
# Always reset collection to ensure clean state with correct embeddings
try:
self.client.delete_collection('math_knowledge')
except:
pass
self.collection = self.client.create_collection('math_knowledge', embedding_function=self.embedder)
self._load_documents()
def _load_documents(self):
docs = []
metadatas = []
ids = []
for filename in os.listdir(self.knowledge_dir):
if filename.endswith('.txt'):
filepath = os.path.join(self.knowledge_dir, filename)
topic = filename.replace('.txt', '')
with open(filepath, 'r', encoding='utf-8') as f:
content = f.read()
chunks = self._chunk_document(content)
for i, chunk in enumerate(chunks):
docs.append(chunk)
metadatas.append({
'topic': topic,
'source': filename
})
ids.append(f"{topic}_{i}")
if docs:
self.collection.add(
documents=docs,
metadatas=metadatas,
ids=ids
)
def _chunk_document(self, content: str, chunk_size: int = 500) -> List[str]:
lines = content.split('\n')
chunks = []
current_chunk = []
current_size = 0
for line in lines:
line_size = len(line)
if current_size + line_size > chunk_size and current_chunk:
chunks.append('\n'.join(current_chunk))
current_chunk = [line]
current_size = line_size
else:
current_chunk.append(line)
current_size += line_size
if current_chunk:
chunks.append('\n'.join(current_chunk))
return chunks
def search(self, query: str, topic: str = None, k: int = 3) -> List[Dict]:
where = {"topic": topic} if topic else None
results = self.collection.query(
query_texts=[query],
n_results=k,
where=where
)
retrieved = []
if results['documents']:
for i, doc in enumerate(results['documents'][0]):
retrieved.append({
'content': doc,
'metadata': results['metadatas'][0][i],
'distance': results['distances'][0][i] if 'distances' in results else 0
})
return retrieved |