| | """ |
| | Retriever Agent: Search arXiv, download papers, and chunk for RAG. |
| | Includes intelligent fallback from MCP/FastMCP to direct arXiv API. |
| | """ |
| | import logging |
| | from typing import Dict, Any, Optional, List |
| | from pathlib import Path |
| |
|
| | from utils.arxiv_client import ArxivClient |
| | from utils.pdf_processor import PDFProcessor |
| | from utils.schemas import AgentState, PaperChunk, Paper |
| | from rag.vector_store import VectorStore |
| | from rag.embeddings import EmbeddingGenerator |
| | from utils.langfuse_client import observe |
| |
|
| | logging.basicConfig( |
| | level=logging.INFO, |
| | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
| | ) |
| | logger = logging.getLogger(__name__) |
| |
|
| | |
| | try: |
| | from utils.mcp_arxiv_client import MCPArxivClient |
| | except ImportError: |
| | MCPArxivClient = None |
| |
|
| | try: |
| | from utils.fastmcp_arxiv_client import FastMCPArxivClient |
| | except ImportError: |
| | FastMCPArxivClient = None |
| |
|
| |
|
| |
|
| | class RetrieverAgent: |
| | """Agent for retrieving and processing papers from arXiv with intelligent fallback.""" |
| |
|
| | def __init__( |
| | self, |
| | arxiv_client: Any, |
| | pdf_processor: PDFProcessor, |
| | vector_store: VectorStore, |
| | embedding_generator: EmbeddingGenerator, |
| | fallback_client: Optional[Any] = None |
| | ): |
| | """ |
| | Initialize Retriever Agent with fallback support. |
| | |
| | Args: |
| | arxiv_client: Primary client (ArxivClient, MCPArxivClient, or FastMCPArxivClient) |
| | pdf_processor: PDFProcessor instance |
| | vector_store: VectorStore instance |
| | embedding_generator: EmbeddingGenerator instance |
| | fallback_client: Optional fallback client (usually direct ArxivClient) used if primary fails |
| | """ |
| | self.arxiv_client = arxiv_client |
| | self.pdf_processor = pdf_processor |
| | self.vector_store = vector_store |
| | self.embedding_generator = embedding_generator |
| | self.fallback_client = fallback_client |
| |
|
| | |
| | client_name = type(arxiv_client).__name__ |
| | logger.info(f"RetrieverAgent initialized with primary client: {client_name}") |
| | if fallback_client: |
| | fallback_name = type(fallback_client).__name__ |
| | logger.info(f"Fallback client configured: {fallback_name}") |
| |
|
| | def _search_with_fallback( |
| | self, |
| | query: str, |
| | max_results: int, |
| | category: Optional[str] |
| | ) -> Optional[List[Paper]]: |
| | """ |
| | Search for papers with automatic fallback. |
| | |
| | Args: |
| | query: Search query |
| | max_results: Maximum number of papers |
| | category: Optional category filter |
| | |
| | Returns: |
| | List of Paper objects, or None if both primary and fallback fail |
| | """ |
| | |
| | try: |
| | logger.info(f"Searching with primary client ({type(self.arxiv_client).__name__})") |
| | papers = self.arxiv_client.search_papers( |
| | query=query, |
| | max_results=max_results, |
| | category=category |
| | ) |
| | if papers: |
| | logger.info(f"Primary client found {len(papers)} papers") |
| | return papers |
| | else: |
| | logger.warning("Primary client returned no papers") |
| | except Exception as e: |
| | logger.error(f"Primary client search failed: {str(e)}") |
| |
|
| | |
| | if self.fallback_client: |
| | try: |
| | logger.warning(f"Attempting fallback with {type(self.fallback_client).__name__}") |
| | papers = self.fallback_client.search_papers( |
| | query=query, |
| | max_results=max_results, |
| | category=category |
| | ) |
| | if papers: |
| | logger.info(f"Fallback client found {len(papers)} papers") |
| | return papers |
| | else: |
| | logger.error("Fallback client returned no papers") |
| | except Exception as e: |
| | logger.error(f"Fallback client search failed: {str(e)}") |
| |
|
| | logger.error("All search attempts failed") |
| | return None |
| |
|
| | def _download_with_fallback(self, paper: Paper) -> Optional[Path]: |
| | """ |
| | Download paper with automatic fallback. |
| | |
| | Args: |
| | paper: Paper object to download |
| | |
| | Returns: |
| | Path to downloaded PDF, or None if both primary and fallback fail |
| | """ |
| | |
| | try: |
| | path = self.arxiv_client.download_paper(paper) |
| | if path: |
| | logger.debug(f"Primary client downloaded {paper.arxiv_id}") |
| | return path |
| | else: |
| | logger.warning(f"Primary client failed to download {paper.arxiv_id}") |
| | except Exception as e: |
| | logger.error(f"Primary client download error for {paper.arxiv_id}: {str(e)}") |
| |
|
| | |
| | if self.fallback_client: |
| | try: |
| | logger.debug(f"Attempting fallback download for {paper.arxiv_id}") |
| | path = self.fallback_client.download_paper(paper) |
| | if path: |
| | logger.info(f"Fallback client downloaded {paper.arxiv_id}") |
| | return path |
| | else: |
| | logger.error(f"Fallback client failed to download {paper.arxiv_id}") |
| | except Exception as e: |
| | logger.error(f"Fallback client download error for {paper.arxiv_id}: {str(e)}") |
| |
|
| | logger.error(f"All download attempts failed for {paper.arxiv_id}") |
| | return None |
| |
|
| | @observe(name="retriever_agent_run", as_type="generation") |
| | def run(self, state: Dict[str, Any]) -> Dict[str, Any]: |
| | """ |
| | Execute retriever agent. |
| | |
| | Args: |
| | state: Current agent state |
| | |
| | Returns: |
| | Updated state with papers and chunks |
| | """ |
| | try: |
| | logger.info("=== Retriever Agent Started ===") |
| |
|
| | query = state.get("query") |
| | category = state.get("category") |
| | num_papers = state.get("num_papers", 5) |
| |
|
| | logger.info(f"Query: {query}") |
| | logger.info(f"Category: {category}") |
| | logger.info(f"Number of papers: {num_papers}") |
| |
|
| | |
| | logger.info("Step 1: Searching arXiv...") |
| | papers = self._search_with_fallback( |
| | query=query, |
| | max_results=num_papers, |
| | category=category |
| | ) |
| |
|
| | if not papers: |
| | error_msg = "No papers found for the given query (tried all available clients)" |
| | logger.error(error_msg) |
| | state["errors"].append(error_msg) |
| | return state |
| |
|
| | logger.info(f"Found {len(papers)} papers") |
| |
|
| | |
| | validated_papers = [] |
| | for paper in papers: |
| | try: |
| | |
| | issues = [] |
| |
|
| | |
| | if not isinstance(paper.authors, list): |
| | issues.append(f"authors is {type(paper.authors).__name__} instead of list") |
| | elif len(paper.authors) == 0: |
| | issues.append("authors list is empty") |
| |
|
| | |
| | if not isinstance(paper.categories, list): |
| | issues.append(f"categories is {type(paper.categories).__name__} instead of list") |
| |
|
| | |
| | if not isinstance(paper.title, str): |
| | issues.append(f"title is {type(paper.title).__name__} instead of str") |
| | if not isinstance(paper.pdf_url, str): |
| | issues.append(f"pdf_url is {type(paper.pdf_url).__name__} instead of str") |
| | if not isinstance(paper.abstract, str): |
| | issues.append(f"abstract is {type(paper.abstract).__name__} instead of str") |
| |
|
| | if issues: |
| | logger.warning(f"Paper {paper.arxiv_id} has data quality issues: {', '.join(issues)}") |
| | |
| | |
| |
|
| | validated_papers.append(paper) |
| |
|
| | except Exception as e: |
| | error_msg = f"Failed to validate paper {getattr(paper, 'arxiv_id', 'unknown')}: {str(e)}" |
| | logger.error(error_msg) |
| | state["errors"].append(error_msg) |
| | |
| |
|
| | if not validated_papers: |
| | error_msg = "All papers failed validation checks" |
| | logger.error(error_msg) |
| | state["errors"].append(error_msg) |
| | return state |
| |
|
| | logger.info(f"Validated {len(validated_papers)} papers (filtered out {len(papers) - len(validated_papers)})") |
| | state["papers"] = validated_papers |
| |
|
| | |
| | logger.info("Step 2: Downloading papers...") |
| | pdf_paths = [] |
| | for paper in papers: |
| | path = self._download_with_fallback(paper) |
| | if path: |
| | pdf_paths.append((paper, path)) |
| | else: |
| | logger.warning(f"Failed to download paper {paper.arxiv_id} (all clients failed)") |
| |
|
| | logger.info(f"Downloaded {len(pdf_paths)} papers") |
| |
|
| | |
| | logger.info("Step 3: Processing PDFs and chunking...") |
| | all_chunks = [] |
| | for paper, pdf_path in pdf_paths: |
| | try: |
| | chunks = self.pdf_processor.process_paper(pdf_path, paper) |
| | if chunks: |
| | all_chunks.extend(chunks) |
| | logger.info(f"Processed {len(chunks)} chunks from {paper.arxiv_id}") |
| | else: |
| | error_msg = f"Failed to process paper {paper.arxiv_id}" |
| | logger.warning(error_msg) |
| | state["errors"].append(error_msg) |
| | except Exception as e: |
| | error_msg = f"Error processing paper {paper.arxiv_id}: {str(e)}" |
| | logger.error(error_msg) |
| | state["errors"].append(error_msg) |
| |
|
| | if not all_chunks: |
| | error_msg = "Failed to extract text from any papers" |
| | logger.error(error_msg) |
| | state["errors"].append(error_msg) |
| | return state |
| |
|
| | logger.info(f"Total chunks created: {len(all_chunks)}") |
| | state["chunks"] = all_chunks |
| |
|
| | |
| | logger.info("Step 4: Generating embeddings...") |
| | chunk_texts = [chunk.content for chunk in all_chunks] |
| | embeddings = self.embedding_generator.generate_embeddings_batch(chunk_texts) |
| | logger.info(f"Generated {len(embeddings)} embeddings") |
| |
|
| | |
| | |
| | estimated_embedding_tokens = len(chunk_texts) * 300 |
| | state["token_usage"]["embedding_tokens"] += estimated_embedding_tokens |
| | logger.info(f"Estimated embedding tokens: {estimated_embedding_tokens}") |
| |
|
| | |
| | logger.info("Step 5: Storing in vector database...") |
| | self.vector_store.add_chunks(all_chunks, embeddings) |
| |
|
| | logger.info("=== Retriever Agent Completed Successfully ===") |
| | return state |
| |
|
| | except Exception as e: |
| | error_msg = f"Retriever Agent error: {str(e)}" |
| | logger.error(error_msg) |
| | state["errors"].append(error_msg) |
| | return state |
| |
|