Jaykay73's picture
Upload 39 files
76d540d verified
"""
AI Research Paper Helper - RAG Engine Service
Retrieval-Augmented Generation for paper Q&A.
"""
import logging
from typing import List, Optional, Dict, Tuple
from dataclasses import dataclass
import numpy as np
import asyncio
try:
import faiss
FAISS_AVAILABLE = True
except ImportError:
FAISS_AVAILABLE = False
import httpx
from config import settings, get_llm_config
from ml.embeddings import get_embedding_service
from ml.text_processor import get_text_processor, TextChunk
logger = logging.getLogger(__name__)
@dataclass
class Source:
"""A source chunk used to answer a question."""
chunk_id: str
text: str
section: Optional[str]
score: float
@dataclass
class RAGResponse:
"""Response from RAG query."""
answer: str
sources: List[Source]
confidence: float
class PaperIndex:
"""Index for a single paper."""
def __init__(self, paper_id: str, title: str):
self.paper_id = paper_id
self.title = title
self.chunks: List[TextChunk] = []
self.embeddings: Optional[np.ndarray] = None
self.faiss_index = None
def is_indexed(self) -> bool:
return self.embeddings is not None and len(self.chunks) > 0
class RAGEngine:
"""
Retrieval-Augmented Generation engine for paper Q&A.
Features:
- FAISS-based vector similarity search
- Section-aware chunking
- Answer grounding with citations
"""
def __init__(self):
self.embedding_service = get_embedding_service()
self.text_processor = get_text_processor()
self.paper_indices: Dict[str, PaperIndex] = {}
async def index_paper(
self,
paper_id: str,
title: str,
content: str,
abstract: Optional[str] = None,
sections: Optional[List[Dict]] = None
) -> Dict:
"""
Index a paper for RAG queries.
Args:
paper_id: Unique identifier for the paper
title: Paper title
content: Full paper content
abstract: Paper abstract
sections: Pre-extracted sections
Returns:
Dict with indexing status and stats
"""
logger.info(f"Indexing paper: {paper_id}")
# Create chunks
chunks = self.text_processor.chunk_document(
content=content,
abstract=abstract,
sections=sections
)
if not chunks:
logger.warning(f"No chunks created for paper {paper_id}")
return {"success": False, "error": "No content to index"}
# Generate embeddings
chunk_texts = [chunk.text for chunk in chunks]
embeddings = await self.embedding_service.embed_texts(chunk_texts)
embeddings_array = np.array(embeddings).astype('float32')
# Normalize embeddings for cosine similarity
faiss.normalize_L2(embeddings_array)
# Create FAISS index
dimension = embeddings_array.shape[1]
faiss_index = faiss.IndexFlatIP(dimension) # Inner product = cosine after normalization
faiss_index.add(embeddings_array)
# Store index
paper_index = PaperIndex(paper_id, title)
paper_index.chunks = chunks
paper_index.embeddings = embeddings_array
paper_index.faiss_index = faiss_index
self.paper_indices[paper_id] = paper_index
logger.info(f"Indexed {len(chunks)} chunks for paper {paper_id}")
return {
"success": True,
"paper_id": paper_id,
"chunks_indexed": len(chunks),
"embedding_dimension": dimension
}
async def query(
self,
query: str,
paper_id: str,
top_k: int = 5
) -> RAGResponse:
"""
Query a paper using RAG.
Args:
query: Natural language question
paper_id: ID of the indexed paper
top_k: Number of chunks to retrieve
Returns:
RAGResponse with answer and sources
"""
# Get paper index
paper_index = self.paper_indices.get(paper_id)
if not paper_index or not paper_index.is_indexed():
return RAGResponse(
answer="Paper not indexed. Please analyze the paper first.",
sources=[],
confidence=0.0
)
# Embed query
query_embedding = await self.embedding_service.embed_single(query)
query_embedding = query_embedding.astype('float32').reshape(1, -1)
faiss.normalize_L2(query_embedding)
# Search FAISS index
scores, indices = paper_index.faiss_index.search(query_embedding, min(top_k, len(paper_index.chunks)))
# Get relevant chunks
sources = []
context_chunks = []
for i, (idx, score) in enumerate(zip(indices[0], scores[0])):
if idx < 0 or idx >= len(paper_index.chunks):
continue
chunk = paper_index.chunks[idx]
sources.append(Source(
chunk_id=chunk.id,
text=chunk.text[:300] + "..." if len(chunk.text) > 300 else chunk.text,
section=chunk.section,
score=float(score)
))
context_chunks.append(chunk.text)
if not context_chunks:
return RAGResponse(
answer="No relevant content found for your question.",
sources=[],
confidence=0.0
)
# Generate answer
answer = await self._generate_answer(query, context_chunks, paper_index.title)
# Calculate confidence based on top retrieval scores
confidence = float(np.mean(scores[0][:3])) if len(scores[0]) >= 3 else float(np.mean(scores[0]))
confidence = min(1.0, max(0.0, confidence))
return RAGResponse(
answer=answer,
sources=sources,
confidence=confidence
)
async def _generate_answer(
self,
query: str,
context_chunks: List[str],
title: str
) -> str:
"""Generate answer from retrieved context."""
context = "\n\n".join([f"[Source {i+1}]: {chunk}" for i, chunk in enumerate(context_chunks)])
# Use LLM if available
config = get_llm_config()
if config and settings.api_mode in ['api', 'hybrid']:
return await self._answer_with_llm(query, context, title, config)
else:
return self._answer_extractive(query, context_chunks)
async def _answer_with_llm(
self,
query: str,
context: str,
title: str,
config: dict
) -> str:
"""Generate answer using LLM."""
prompt = f"""Based on the following excerpts from the paper "{title}", answer the question.
CONTEXT:
{context}
QUESTION: {query}
INSTRUCTIONS:
- Answer ONLY based on the provided context
- If the answer is not in the context, say "I couldn't find this information in the paper"
- Be concise but complete
- Reference source numbers [1], [2] etc. when citing specific information
ANSWER:"""
try:
async with httpx.AsyncClient(timeout=45.0) as client:
headers = {
"Authorization": f"Bearer {config['api_key']}",
"Content-Type": "application/json"
}
if config['provider'] == 'openrouter':
headers["HTTP-Referer"] = "https://ai-research-helper.local"
response = await client.post(
f"{config['base_url']}/chat/completions",
headers=headers,
json={
"model": config['model'],
"messages": [
{"role": "system", "content": "You are a helpful research assistant answering questions about academic papers. Only answer based on the provided context."},
{"role": "user", "content": prompt}
],
"temperature": 0.3,
"max_tokens": 500
}
)
response.raise_for_status()
return response.json()['choices'][0]['message']['content']
except Exception as e:
logger.error(f"LLM answer generation failed: {e}")
return self._answer_extractive(query, context.split("\n\n"))
def _answer_extractive(self, query: str, chunks: List[str]) -> str:
"""Simple extractive answer when LLM is not available."""
# Return the most relevant chunk as the answer
if chunks:
# Combine top chunks
combined = " ".join(chunks[:2])[:1000]
return f"Based on the paper: {combined}"
return "Unable to find relevant information."
def is_paper_indexed(self, paper_id: str) -> bool:
"""Check if a paper is indexed."""
paper_index = self.paper_indices.get(paper_id)
return paper_index is not None and paper_index.is_indexed()
def get_indexed_papers(self) -> List[str]:
"""Get list of indexed paper IDs."""
return [pid for pid, idx in self.paper_indices.items() if idx.is_indexed()]
def clear_index(self, paper_id: str) -> bool:
"""Clear index for a paper."""
if paper_id in self.paper_indices:
del self.paper_indices[paper_id]
return True
return False
# Singleton instance
_rag_engine = None
def get_rag_engine() -> RAGEngine:
"""Get the singleton RAG engine instance."""
global _rag_engine
if _rag_engine is None:
if not FAISS_AVAILABLE:
raise ImportError("FAISS is required for RAG. Install with: pip install faiss-cpu")
_rag_engine = RAGEngine()
return _rag_engine