File size: 12,109 Bytes
aca8ab4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
"""
Retriever Agent: Search arXiv, download papers, and chunk for RAG.
Includes intelligent fallback from MCP/FastMCP to direct arXiv API.
"""
import logging
from typing import Dict, Any, Optional, List
from pathlib import Path

from utils.arxiv_client import ArxivClient
from utils.pdf_processor import PDFProcessor
from utils.schemas import AgentState, PaperChunk, Paper
from rag.vector_store import VectorStore
from rag.embeddings import EmbeddingGenerator
from utils.langfuse_client import observe

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# Import MCP clients for type hints
try:
    from utils.mcp_arxiv_client import MCPArxivClient
except ImportError:
    MCPArxivClient = None

try:
    from utils.fastmcp_arxiv_client import FastMCPArxivClient
except ImportError:
    FastMCPArxivClient = None



class RetrieverAgent:
    """Agent for retrieving and processing papers from arXiv with intelligent fallback."""

    def __init__(
        self,
        arxiv_client: Any,
        pdf_processor: PDFProcessor,
        vector_store: VectorStore,
        embedding_generator: EmbeddingGenerator,
        fallback_client: Optional[Any] = None
    ):
        """
        Initialize Retriever Agent with fallback support.

        Args:
            arxiv_client: Primary client (ArxivClient, MCPArxivClient, or FastMCPArxivClient)
            pdf_processor: PDFProcessor instance
            vector_store: VectorStore instance
            embedding_generator: EmbeddingGenerator instance
            fallback_client: Optional fallback client (usually direct ArxivClient) used if primary fails
        """
        self.arxiv_client = arxiv_client
        self.pdf_processor = pdf_processor
        self.vector_store = vector_store
        self.embedding_generator = embedding_generator
        self.fallback_client = fallback_client

        # Log client configuration
        client_name = type(arxiv_client).__name__
        logger.info(f"RetrieverAgent initialized with primary client: {client_name}")
        if fallback_client:
            fallback_name = type(fallback_client).__name__
            logger.info(f"Fallback client configured: {fallback_name}")

    def _search_with_fallback(
        self,
        query: str,
        max_results: int,
        category: Optional[str]
    ) -> Optional[List[Paper]]:
        """
        Search for papers with automatic fallback.

        Args:
            query: Search query
            max_results: Maximum number of papers
            category: Optional category filter

        Returns:
            List of Paper objects, or None if both primary and fallback fail
        """
        # Try primary client
        try:
            logger.info(f"Searching with primary client ({type(self.arxiv_client).__name__})")
            papers = self.arxiv_client.search_papers(
                query=query,
                max_results=max_results,
                category=category
            )
            if papers:
                logger.info(f"Primary client found {len(papers)} papers")
                return papers
            else:
                logger.warning("Primary client returned no papers")
        except Exception as e:
            logger.error(f"Primary client search failed: {str(e)}")

        # Try fallback client if available
        if self.fallback_client:
            try:
                logger.warning(f"Attempting fallback with {type(self.fallback_client).__name__}")
                papers = self.fallback_client.search_papers(
                    query=query,
                    max_results=max_results,
                    category=category
                )
                if papers:
                    logger.info(f"Fallback client found {len(papers)} papers")
                    return papers
                else:
                    logger.error("Fallback client returned no papers")
            except Exception as e:
                logger.error(f"Fallback client search failed: {str(e)}")

        logger.error("All search attempts failed")
        return None

    def _download_with_fallback(self, paper: Paper) -> Optional[Path]:
        """
        Download paper with automatic fallback.

        Args:
            paper: Paper object to download

        Returns:
            Path to downloaded PDF, or None if both primary and fallback fail
        """
        # Try primary client
        try:
            path = self.arxiv_client.download_paper(paper)
            if path:
                logger.debug(f"Primary client downloaded {paper.arxiv_id}")
                return path
            else:
                logger.warning(f"Primary client failed to download {paper.arxiv_id}")
        except Exception as e:
            logger.error(f"Primary client download error for {paper.arxiv_id}: {str(e)}")

        # Try fallback client if available
        if self.fallback_client:
            try:
                logger.debug(f"Attempting fallback download for {paper.arxiv_id}")
                path = self.fallback_client.download_paper(paper)
                if path:
                    logger.info(f"Fallback client downloaded {paper.arxiv_id}")
                    return path
                else:
                    logger.error(f"Fallback client failed to download {paper.arxiv_id}")
            except Exception as e:
                logger.error(f"Fallback client download error for {paper.arxiv_id}: {str(e)}")

        logger.error(f"All download attempts failed for {paper.arxiv_id}")
        return None

    @observe(name="retriever_agent_run", as_type="generation")
    def run(self, state: Dict[str, Any]) -> Dict[str, Any]:
        """
        Execute retriever agent.

        Args:
            state: Current agent state

        Returns:
            Updated state with papers and chunks
        """
        try:
            logger.info("=== Retriever Agent Started ===")

            query = state.get("query")
            category = state.get("category")
            num_papers = state.get("num_papers", 5)

            logger.info(f"Query: {query}")
            logger.info(f"Category: {category}")
            logger.info(f"Number of papers: {num_papers}")

            # Step 1: Search arXiv (with fallback)
            logger.info("Step 1: Searching arXiv...")
            papers = self._search_with_fallback(
                query=query,
                max_results=num_papers,
                category=category
            )

            if not papers:
                error_msg = "No papers found for the given query (tried all available clients)"
                logger.error(error_msg)
                state["errors"].append(error_msg)
                return state

            logger.info(f"Found {len(papers)} papers")

            # Validate paper data quality after MCP parsing
            validated_papers = []
            for paper in papers:
                try:
                    # Check for critical data quality issues
                    issues = []

                    # Validate authors field
                    if not isinstance(paper.authors, list):
                        issues.append(f"authors is {type(paper.authors).__name__} instead of list")
                    elif len(paper.authors) == 0:
                        issues.append("authors list is empty")

                    # Validate categories field
                    if not isinstance(paper.categories, list):
                        issues.append(f"categories is {type(paper.categories).__name__} instead of list")

                    # Validate string fields
                    if not isinstance(paper.title, str):
                        issues.append(f"title is {type(paper.title).__name__} instead of str")
                    if not isinstance(paper.pdf_url, str):
                        issues.append(f"pdf_url is {type(paper.pdf_url).__name__} instead of str")
                    if not isinstance(paper.abstract, str):
                        issues.append(f"abstract is {type(paper.abstract).__name__} instead of str")

                    if issues:
                        logger.warning(f"Paper {paper.arxiv_id} has data quality issues: {', '.join(issues)}")
                        # Note: Thanks to Pydantic validators, these should already be fixed
                        # This is just a diagnostic check

                    validated_papers.append(paper)

                except Exception as e:
                    error_msg = f"Failed to validate paper {getattr(paper, 'arxiv_id', 'unknown')}: {str(e)}"
                    logger.error(error_msg)
                    state["errors"].append(error_msg)
                    # Skip this paper but continue with others

            if not validated_papers:
                error_msg = "All papers failed validation checks"
                logger.error(error_msg)
                state["errors"].append(error_msg)
                return state

            logger.info(f"Validated {len(validated_papers)} papers (filtered out {len(papers) - len(validated_papers)})")
            state["papers"] = validated_papers

            # Step 2: Download papers (with fallback)
            logger.info("Step 2: Downloading papers...")
            pdf_paths = []
            for paper in papers:
                path = self._download_with_fallback(paper)
                if path:
                    pdf_paths.append((paper, path))
                else:
                    logger.warning(f"Failed to download paper {paper.arxiv_id} (all clients failed)")

            logger.info(f"Downloaded {len(pdf_paths)} papers")

            # Step 3: Process PDFs and chunk
            logger.info("Step 3: Processing PDFs and chunking...")
            all_chunks = []
            for paper, pdf_path in pdf_paths:
                try:
                    chunks = self.pdf_processor.process_paper(pdf_path, paper)
                    if chunks:
                        all_chunks.extend(chunks)
                        logger.info(f"Processed {len(chunks)} chunks from {paper.arxiv_id}")
                    else:
                        error_msg = f"Failed to process paper {paper.arxiv_id}"
                        logger.warning(error_msg)
                        state["errors"].append(error_msg)
                except Exception as e:
                    error_msg = f"Error processing paper {paper.arxiv_id}: {str(e)}"
                    logger.error(error_msg)
                    state["errors"].append(error_msg)

            if not all_chunks:
                error_msg = "Failed to extract text from any papers"
                logger.error(error_msg)
                state["errors"].append(error_msg)
                return state

            logger.info(f"Total chunks created: {len(all_chunks)}")
            state["chunks"] = all_chunks

            # Step 4: Generate embeddings
            logger.info("Step 4: Generating embeddings...")
            chunk_texts = [chunk.content for chunk in all_chunks]
            embeddings = self.embedding_generator.generate_embeddings_batch(chunk_texts)
            logger.info(f"Generated {len(embeddings)} embeddings")

            # Estimate embedding tokens (Azure doesn't return usage for embeddings)
            # Estimate ~300 tokens per chunk on average
            estimated_embedding_tokens = len(chunk_texts) * 300
            state["token_usage"]["embedding_tokens"] += estimated_embedding_tokens
            logger.info(f"Estimated embedding tokens: {estimated_embedding_tokens}")

            # Step 5: Store in vector database
            logger.info("Step 5: Storing in vector database...")
            self.vector_store.add_chunks(all_chunks, embeddings)

            logger.info("=== Retriever Agent Completed Successfully ===")
            return state

        except Exception as e:
            error_msg = f"Retriever Agent error: {str(e)}"
            logger.error(error_msg)
            state["errors"].append(error_msg)
            return state