Tanaybh commited on
Commit
3a5fdfb
·
verified ·
1 Parent(s): 422c3d1

Upload 4 files

Browse files
Files changed (4) hide show
  1. agents.py +351 -0
  2. fetch_arxiv_data.py +114 -0
  3. retriever.py +201 -0
  4. utils.py +231 -0
agents.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DocMind - Multi-Agent System
3
+ Implements Retriever, Reader, Critic, and Synthesizer agents
4
+ """
5
+
6
+ from typing import List, Dict, Tuple
7
+ from retriever import PaperRetriever
8
+ import os
9
+
10
+
11
+ class RetrieverAgent:
12
+ """Agent responsible for finding relevant papers"""
13
+
14
+ def __init__(self, retriever: PaperRetriever):
15
+ self.retriever = retriever
16
+
17
+ def retrieve(self, query: str, top_k: int = 5) -> List[Tuple[Dict, float]]:
18
+ """
19
+ Retrieve relevant papers for the query
20
+
21
+ Returns:
22
+ List of (paper, relevance_score) tuples
23
+ """
24
+ print(f"🔍 Retriever Agent: Searching for '{query}'...")
25
+ results = self.retriever.search(query, top_k)
26
+ print(f" Found {len(results)} relevant papers")
27
+ return results
28
+
29
+
30
+ class ReaderAgent:
31
+ """Agent responsible for reading and summarizing papers"""
32
+
33
+ def __init__(self, llm_client=None):
34
+ """
35
+ Args:
36
+ llm_client: Optional LLM client (OpenAI, Anthropic, etc.)
37
+ If None, uses rule-based summarization
38
+ """
39
+ self.llm_client = llm_client
40
+
41
+ def summarize_paper(self, paper: Dict) -> str:
42
+ """
43
+ Generate a summary of a single paper
44
+
45
+ Args:
46
+ paper: Paper dictionary with title, abstract, etc.
47
+
48
+ Returns:
49
+ Summary string
50
+ """
51
+ if self.llm_client:
52
+ return self._llm_summarize(paper)
53
+ else:
54
+ return self._rule_based_summarize(paper)
55
+
56
+ def _rule_based_summarize(self, paper: Dict) -> str:
57
+ """Simple extractive summary (first 3 sentences)"""
58
+ abstract = paper['abstract']
59
+ sentences = abstract.split('. ')
60
+ summary = '. '.join(sentences[:3]) + '.'
61
+
62
+ return {
63
+ 'title': paper['title'],
64
+ 'arxiv_id': paper['arxiv_id'],
65
+ 'authors': paper['authors'][:3],
66
+ 'summary': summary,
67
+ 'year': paper['published'][:4]
68
+ }
69
+
70
+ def _llm_summarize(self, paper: Dict) -> str:
71
+ """Use LLM to generate intelligent summary"""
72
+ prompt = f"""Summarize this research paper in 2-3 sentences, focusing on:
73
+ 1. The main contribution/idea
74
+ 2. The key methodology or approach
75
+ 3. Important results or implications
76
+
77
+ Title: {paper['title']}
78
+ Abstract: {paper['abstract']}
79
+
80
+ Summary:"""
81
+
82
+ # Call LLM (implementation depends on client)
83
+ # This is a placeholder - replace with actual LLM call
84
+ response = "LLM summary would go here"
85
+
86
+ return {
87
+ 'title': paper['title'],
88
+ 'arxiv_id': paper['arxiv_id'],
89
+ 'authors': paper['authors'][:3],
90
+ 'summary': response,
91
+ 'year': paper['published'][:4]
92
+ }
93
+
94
+ def read_papers(self, papers: List[Tuple[Dict, float]]) -> List[Dict]:
95
+ """
96
+ Read and summarize multiple papers
97
+
98
+ Args:
99
+ papers: List of (paper, score) tuples from retriever
100
+
101
+ Returns:
102
+ List of summaries
103
+ """
104
+ print(f"📖 Reader Agent: Reading {len(papers)} papers...")
105
+ summaries = []
106
+
107
+ for paper, score in papers:
108
+ summary = self.summarize_paper(paper)
109
+ summary['relevance_score'] = score
110
+ summaries.append(summary)
111
+
112
+ print(f" Generated {len(summaries)} summaries")
113
+ return summaries
114
+
115
+
116
+ class CriticAgent:
117
+ """Agent responsible for evaluating and filtering summaries"""
118
+
119
+ def __init__(self, llm_client=None):
120
+ self.llm_client = llm_client
121
+
122
+ def critique(self, summaries: List[Dict], query: str) -> List[Dict]:
123
+ """
124
+ Evaluate summaries for quality and relevance
125
+
126
+ Args:
127
+ summaries: List of paper summaries
128
+ query: Original user query
129
+
130
+ Returns:
131
+ Filtered and scored summaries
132
+ """
133
+ print(f"🔎 Critic Agent: Evaluating {len(summaries)} summaries...")
134
+
135
+ # Simple rule-based filtering
136
+ filtered = []
137
+ for summary in summaries:
138
+ # Check relevance score threshold
139
+ if summary['relevance_score'] > 0.3:
140
+ # Add quality score (can be enhanced with LLM)
141
+ summary['quality_score'] = self._assess_quality(summary, query)
142
+ filtered.append(summary)
143
+
144
+ # Sort by combined score
145
+ filtered.sort(
146
+ key=lambda x: x['relevance_score'] * 0.7 + x['quality_score'] * 0.3,
147
+ reverse=True
148
+ )
149
+
150
+ print(f" Retained {len(filtered)} high-quality summaries")
151
+ return filtered
152
+
153
+ def _assess_quality(self, summary: Dict, query: str) -> float:
154
+ """
155
+ Simple quality assessment (can be enhanced with LLM)
156
+
157
+ Returns:
158
+ Quality score 0-1
159
+ """
160
+ score = 0.5 # Base score
161
+
162
+ # Longer summaries might be more informative
163
+ if len(summary['summary']) > 100:
164
+ score += 0.2
165
+
166
+ # Recent papers get bonus
167
+ if int(summary['year']) >= 2024:
168
+ score += 0.3
169
+
170
+ return min(score, 1.0)
171
+
172
+
173
+ class SynthesizerAgent:
174
+ """Agent responsible for synthesizing final answer"""
175
+
176
+ def __init__(self, llm_client=None):
177
+ self.llm_client = llm_client
178
+
179
+ def synthesize(
180
+ self,
181
+ summaries: List[Dict],
182
+ query: str,
183
+ max_papers: int = 10
184
+ ) -> str:
185
+ """
186
+ Synthesize final answer from summaries
187
+
188
+ Args:
189
+ summaries: List of filtered, quality summaries
190
+ query: Original user query
191
+ max_papers: Maximum papers to include in response
192
+
193
+ Returns:
194
+ Final synthesized response with citations
195
+ """
196
+ print(f"✨ Synthesizer Agent: Creating final response...")
197
+
198
+ if not summaries:
199
+ return "No relevant papers found for your query."
200
+
201
+ # Limit to top papers
202
+ top_summaries = summaries[:max_papers]
203
+
204
+ if self.llm_client:
205
+ return self._llm_synthesize(top_summaries, query)
206
+ else:
207
+ return self._rule_based_synthesize(top_summaries, query)
208
+
209
+ def _rule_based_synthesize(self, summaries: List[Dict], query: str) -> str:
210
+ """Create structured response without LLM"""
211
+ response = f"# Research Summary: {query}\n\n"
212
+ response += f"Based on {len(summaries)} relevant papers from arXiv:\n\n"
213
+
214
+ for i, summary in enumerate(summaries, 1):
215
+ response += f"## [{i}] {summary['title']}\n"
216
+ response += f"**Authors:** {', '.join(summary['authors'])}"
217
+ if len(summary['authors']) >= 3:
218
+ response += " et al."
219
+ response += f"\n**Year:** {summary['year']}\n"
220
+ response += f"**arXiv ID:** {summary['arxiv_id']}\n"
221
+ response += f"**Relevance:** {summary['relevance_score']:.2f}\n\n"
222
+ response += f"{summary['summary']}\n\n"
223
+ response += "---\n\n"
224
+
225
+ return response
226
+
227
+ def _llm_synthesize(self, summaries: List[Dict], query: str) -> str:
228
+ """Use LLM to create coherent synthesis"""
229
+ # Build context from summaries
230
+ context = ""
231
+ for i, summary in enumerate(summaries, 1):
232
+ context += f"[{i}] {summary['title']} ({summary['arxiv_id']})\n"
233
+ context += f" {summary['summary']}\n\n"
234
+
235
+ prompt = f"""You are a research assistant. Based on the following papers, answer this question:
236
+
237
+ Question: {query}
238
+
239
+ Papers:
240
+ {context}
241
+
242
+ Provide a comprehensive answer that:
243
+ 1. Directly addresses the question
244
+ 2. Synthesizes information across papers
245
+ 3. Cites papers by number [1], [2], etc.
246
+ 4. Highlights key findings and consensus/disagreements
247
+ 5. Is concise but thorough (3-5 paragraphs)
248
+
249
+ Answer:"""
250
+
251
+ # Placeholder for LLM call
252
+ response = "LLM-generated synthesis would go here with citations"
253
+
254
+ # Append paper references
255
+ response += "\n\n## References\n"
256
+ for i, summary in enumerate(summaries, 1):
257
+ response += f"[{i}] {summary['title']} "
258
+ response += f"({summary['arxiv_id']}, {summary['year']})\n"
259
+
260
+ return response
261
+
262
+
263
+ class DocMindOrchestrator:
264
+ """Main orchestrator that coordinates all agents"""
265
+
266
+ def __init__(
267
+ self,
268
+ retriever: PaperRetriever,
269
+ llm_client=None
270
+ ):
271
+ self.retriever_agent = RetrieverAgent(retriever)
272
+ self.reader_agent = ReaderAgent(llm_client)
273
+ self.critic_agent = CriticAgent(llm_client)
274
+ self.synthesizer_agent = SynthesizerAgent(llm_client)
275
+
276
+ def process_query(
277
+ self,
278
+ query: str,
279
+ top_k: int = 10,
280
+ max_papers_in_response: int = 5
281
+ ) -> str:
282
+ """
283
+ Process user query through full agent pipeline
284
+
285
+ Args:
286
+ query: User question
287
+ top_k: Number of papers to retrieve
288
+ max_papers_in_response: Max papers in final response
289
+
290
+ Returns:
291
+ Final synthesized answer
292
+ """
293
+ print(f"\n{'=' * 60}")
294
+ print(f"Processing query: {query}")
295
+ print('=' * 60)
296
+
297
+ # Step 1: Retrieve
298
+ papers = self.retriever_agent.retrieve(query, top_k)
299
+
300
+ if not papers:
301
+ return "No relevant papers found for your query."
302
+
303
+ # Step 2: Read & Summarize
304
+ summaries = self.reader_agent.read_papers(papers)
305
+
306
+ # Step 3: Critique & Filter
307
+ quality_summaries = self.critic_agent.critique(summaries, query)
308
+
309
+ # Step 4: Synthesize
310
+ final_response = self.synthesizer_agent.synthesize(
311
+ quality_summaries,
312
+ query,
313
+ max_papers_in_response
314
+ )
315
+
316
+ print(f"{'=' * 60}\n")
317
+ return final_response
318
+
319
+
320
+ def main():
321
+ """Example usage of multi-agent system"""
322
+ from fetch_arxiv_data import ArxivFetcher
323
+
324
+ # Setup
325
+ fetcher = ArxivFetcher()
326
+ retriever = PaperRetriever()
327
+
328
+ # Load or build index
329
+ if not retriever.load_index():
330
+ papers = fetcher.load_papers("arxiv_papers.json")
331
+ retriever.build_index(papers)
332
+ retriever.save_index()
333
+
334
+ # Create orchestrator
335
+ orchestrator = DocMindOrchestrator(retriever)
336
+
337
+ # Test queries
338
+ test_queries = [
339
+ "What are the latest improvements in diffusion models?",
340
+ "How does RLHF compare to DPO for language model alignment?",
341
+ "What are the main challenges in scaling transformers?"
342
+ ]
343
+
344
+ for query in test_queries:
345
+ response = orchestrator.process_query(query, top_k=8, max_papers_in_response=3)
346
+ print(response)
347
+ print("\n" + "=" * 80 + "\n")
348
+
349
+
350
+ if __name__ == "__main__":
351
+ main()
fetch_arxiv_data.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DocMind - arXiv Data Fetcher
3
+ Fetches papers from arXiv API and saves them for indexing
4
+ """
5
+
6
+ import arxiv
7
+ import os
8
+ import json
9
+ from pathlib import Path
10
+ from typing import List, Dict
11
+
12
+
13
+ class ArxivFetcher:
14
+ def __init__(self, data_dir: str = "data/papers"):
15
+ self.data_dir = Path(data_dir)
16
+ self.data_dir.mkdir(parents=True, exist_ok=True)
17
+
18
+ def fetch_papers(
19
+ self,
20
+ query: str = "machine learning",
21
+ max_results: int = 100,
22
+ category: str = None
23
+ ) -> List[Dict]:
24
+ """
25
+ Fetch papers from arXiv API
26
+
27
+ Args:
28
+ query: Search query string
29
+ max_results: Maximum number of papers to fetch
30
+ category: arXiv category (e.g., 'cs.AI', 'cs.LG')
31
+
32
+ Returns:
33
+ List of paper dictionaries
34
+ """
35
+ print(f"Fetching papers from arXiv: query='{query}', max={max_results}")
36
+
37
+ # Build search query
38
+ search_query = query
39
+ if category:
40
+ search_query = f"cat:{category} AND {query}"
41
+
42
+ search = arxiv.Search(
43
+ query=search_query,
44
+ max_results=max_results,
45
+ sort_by=arxiv.SortCriterion.SubmittedDate
46
+ )
47
+
48
+ papers = []
49
+ for result in search.results():
50
+ paper = {
51
+ 'arxiv_id': result.entry_id.split('/')[-1],
52
+ 'title': result.title,
53
+ 'authors': [author.name for author in result.authors],
54
+ 'abstract': result.summary,
55
+ 'published': result.published.strftime('%Y-%m-%d'),
56
+ 'pdf_url': result.pdf_url,
57
+ 'categories': result.categories
58
+ }
59
+ papers.append(paper)
60
+
61
+ print(f"Successfully fetched {len(papers)} papers")
62
+ return papers
63
+
64
+ def save_papers(self, papers: List[Dict], filename: str = "papers.json"):
65
+ """Save papers to JSON file"""
66
+ filepath = self.data_dir / filename
67
+ with open(filepath, 'w', encoding='utf-8') as f:
68
+ json.dump(papers, f, indent=2, ensure_ascii=False)
69
+ print(f"Saved {len(papers)} papers to {filepath}")
70
+
71
+ def load_papers(self, filename: str = "papers.json") -> List[Dict]:
72
+ """Load papers from JSON file"""
73
+ filepath = self.data_dir / filename
74
+ if not filepath.exists():
75
+ print(f"No saved papers found at {filepath}")
76
+ return []
77
+
78
+ with open(filepath, 'r', encoding='utf-8') as f:
79
+ papers = json.load(f)
80
+ print(f"Loaded {len(papers)} papers from {filepath}")
81
+ return papers
82
+
83
+
84
+ def main():
85
+ """Example usage: Fetch recent ML and AI papers"""
86
+ fetcher = ArxivFetcher()
87
+
88
+ # Fetch recent ML papers
89
+ ml_papers = fetcher.fetch_papers(
90
+ query="machine learning OR deep learning",
91
+ max_results=50,
92
+ category="cs.LG"
93
+ )
94
+
95
+ # Fetch recent AI papers
96
+ ai_papers = fetcher.fetch_papers(
97
+ query="artificial intelligence OR neural networks",
98
+ max_results=50,
99
+ category="cs.AI"
100
+ )
101
+
102
+ # Combine and save
103
+ all_papers = ml_papers + ai_papers
104
+ fetcher.save_papers(all_papers, "arxiv_papers.json")
105
+
106
+ # Show sample
107
+ print("\n=== Sample Paper ===")
108
+ print(f"Title: {all_papers[0]['title']}")
109
+ print(f"Authors: {', '.join(all_papers[0]['authors'][:3])}")
110
+ print(f"Abstract: {all_papers[0]['abstract'][:200]}...")
111
+
112
+
113
+ if __name__ == "__main__":
114
+ main()
retriever.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DocMind - Retriever Module
3
+ Semantic search over arXiv papers using FAISS and sentence-transformers
4
+ """
5
+
6
+ import numpy as np
7
+ import faiss
8
+ from sentence_transformers import SentenceTransformer
9
+ from typing import List, Dict, Tuple
10
+ import pickle
11
+ from pathlib import Path
12
+
13
+
14
+ class PaperRetriever:
15
+ def __init__(
16
+ self,
17
+ model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
18
+ index_path: str = "data/faiss_index"
19
+ ):
20
+ """
21
+ Initialize retriever with embedding model and FAISS index
22
+
23
+ Args:
24
+ model_name: HuggingFace sentence-transformer model
25
+ index_path: Directory to save/load FAISS index
26
+ """
27
+ print(f"Loading embedding model: {model_name}")
28
+ self.model = SentenceTransformer(model_name)
29
+ self.index_path = Path(index_path)
30
+ self.index_path.mkdir(parents=True, exist_ok=True)
31
+
32
+ self.index = None
33
+ self.papers = []
34
+ self.embeddings = None
35
+
36
+ def build_index(self, papers: List[Dict]):
37
+ """
38
+ Build FAISS index from papers
39
+
40
+ Args:
41
+ papers: List of paper dictionaries with 'title' and 'abstract'
42
+ """
43
+ print(f"Building index for {len(papers)} papers...")
44
+ self.papers = papers
45
+
46
+ # Create text to embed: title + abstract
47
+ texts = [
48
+ f"{paper['title']}. {paper['abstract']}"
49
+ for paper in papers
50
+ ]
51
+
52
+ # Generate embeddings
53
+ print("Generating embeddings...")
54
+ self.embeddings = self.model.encode(
55
+ texts,
56
+ show_progress_bar=True,
57
+ convert_to_numpy=True
58
+ )
59
+
60
+ # Build FAISS index
61
+ dimension = self.embeddings.shape[1]
62
+ self.index = faiss.IndexFlatIP(dimension) # Inner product (cosine similarity)
63
+
64
+ # Normalize embeddings for cosine similarity
65
+ faiss.normalize_L2(self.embeddings)
66
+ self.index.add(self.embeddings)
67
+
68
+ print(f"Index built with {self.index.ntotal} papers")
69
+
70
+ def save_index(self, name: str = "papers"):
71
+ """Save FAISS index and metadata"""
72
+ faiss.write_index(self.index, str(self.index_path / f"{name}.index"))
73
+
74
+ with open(self.index_path / f"{name}_papers.pkl", 'wb') as f:
75
+ pickle.dump(self.papers, f)
76
+
77
+ with open(self.index_path / f"{name}_embeddings.npy", 'wb') as f:
78
+ np.save(f, self.embeddings)
79
+
80
+ print(f"Saved index to {self.index_path}/{name}.*")
81
+
82
+ def load_index(self, name: str = "papers"):
83
+ """Load FAISS index and metadata"""
84
+ index_file = self.index_path / f"{name}.index"
85
+ if not index_file.exists():
86
+ print(f"No index found at {index_file}")
87
+ return False
88
+
89
+ self.index = faiss.read_index(str(index_file))
90
+
91
+ with open(self.index_path / f"{name}_papers.pkl", 'rb') as f:
92
+ self.papers = pickle.load(f)
93
+
94
+ with open(self.index_path / f"{name}_embeddings.npy", 'rb') as f:
95
+ self.embeddings = np.load(f)
96
+
97
+ print(f"Loaded index with {len(self.papers)} papers")
98
+ return True
99
+
100
+ def search(
101
+ self,
102
+ query: str,
103
+ top_k: int = 5
104
+ ) -> List[Tuple[Dict, float]]:
105
+ """
106
+ Search for relevant papers
107
+
108
+ Args:
109
+ query: Search query string
110
+ top_k: Number of results to return
111
+
112
+ Returns:
113
+ List of (paper_dict, score) tuples
114
+ """
115
+ if self.index is None:
116
+ raise ValueError("Index not built or loaded")
117
+
118
+ # Embed query
119
+ query_embedding = self.model.encode([query], convert_to_numpy=True)
120
+ faiss.normalize_L2(query_embedding)
121
+
122
+ # Search
123
+ scores, indices = self.index.search(query_embedding, top_k)
124
+
125
+ # Return results
126
+ results = []
127
+ for idx, score in zip(indices[0], scores[0]):
128
+ paper = self.papers[idx]
129
+ results.append((paper, float(score)))
130
+
131
+ return results
132
+
133
+ def get_retrieval_context(
134
+ self,
135
+ query: str,
136
+ top_k: int = 5
137
+ ) -> str:
138
+ """
139
+ Get formatted context string for LLM consumption
140
+
141
+ Args:
142
+ query: Search query
143
+ top_k: Number of papers to retrieve
144
+
145
+ Returns:
146
+ Formatted context string with paper summaries
147
+ """
148
+ results = self.search(query, top_k)
149
+
150
+ context = f"Retrieved {len(results)} relevant papers:\n\n"
151
+ for i, (paper, score) in enumerate(results, 1):
152
+ context += f"[{i}] {paper['title']}\n"
153
+ context += f" Authors: {', '.join(paper['authors'][:3])}"
154
+ if len(paper['authors']) > 3:
155
+ context += f" et al."
156
+ context += f"\n arXiv ID: {paper['arxiv_id']}\n"
157
+ context += f" Published: {paper['published']}\n"
158
+ context += f" Relevance: {score:.3f}\n"
159
+ context += f" Abstract: {paper['abstract']}\n\n"
160
+
161
+ return context
162
+
163
+
164
+ def main():
165
+ """Example: Build and test retriever"""
166
+ from fetch_arxiv_data import ArxivFetcher
167
+
168
+ # Load papers
169
+ fetcher = ArxivFetcher()
170
+ papers = fetcher.load_papers("arxiv_papers.json")
171
+
172
+ if not papers:
173
+ print("No papers found. Run fetch_arxiv_data.py first")
174
+ return
175
+
176
+ # Build index
177
+ retriever = PaperRetriever()
178
+ retriever.build_index(papers)
179
+ retriever.save_index()
180
+
181
+ # Test search
182
+ test_queries = [
183
+ "diffusion models for image generation",
184
+ "reinforcement learning from human feedback",
185
+ "large language model alignment"
186
+ ]
187
+
188
+ for query in test_queries:
189
+ print(f"\n{'=' * 60}")
190
+ print(f"Query: {query}")
191
+ print('=' * 60)
192
+
193
+ results = retriever.search(query, top_k=3)
194
+ for i, (paper, score) in enumerate(results, 1):
195
+ print(f"\n[{i}] Score: {score:.3f}")
196
+ print(f" {paper['title']}")
197
+ print(f" {paper['arxiv_id']}")
198
+
199
+
200
+ if __name__ == "__main__":
201
+ main()
utils.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DocMind - Utility Functions
3
+ Helper functions for the multi-agent system
4
+ """
5
+
6
+ from typing import List, Dict
7
+ import re
8
+ from datetime import datetime
9
+
10
+
11
+ def clean_text(text: str) -> str:
12
+ """Clean and normalize text"""
13
+ # Remove extra whitespace
14
+ text = re.sub(r'\s+', ' ', text)
15
+ # Remove special characters but keep basic punctuation
16
+ text = re.sub(r'[^\w\s.,!?;:()\-]', '', text)
17
+ return text.strip()
18
+
19
+
20
+ def truncate_text(text: str, max_length: int = 500) -> str:
21
+ """Truncate text to maximum length, ending at sentence boundary"""
22
+ if len(text) <= max_length:
23
+ return text
24
+
25
+ # Find last sentence boundary before max_length
26
+ truncated = text[:max_length]
27
+ last_period = truncated.rfind('.')
28
+
29
+ if last_period > 0:
30
+ return truncated[:last_period + 1]
31
+ return truncated + "..."
32
+
33
+
34
+ def format_authors(authors: List[str], max_authors: int = 3) -> str:
35
+ """Format author list for display"""
36
+ if len(authors) <= max_authors:
37
+ return ", ".join(authors)
38
+ else:
39
+ return ", ".join(authors[:max_authors]) + " et al."
40
+
41
+
42
+ def extract_year(date_string: str) -> int:
43
+ """Extract year from date string"""
44
+ try:
45
+ if isinstance(date_string, str):
46
+ return int(date_string[:4])
47
+ return datetime.now().year
48
+ except:
49
+ return datetime.now().year
50
+
51
+
52
+ def score_recency(year: int, current_year: int = None) -> float:
53
+ """
54
+ Score paper based on recency
55
+
56
+ Returns:
57
+ Score from 0-1, where 1 is most recent
58
+ """
59
+ if current_year is None:
60
+ current_year = datetime.now().year
61
+
62
+ age = current_year - year
63
+ if age <= 0:
64
+ return 1.0
65
+ elif age <= 1:
66
+ return 0.9
67
+ elif age <= 2:
68
+ return 0.7
69
+ elif age <= 3:
70
+ return 0.5
71
+ else:
72
+ return max(0.3, 1.0 / (age + 1))
73
+
74
+
75
+ def combine_scores(
76
+ relevance: float,
77
+ recency: float,
78
+ quality: float,
79
+ weights: Dict[str, float] = None
80
+ ) -> float:
81
+ """
82
+ Combine multiple scores with weights
83
+
84
+ Args:
85
+ relevance: Relevance score (0-1)
86
+ recency: Recency score (0-1)
87
+ quality: Quality score (0-1)
88
+ weights: Dict with keys 'relevance', 'recency', 'quality'
89
+
90
+ Returns:
91
+ Combined score (0-1)
92
+ """
93
+ if weights is None:
94
+ weights = {
95
+ 'relevance': 0.6,
96
+ 'recency': 0.2,
97
+ 'quality': 0.2
98
+ }
99
+
100
+ return (
101
+ relevance * weights['relevance'] +
102
+ recency * weights['recency'] +
103
+ quality * weights['quality']
104
+ )
105
+
106
+
107
+ def deduplicate_papers(papers: List[Dict]) -> List[Dict]:
108
+ """Remove duplicate papers based on arXiv ID"""
109
+ seen = set()
110
+ unique = []
111
+
112
+ for paper in papers:
113
+ paper_id = paper.get('arxiv_id', '')
114
+ if paper_id and paper_id not in seen:
115
+ seen.add(paper_id)
116
+ unique.append(paper)
117
+
118
+ return unique
119
+
120
+
121
+ def format_citation(paper: Dict, style: str = 'apa') -> str:
122
+ """
123
+ Format paper citation
124
+
125
+ Args:
126
+ paper: Paper dict with title, authors, year, arxiv_id
127
+ style: Citation style ('apa', 'simple', 'markdown')
128
+
129
+ Returns:
130
+ Formatted citation string
131
+ """
132
+ authors = format_authors(paper.get('authors', []))
133
+ title = paper.get('title', 'Unknown Title')
134
+ year = extract_year(paper.get('published', ''))
135
+ arxiv_id = paper.get('arxiv_id', '')
136
+
137
+ if style == 'apa':
138
+ return f"{authors} ({year}). {title}. arXiv:{arxiv_id}"
139
+
140
+ elif style == 'markdown':
141
+ return f"**{title}** - {authors} ({year}) - arXiv:[{arxiv_id}](https://arxiv.org/abs/{arxiv_id})"
142
+
143
+ else: # simple
144
+ return f"{title} ({arxiv_id}, {year})"
145
+
146
+
147
+ def extract_keywords(text: str, top_n: int = 5) -> List[str]:
148
+ """
149
+ Extract simple keywords from text (frequency-based)
150
+
151
+ Args:
152
+ text: Input text
153
+ top_n: Number of keywords to return
154
+
155
+ Returns:
156
+ List of top keywords
157
+ """
158
+ # Simple word frequency approach
159
+ # Remove common words
160
+ stop_words = {
161
+ 'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for',
162
+ 'of', 'with', 'by', 'from', 'is', 'are', 'was', 'were', 'be', 'been',
163
+ 'this', 'that', 'these', 'those', 'we', 'our', 'propose', 'show'
164
+ }
165
+
166
+ # Tokenize and count
167
+ words = re.findall(r'\b[a-z]{4,}\b', text.lower())
168
+ word_freq = {}
169
+
170
+ for word in words:
171
+ if word not in stop_words:
172
+ word_freq[word] = word_freq.get(word, 0) + 1
173
+
174
+ # Get top N
175
+ sorted_words = sorted(word_freq.items(), key=lambda x: x[1], reverse=True)
176
+ return [word for word, freq in sorted_words[:top_n]]
177
+
178
+
179
+ class ProgressTracker:
180
+ """Simple progress tracker for multi-step processes"""
181
+
182
+ def __init__(self, total_steps: int):
183
+ self.total_steps = total_steps
184
+ self.current_step = 0
185
+ self.step_names = []
186
+
187
+ def next_step(self, step_name: str = None):
188
+ """Move to next step"""
189
+ self.current_step += 1
190
+ if step_name:
191
+ self.step_names.append(step_name)
192
+
193
+ def get_progress(self) -> float:
194
+ """Get progress as percentage"""
195
+ return (self.current_step / self.total_steps) * 100
196
+
197
+ def get_status(self) -> str:
198
+ """Get status string"""
199
+ return f"Step {self.current_step}/{self.total_steps} ({self.get_progress():.1f}%)"
200
+
201
+
202
+ def validate_paper_dict(paper: Dict) -> bool:
203
+ """Validate that paper dictionary has required fields"""
204
+ required_fields = ['title', 'abstract', 'arxiv_id', 'authors', 'published']
205
+ return all(field in paper for field in required_fields)
206
+
207
+
208
+ def safe_get(dictionary: Dict, key: str, default=None):
209
+ """Safely get value from dictionary with fallback"""
210
+ try:
211
+ return dictionary.get(key, default)
212
+ except:
213
+ return default
214
+
215
+
216
+ # Example usage
217
+ if __name__ == "__main__":
218
+ # Test utilities
219
+ sample_paper = {
220
+ 'title': 'Attention Is All You Need',
221
+ 'authors': ['Vaswani', 'Shazeer', 'Parmar', 'Uszkoreit'],
222
+ 'published': '2017-06-12',
223
+ 'arxiv_id': '1706.03762',
224
+ 'abstract': 'The dominant sequence transduction models are based on complex recurrent or convolutional neural networks...'
225
+ }
226
+
227
+ print("Citation (APA):", format_citation(sample_paper, 'apa'))
228
+ print("Citation (Markdown):", format_citation(sample_paper, 'markdown'))
229
+ print("Authors:", format_authors(sample_paper['authors']))
230
+ print("Recency score:", score_recency(2017))
231
+ print("Keywords:", extract_keywords(sample_paper['abstract']))