qulab-infinite / agent_lab /backend /rag_engine.py
workofarttattoo's picture
πŸš€ QuLab MCP Server: Complete Experiment Taxonomy Deployment
91994bf
import json
import os
import re
import math
from collections import Counter
from typing import List, Dict, Any, Tuple
class RAGEngine:
"""
Real-time Retrieval Augmented Generation Engine.
Provides document retrieval for ECH0's knowledge base using lightweight TF-IDF.
"""
def __init__(self, knowledge_file: str):
self.knowledge_file = knowledge_file
self.data_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "data")
self.documents = [] # List of {"id": str, "content": str, "metadata": dict}
self.index = {} # word -> doc_freq
self.doc_vectors = []
self.load_knowledge()
self.load_scientific_data() # Index extra papers/data
self.build_index()
def load_knowledge(self):
"""Loads and processes the tool knowledge base."""
if not os.path.exists(self.knowledge_file):
print(f"[RAG] Warning: Knowledge file {self.knowledge_file} not found.")
return
try:
with open(self.knowledge_file, 'r') as f:
data = json.load(f)
except Exception as e:
print(f"[RAG] Error loading tools: {e}")
return
# Flatten the knowledge base into retrievable chunks
for lab in data:
# Lab description doc
lab_text = f"{lab['name']} {lab.get('description', '')}"
self.documents.append({
"id": f"MSG_LAB_{lab['id']}",
"content": lab_text,
"type": "lab_overview",
"metadata": lab
})
# Capability docs
for cap in lab.get('capabilities', []):
cap_text = f"{lab['name']} {cap['tool_name']} {cap.get('doc', '')} {' '.join(cap.get('args', []))}"
self.documents.append({
"id": f"MSG_TOOL_{lab['id']}_{cap['tool_name']}",
"content": cap_text,
"type": "tool",
"metadata": {**cap, "lab_id": lab['id'], "lab_name": lab['name']}
})
def load_scientific_data(self):
"""Scans the data directory for scientific papers and research JSONs."""
if not os.path.exists(self.data_dir):
return
# Specific high-value sources for indexing
targets = [
"arxiv_ingestion/papers_20251103_daily.json",
"ech0_cancer_language_research.json",
"materials_db_expanded.json"
]
for rel_path in targets:
full_path = os.path.join(self.data_dir, rel_path)
if not os.path.exists(full_path):
continue
try:
with open(full_path, 'r') as f:
data = json.load(f)
# Handle list of objects (papers/records)
if isinstance(data, list):
for item in data[:200]: # Cap per file to avoid bloat
title = item.get('title', item.get('name', 'Record'))
abstract = item.get('abstract', item.get('summary', item.get('description', '')))
content = f"{title} {abstract}"
self.documents.append({
"id": f"MSG_DATA_{rel_path}_{item.get('id', hash(content))}",
"content": content,
"type": "scientific_data",
"metadata": item
})
except Exception as e:
print(f"[RAG] Error indexing {rel_path}: {e}")
def _tokenize(self, text: str) -> List[str]:
return [w.lower() for w in re.findall(r'\w+', text) if len(w) > 2]
def build_index(self):
"""Builds TF-IDF index."""
# Calculate DF
doc_counts = Counter()
for doc in self.documents:
terms = set(self._tokenize(doc['content']))
for term in terms:
doc_counts[term] += 1
self.idf = {term: math.log(len(self.documents) / (count + 1)) for term, count in doc_counts.items()}
# Calculate Vectors
self.doc_vectors = []
for doc in self.documents:
vec = self._text_to_vector(doc['content'])
self.doc_vectors.append(vec)
def _text_to_vector(self, text: str) -> Dict[str, float]:
tf = Counter(self._tokenize(text))
vector = {}
for term, count in tf.items():
if term in self.idf:
vector[term] = count * self.idf[term]
# Normalize
norm = math.sqrt(sum(v*v for v in vector.values()))
if norm > 0:
for term in vector:
vector[term] /= norm
return vector
def _cosine_similarity(self, vec1: Dict[str, float], vec2: Dict[str, float]) -> float:
intersection = set(vec1.keys()) & set(vec2.keys())
dot_product = sum(vec1[term] * vec2[term] for term in intersection)
return dot_product # Vectors are already normalized
def retrieve(self, query: str, top_k: int = 5) -> List[Dict]:
"""Retrieves top-k relevant documents."""
query_vec = self._text_to_vector(query)
scores = []
for i, doc_vec in enumerate(self.doc_vectors):
score = self._cosine_similarity(query_vec, doc_vec)
scores.append((score, self.documents[i]))
scores.sort(key=lambda x: x[0], reverse=True)
return [item[1] for item in scores[:top_k] if item[0] > 0.05] # Filter noise
def get_context_formatted(self, query: str) -> str:
"""Returns a string formatted for LLM context."""
docs = self.retrieve(query, top_k=7)
if not docs:
return "No specific lab knowledge found for this query."
context = "RELEVANT KNOWLEDGE:\n"
for doc in docs:
if doc['type'] == 'tool':
meta = doc['metadata']
context += f"- Tool: {meta['lab_name']}.{meta['tool_name']}({', '.join(meta.get('args', []))})\n Desc: {meta.get('doc', '')}\n"
else:
context += f"- Lab: {doc['metadata']['name']}: {doc['metadata']['description'][:200]}...\n"
return context
def recommend_tools(self, query: str, top_k: int = 15) -> List[str]:
"""Returns a list of tool names recommended for the query."""
docs = self.retrieve(query, top_k=top_k)
tool_names = []
for doc in docs:
if doc['type'] == 'tool':
meta = doc['metadata']
tool_names.append(f"{meta['lab_id'].replace('_lab.py', '')}.{meta['tool_name']}")
return list(dict.fromkeys(tool_names)) # Deduplicate
# Singleton instance
rag_engine = RAGEngine(os.path.join(os.path.dirname(__file__), "ech0_knowledge.json"))