GitHub Actions
Clean sync from GitHub - no large files in history
aca8ab4
"""
Retriever Agent: Search arXiv, download papers, and chunk for RAG.
Includes intelligent fallback from MCP/FastMCP to direct arXiv API.
"""
import logging
from typing import Dict, Any, Optional, List
from pathlib import Path
from utils.arxiv_client import ArxivClient
from utils.pdf_processor import PDFProcessor
from utils.schemas import AgentState, PaperChunk, Paper
from rag.vector_store import VectorStore
from rag.embeddings import EmbeddingGenerator
from utils.langfuse_client import observe
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Import MCP clients for type hints
try:
from utils.mcp_arxiv_client import MCPArxivClient
except ImportError:
MCPArxivClient = None
try:
from utils.fastmcp_arxiv_client import FastMCPArxivClient
except ImportError:
FastMCPArxivClient = None
class RetrieverAgent:
"""Agent for retrieving and processing papers from arXiv with intelligent fallback."""
def __init__(
self,
arxiv_client: Any,
pdf_processor: PDFProcessor,
vector_store: VectorStore,
embedding_generator: EmbeddingGenerator,
fallback_client: Optional[Any] = None
):
"""
Initialize Retriever Agent with fallback support.
Args:
arxiv_client: Primary client (ArxivClient, MCPArxivClient, or FastMCPArxivClient)
pdf_processor: PDFProcessor instance
vector_store: VectorStore instance
embedding_generator: EmbeddingGenerator instance
fallback_client: Optional fallback client (usually direct ArxivClient) used if primary fails
"""
self.arxiv_client = arxiv_client
self.pdf_processor = pdf_processor
self.vector_store = vector_store
self.embedding_generator = embedding_generator
self.fallback_client = fallback_client
# Log client configuration
client_name = type(arxiv_client).__name__
logger.info(f"RetrieverAgent initialized with primary client: {client_name}")
if fallback_client:
fallback_name = type(fallback_client).__name__
logger.info(f"Fallback client configured: {fallback_name}")
def _search_with_fallback(
self,
query: str,
max_results: int,
category: Optional[str]
) -> Optional[List[Paper]]:
"""
Search for papers with automatic fallback.
Args:
query: Search query
max_results: Maximum number of papers
category: Optional category filter
Returns:
List of Paper objects, or None if both primary and fallback fail
"""
# Try primary client
try:
logger.info(f"Searching with primary client ({type(self.arxiv_client).__name__})")
papers = self.arxiv_client.search_papers(
query=query,
max_results=max_results,
category=category
)
if papers:
logger.info(f"Primary client found {len(papers)} papers")
return papers
else:
logger.warning("Primary client returned no papers")
except Exception as e:
logger.error(f"Primary client search failed: {str(e)}")
# Try fallback client if available
if self.fallback_client:
try:
logger.warning(f"Attempting fallback with {type(self.fallback_client).__name__}")
papers = self.fallback_client.search_papers(
query=query,
max_results=max_results,
category=category
)
if papers:
logger.info(f"Fallback client found {len(papers)} papers")
return papers
else:
logger.error("Fallback client returned no papers")
except Exception as e:
logger.error(f"Fallback client search failed: {str(e)}")
logger.error("All search attempts failed")
return None
def _download_with_fallback(self, paper: Paper) -> Optional[Path]:
"""
Download paper with automatic fallback.
Args:
paper: Paper object to download
Returns:
Path to downloaded PDF, or None if both primary and fallback fail
"""
# Try primary client
try:
path = self.arxiv_client.download_paper(paper)
if path:
logger.debug(f"Primary client downloaded {paper.arxiv_id}")
return path
else:
logger.warning(f"Primary client failed to download {paper.arxiv_id}")
except Exception as e:
logger.error(f"Primary client download error for {paper.arxiv_id}: {str(e)}")
# Try fallback client if available
if self.fallback_client:
try:
logger.debug(f"Attempting fallback download for {paper.arxiv_id}")
path = self.fallback_client.download_paper(paper)
if path:
logger.info(f"Fallback client downloaded {paper.arxiv_id}")
return path
else:
logger.error(f"Fallback client failed to download {paper.arxiv_id}")
except Exception as e:
logger.error(f"Fallback client download error for {paper.arxiv_id}: {str(e)}")
logger.error(f"All download attempts failed for {paper.arxiv_id}")
return None
@observe(name="retriever_agent_run", as_type="generation")
def run(self, state: Dict[str, Any]) -> Dict[str, Any]:
"""
Execute retriever agent.
Args:
state: Current agent state
Returns:
Updated state with papers and chunks
"""
try:
logger.info("=== Retriever Agent Started ===")
query = state.get("query")
category = state.get("category")
num_papers = state.get("num_papers", 5)
logger.info(f"Query: {query}")
logger.info(f"Category: {category}")
logger.info(f"Number of papers: {num_papers}")
# Step 1: Search arXiv (with fallback)
logger.info("Step 1: Searching arXiv...")
papers = self._search_with_fallback(
query=query,
max_results=num_papers,
category=category
)
if not papers:
error_msg = "No papers found for the given query (tried all available clients)"
logger.error(error_msg)
state["errors"].append(error_msg)
return state
logger.info(f"Found {len(papers)} papers")
# Validate paper data quality after MCP parsing
validated_papers = []
for paper in papers:
try:
# Check for critical data quality issues
issues = []
# Validate authors field
if not isinstance(paper.authors, list):
issues.append(f"authors is {type(paper.authors).__name__} instead of list")
elif len(paper.authors) == 0:
issues.append("authors list is empty")
# Validate categories field
if not isinstance(paper.categories, list):
issues.append(f"categories is {type(paper.categories).__name__} instead of list")
# Validate string fields
if not isinstance(paper.title, str):
issues.append(f"title is {type(paper.title).__name__} instead of str")
if not isinstance(paper.pdf_url, str):
issues.append(f"pdf_url is {type(paper.pdf_url).__name__} instead of str")
if not isinstance(paper.abstract, str):
issues.append(f"abstract is {type(paper.abstract).__name__} instead of str")
if issues:
logger.warning(f"Paper {paper.arxiv_id} has data quality issues: {', '.join(issues)}")
# Note: Thanks to Pydantic validators, these should already be fixed
# This is just a diagnostic check
validated_papers.append(paper)
except Exception as e:
error_msg = f"Failed to validate paper {getattr(paper, 'arxiv_id', 'unknown')}: {str(e)}"
logger.error(error_msg)
state["errors"].append(error_msg)
# Skip this paper but continue with others
if not validated_papers:
error_msg = "All papers failed validation checks"
logger.error(error_msg)
state["errors"].append(error_msg)
return state
logger.info(f"Validated {len(validated_papers)} papers (filtered out {len(papers) - len(validated_papers)})")
state["papers"] = validated_papers
# Step 2: Download papers (with fallback)
logger.info("Step 2: Downloading papers...")
pdf_paths = []
for paper in papers:
path = self._download_with_fallback(paper)
if path:
pdf_paths.append((paper, path))
else:
logger.warning(f"Failed to download paper {paper.arxiv_id} (all clients failed)")
logger.info(f"Downloaded {len(pdf_paths)} papers")
# Step 3: Process PDFs and chunk
logger.info("Step 3: Processing PDFs and chunking...")
all_chunks = []
for paper, pdf_path in pdf_paths:
try:
chunks = self.pdf_processor.process_paper(pdf_path, paper)
if chunks:
all_chunks.extend(chunks)
logger.info(f"Processed {len(chunks)} chunks from {paper.arxiv_id}")
else:
error_msg = f"Failed to process paper {paper.arxiv_id}"
logger.warning(error_msg)
state["errors"].append(error_msg)
except Exception as e:
error_msg = f"Error processing paper {paper.arxiv_id}: {str(e)}"
logger.error(error_msg)
state["errors"].append(error_msg)
if not all_chunks:
error_msg = "Failed to extract text from any papers"
logger.error(error_msg)
state["errors"].append(error_msg)
return state
logger.info(f"Total chunks created: {len(all_chunks)}")
state["chunks"] = all_chunks
# Step 4: Generate embeddings
logger.info("Step 4: Generating embeddings...")
chunk_texts = [chunk.content for chunk in all_chunks]
embeddings = self.embedding_generator.generate_embeddings_batch(chunk_texts)
logger.info(f"Generated {len(embeddings)} embeddings")
# Estimate embedding tokens (Azure doesn't return usage for embeddings)
# Estimate ~300 tokens per chunk on average
estimated_embedding_tokens = len(chunk_texts) * 300
state["token_usage"]["embedding_tokens"] += estimated_embedding_tokens
logger.info(f"Estimated embedding tokens: {estimated_embedding_tokens}")
# Step 5: Store in vector database
logger.info("Step 5: Storing in vector database...")
self.vector_store.add_chunks(all_chunks, embeddings)
logger.info("=== Retriever Agent Completed Successfully ===")
return state
except Exception as e:
error_msg = f"Retriever Agent error: {str(e)}"
logger.error(error_msg)
state["errors"].append(error_msg)
return state