Spaces:
Sleeping
Sleeping
| import os | |
| from typing import List, Dict | |
| import openai | |
| import chromadb | |
| from chromadb.config import Settings | |
| from chromadb import EmbeddingFunction | |
| class OpenAIEmbedder(EmbeddingFunction): | |
| def __init__(self): | |
| openai.api_key = os.getenv('OPENAI_API_KEY') | |
| def __call__(self, input: List[str]) -> List[List[float]]: | |
| response = openai.Embedding.create(input=input, model="text-embedding-ada-002") | |
| return [item['embedding'] for item in response['data']] | |
| class KnowledgeBase: | |
| def __init__(self, knowledge_dir='knowledge/docs'): | |
| self.knowledge_dir = knowledge_dir | |
| self.embedder = OpenAIEmbedder() | |
| self.client = chromadb.Client(Settings( | |
| anonymized_telemetry=False, | |
| allow_reset=True | |
| )) | |
| # Always reset collection to ensure clean state with correct embeddings | |
| try: | |
| self.client.delete_collection('math_knowledge') | |
| except: | |
| pass | |
| self.collection = self.client.create_collection('math_knowledge', embedding_function=self.embedder) | |
| self._load_documents() | |
| def _load_documents(self): | |
| docs = [] | |
| metadatas = [] | |
| ids = [] | |
| for filename in os.listdir(self.knowledge_dir): | |
| if filename.endswith('.txt'): | |
| filepath = os.path.join(self.knowledge_dir, filename) | |
| topic = filename.replace('.txt', '') | |
| with open(filepath, 'r', encoding='utf-8') as f: | |
| content = f.read() | |
| chunks = self._chunk_document(content) | |
| for i, chunk in enumerate(chunks): | |
| docs.append(chunk) | |
| metadatas.append({ | |
| 'topic': topic, | |
| 'source': filename | |
| }) | |
| ids.append(f"{topic}_{i}") | |
| if docs: | |
| self.collection.add( | |
| documents=docs, | |
| metadatas=metadatas, | |
| ids=ids | |
| ) | |
| def _chunk_document(self, content: str, chunk_size: int = 500) -> List[str]: | |
| lines = content.split('\n') | |
| chunks = [] | |
| current_chunk = [] | |
| current_size = 0 | |
| for line in lines: | |
| line_size = len(line) | |
| if current_size + line_size > chunk_size and current_chunk: | |
| chunks.append('\n'.join(current_chunk)) | |
| current_chunk = [line] | |
| current_size = line_size | |
| else: | |
| current_chunk.append(line) | |
| current_size += line_size | |
| if current_chunk: | |
| chunks.append('\n'.join(current_chunk)) | |
| return chunks | |
| def search(self, query: str, topic: str = None, k: int = 3) -> List[Dict]: | |
| where = {"topic": topic} if topic else None | |
| results = self.collection.query( | |
| query_texts=[query], | |
| n_results=k, | |
| where=where | |
| ) | |
| retrieved = [] | |
| if results['documents']: | |
| for i, doc in enumerate(results['documents'][0]): | |
| retrieved.append({ | |
| 'content': doc, | |
| 'metadata': results['metadatas'][0][i], | |
| 'distance': results['distances'][0][i] if 'distances' in results else 0 | |
| }) | |
| return retrieved |