Spaces:
No application file
No application file
| import json | |
| import os | |
| import re | |
| import math | |
| from collections import Counter | |
| from typing import List, Dict, Any, Tuple | |
| class RAGEngine: | |
| """ | |
| Real-time Retrieval Augmented Generation Engine. | |
| Provides document retrieval for ECH0's knowledge base using lightweight TF-IDF. | |
| """ | |
| def __init__(self, knowledge_file: str): | |
| self.knowledge_file = knowledge_file | |
| self.data_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "data") | |
| self.documents = [] # List of {"id": str, "content": str, "metadata": dict} | |
| self.index = {} # word -> doc_freq | |
| self.doc_vectors = [] | |
| self.load_knowledge() | |
| self.load_scientific_data() # Index extra papers/data | |
| self.build_index() | |
| def load_knowledge(self): | |
| """Loads and processes the tool knowledge base.""" | |
| if not os.path.exists(self.knowledge_file): | |
| print(f"[RAG] Warning: Knowledge file {self.knowledge_file} not found.") | |
| return | |
| try: | |
| with open(self.knowledge_file, 'r') as f: | |
| data = json.load(f) | |
| except Exception as e: | |
| print(f"[RAG] Error loading tools: {e}") | |
| return | |
| # Flatten the knowledge base into retrievable chunks | |
| for lab in data: | |
| # Lab description doc | |
| lab_text = f"{lab['name']} {lab.get('description', '')}" | |
| self.documents.append({ | |
| "id": f"MSG_LAB_{lab['id']}", | |
| "content": lab_text, | |
| "type": "lab_overview", | |
| "metadata": lab | |
| }) | |
| # Capability docs | |
| for cap in lab.get('capabilities', []): | |
| cap_text = f"{lab['name']} {cap['tool_name']} {cap.get('doc', '')} {' '.join(cap.get('args', []))}" | |
| self.documents.append({ | |
| "id": f"MSG_TOOL_{lab['id']}_{cap['tool_name']}", | |
| "content": cap_text, | |
| "type": "tool", | |
| "metadata": {**cap, "lab_id": lab['id'], "lab_name": lab['name']} | |
| }) | |
| def load_scientific_data(self): | |
| """Scans the data directory for scientific papers and research JSONs.""" | |
| if not os.path.exists(self.data_dir): | |
| return | |
| # Specific high-value sources for indexing | |
| targets = [ | |
| "arxiv_ingestion/papers_20251103_daily.json", | |
| "ech0_cancer_language_research.json", | |
| "materials_db_expanded.json" | |
| ] | |
| for rel_path in targets: | |
| full_path = os.path.join(self.data_dir, rel_path) | |
| if not os.path.exists(full_path): | |
| continue | |
| try: | |
| with open(full_path, 'r') as f: | |
| data = json.load(f) | |
| # Handle list of objects (papers/records) | |
| if isinstance(data, list): | |
| for item in data[:200]: # Cap per file to avoid bloat | |
| title = item.get('title', item.get('name', 'Record')) | |
| abstract = item.get('abstract', item.get('summary', item.get('description', ''))) | |
| content = f"{title} {abstract}" | |
| self.documents.append({ | |
| "id": f"MSG_DATA_{rel_path}_{item.get('id', hash(content))}", | |
| "content": content, | |
| "type": "scientific_data", | |
| "metadata": item | |
| }) | |
| except Exception as e: | |
| print(f"[RAG] Error indexing {rel_path}: {e}") | |
| def _tokenize(self, text: str) -> List[str]: | |
| return [w.lower() for w in re.findall(r'\w+', text) if len(w) > 2] | |
| def build_index(self): | |
| """Builds TF-IDF index.""" | |
| # Calculate DF | |
| doc_counts = Counter() | |
| for doc in self.documents: | |
| terms = set(self._tokenize(doc['content'])) | |
| for term in terms: | |
| doc_counts[term] += 1 | |
| self.idf = {term: math.log(len(self.documents) / (count + 1)) for term, count in doc_counts.items()} | |
| # Calculate Vectors | |
| self.doc_vectors = [] | |
| for doc in self.documents: | |
| vec = self._text_to_vector(doc['content']) | |
| self.doc_vectors.append(vec) | |
| def _text_to_vector(self, text: str) -> Dict[str, float]: | |
| tf = Counter(self._tokenize(text)) | |
| vector = {} | |
| for term, count in tf.items(): | |
| if term in self.idf: | |
| vector[term] = count * self.idf[term] | |
| # Normalize | |
| norm = math.sqrt(sum(v*v for v in vector.values())) | |
| if norm > 0: | |
| for term in vector: | |
| vector[term] /= norm | |
| return vector | |
| def _cosine_similarity(self, vec1: Dict[str, float], vec2: Dict[str, float]) -> float: | |
| intersection = set(vec1.keys()) & set(vec2.keys()) | |
| dot_product = sum(vec1[term] * vec2[term] for term in intersection) | |
| return dot_product # Vectors are already normalized | |
| def retrieve(self, query: str, top_k: int = 5) -> List[Dict]: | |
| """Retrieves top-k relevant documents.""" | |
| query_vec = self._text_to_vector(query) | |
| scores = [] | |
| for i, doc_vec in enumerate(self.doc_vectors): | |
| score = self._cosine_similarity(query_vec, doc_vec) | |
| scores.append((score, self.documents[i])) | |
| scores.sort(key=lambda x: x[0], reverse=True) | |
| return [item[1] for item in scores[:top_k] if item[0] > 0.05] # Filter noise | |
| def get_context_formatted(self, query: str) -> str: | |
| """Returns a string formatted for LLM context.""" | |
| docs = self.retrieve(query, top_k=7) | |
| if not docs: | |
| return "No specific lab knowledge found for this query." | |
| context = "RELEVANT KNOWLEDGE:\n" | |
| for doc in docs: | |
| if doc['type'] == 'tool': | |
| meta = doc['metadata'] | |
| context += f"- Tool: {meta['lab_name']}.{meta['tool_name']}({', '.join(meta.get('args', []))})\n Desc: {meta.get('doc', '')}\n" | |
| else: | |
| context += f"- Lab: {doc['metadata']['name']}: {doc['metadata']['description'][:200]}...\n" | |
| return context | |
| def recommend_tools(self, query: str, top_k: int = 15) -> List[str]: | |
| """Returns a list of tool names recommended for the query.""" | |
| docs = self.retrieve(query, top_k=top_k) | |
| tool_names = [] | |
| for doc in docs: | |
| if doc['type'] == 'tool': | |
| meta = doc['metadata'] | |
| tool_names.append(f"{meta['lab_id'].replace('_lab.py', '')}.{meta['tool_name']}") | |
| return list(dict.fromkeys(tool_names)) # Deduplicate | |
| # Singleton instance | |
| rag_engine = RAGEngine(os.path.join(os.path.dirname(__file__), "ech0_knowledge.json")) | |