File size: 12,109 Bytes
aca8ab4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 | """
Retriever Agent: Search arXiv, download papers, and chunk for RAG.
Includes intelligent fallback from MCP/FastMCP to direct arXiv API.
"""
import logging
from typing import Dict, Any, Optional, List
from pathlib import Path
from utils.arxiv_client import ArxivClient
from utils.pdf_processor import PDFProcessor
from utils.schemas import AgentState, PaperChunk, Paper
from rag.vector_store import VectorStore
from rag.embeddings import EmbeddingGenerator
from utils.langfuse_client import observe
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Import MCP clients for type hints
try:
from utils.mcp_arxiv_client import MCPArxivClient
except ImportError:
MCPArxivClient = None
try:
from utils.fastmcp_arxiv_client import FastMCPArxivClient
except ImportError:
FastMCPArxivClient = None
class RetrieverAgent:
"""Agent for retrieving and processing papers from arXiv with intelligent fallback."""
def __init__(
self,
arxiv_client: Any,
pdf_processor: PDFProcessor,
vector_store: VectorStore,
embedding_generator: EmbeddingGenerator,
fallback_client: Optional[Any] = None
):
"""
Initialize Retriever Agent with fallback support.
Args:
arxiv_client: Primary client (ArxivClient, MCPArxivClient, or FastMCPArxivClient)
pdf_processor: PDFProcessor instance
vector_store: VectorStore instance
embedding_generator: EmbeddingGenerator instance
fallback_client: Optional fallback client (usually direct ArxivClient) used if primary fails
"""
self.arxiv_client = arxiv_client
self.pdf_processor = pdf_processor
self.vector_store = vector_store
self.embedding_generator = embedding_generator
self.fallback_client = fallback_client
# Log client configuration
client_name = type(arxiv_client).__name__
logger.info(f"RetrieverAgent initialized with primary client: {client_name}")
if fallback_client:
fallback_name = type(fallback_client).__name__
logger.info(f"Fallback client configured: {fallback_name}")
def _search_with_fallback(
self,
query: str,
max_results: int,
category: Optional[str]
) -> Optional[List[Paper]]:
"""
Search for papers with automatic fallback.
Args:
query: Search query
max_results: Maximum number of papers
category: Optional category filter
Returns:
List of Paper objects, or None if both primary and fallback fail
"""
# Try primary client
try:
logger.info(f"Searching with primary client ({type(self.arxiv_client).__name__})")
papers = self.arxiv_client.search_papers(
query=query,
max_results=max_results,
category=category
)
if papers:
logger.info(f"Primary client found {len(papers)} papers")
return papers
else:
logger.warning("Primary client returned no papers")
except Exception as e:
logger.error(f"Primary client search failed: {str(e)}")
# Try fallback client if available
if self.fallback_client:
try:
logger.warning(f"Attempting fallback with {type(self.fallback_client).__name__}")
papers = self.fallback_client.search_papers(
query=query,
max_results=max_results,
category=category
)
if papers:
logger.info(f"Fallback client found {len(papers)} papers")
return papers
else:
logger.error("Fallback client returned no papers")
except Exception as e:
logger.error(f"Fallback client search failed: {str(e)}")
logger.error("All search attempts failed")
return None
def _download_with_fallback(self, paper: Paper) -> Optional[Path]:
"""
Download paper with automatic fallback.
Args:
paper: Paper object to download
Returns:
Path to downloaded PDF, or None if both primary and fallback fail
"""
# Try primary client
try:
path = self.arxiv_client.download_paper(paper)
if path:
logger.debug(f"Primary client downloaded {paper.arxiv_id}")
return path
else:
logger.warning(f"Primary client failed to download {paper.arxiv_id}")
except Exception as e:
logger.error(f"Primary client download error for {paper.arxiv_id}: {str(e)}")
# Try fallback client if available
if self.fallback_client:
try:
logger.debug(f"Attempting fallback download for {paper.arxiv_id}")
path = self.fallback_client.download_paper(paper)
if path:
logger.info(f"Fallback client downloaded {paper.arxiv_id}")
return path
else:
logger.error(f"Fallback client failed to download {paper.arxiv_id}")
except Exception as e:
logger.error(f"Fallback client download error for {paper.arxiv_id}: {str(e)}")
logger.error(f"All download attempts failed for {paper.arxiv_id}")
return None
@observe(name="retriever_agent_run", as_type="generation")
def run(self, state: Dict[str, Any]) -> Dict[str, Any]:
"""
Execute retriever agent.
Args:
state: Current agent state
Returns:
Updated state with papers and chunks
"""
try:
logger.info("=== Retriever Agent Started ===")
query = state.get("query")
category = state.get("category")
num_papers = state.get("num_papers", 5)
logger.info(f"Query: {query}")
logger.info(f"Category: {category}")
logger.info(f"Number of papers: {num_papers}")
# Step 1: Search arXiv (with fallback)
logger.info("Step 1: Searching arXiv...")
papers = self._search_with_fallback(
query=query,
max_results=num_papers,
category=category
)
if not papers:
error_msg = "No papers found for the given query (tried all available clients)"
logger.error(error_msg)
state["errors"].append(error_msg)
return state
logger.info(f"Found {len(papers)} papers")
# Validate paper data quality after MCP parsing
validated_papers = []
for paper in papers:
try:
# Check for critical data quality issues
issues = []
# Validate authors field
if not isinstance(paper.authors, list):
issues.append(f"authors is {type(paper.authors).__name__} instead of list")
elif len(paper.authors) == 0:
issues.append("authors list is empty")
# Validate categories field
if not isinstance(paper.categories, list):
issues.append(f"categories is {type(paper.categories).__name__} instead of list")
# Validate string fields
if not isinstance(paper.title, str):
issues.append(f"title is {type(paper.title).__name__} instead of str")
if not isinstance(paper.pdf_url, str):
issues.append(f"pdf_url is {type(paper.pdf_url).__name__} instead of str")
if not isinstance(paper.abstract, str):
issues.append(f"abstract is {type(paper.abstract).__name__} instead of str")
if issues:
logger.warning(f"Paper {paper.arxiv_id} has data quality issues: {', '.join(issues)}")
# Note: Thanks to Pydantic validators, these should already be fixed
# This is just a diagnostic check
validated_papers.append(paper)
except Exception as e:
error_msg = f"Failed to validate paper {getattr(paper, 'arxiv_id', 'unknown')}: {str(e)}"
logger.error(error_msg)
state["errors"].append(error_msg)
# Skip this paper but continue with others
if not validated_papers:
error_msg = "All papers failed validation checks"
logger.error(error_msg)
state["errors"].append(error_msg)
return state
logger.info(f"Validated {len(validated_papers)} papers (filtered out {len(papers) - len(validated_papers)})")
state["papers"] = validated_papers
# Step 2: Download papers (with fallback)
logger.info("Step 2: Downloading papers...")
pdf_paths = []
for paper in papers:
path = self._download_with_fallback(paper)
if path:
pdf_paths.append((paper, path))
else:
logger.warning(f"Failed to download paper {paper.arxiv_id} (all clients failed)")
logger.info(f"Downloaded {len(pdf_paths)} papers")
# Step 3: Process PDFs and chunk
logger.info("Step 3: Processing PDFs and chunking...")
all_chunks = []
for paper, pdf_path in pdf_paths:
try:
chunks = self.pdf_processor.process_paper(pdf_path, paper)
if chunks:
all_chunks.extend(chunks)
logger.info(f"Processed {len(chunks)} chunks from {paper.arxiv_id}")
else:
error_msg = f"Failed to process paper {paper.arxiv_id}"
logger.warning(error_msg)
state["errors"].append(error_msg)
except Exception as e:
error_msg = f"Error processing paper {paper.arxiv_id}: {str(e)}"
logger.error(error_msg)
state["errors"].append(error_msg)
if not all_chunks:
error_msg = "Failed to extract text from any papers"
logger.error(error_msg)
state["errors"].append(error_msg)
return state
logger.info(f"Total chunks created: {len(all_chunks)}")
state["chunks"] = all_chunks
# Step 4: Generate embeddings
logger.info("Step 4: Generating embeddings...")
chunk_texts = [chunk.content for chunk in all_chunks]
embeddings = self.embedding_generator.generate_embeddings_batch(chunk_texts)
logger.info(f"Generated {len(embeddings)} embeddings")
# Estimate embedding tokens (Azure doesn't return usage for embeddings)
# Estimate ~300 tokens per chunk on average
estimated_embedding_tokens = len(chunk_texts) * 300
state["token_usage"]["embedding_tokens"] += estimated_embedding_tokens
logger.info(f"Estimated embedding tokens: {estimated_embedding_tokens}")
# Step 5: Store in vector database
logger.info("Step 5: Storing in vector database...")
self.vector_store.add_chunks(all_chunks, embeddings)
logger.info("=== Retriever Agent Completed Successfully ===")
return state
except Exception as e:
error_msg = f"Retriever Agent error: {str(e)}"
logger.error(error_msg)
state["errors"].append(error_msg)
return state
|