Upload 4 files
Browse files- agents.py +351 -0
- fetch_arxiv_data.py +114 -0
- retriever.py +201 -0
- utils.py +231 -0
agents.py
ADDED
|
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
DocMind - Multi-Agent System
|
| 3 |
+
Implements Retriever, Reader, Critic, and Synthesizer agents
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from typing import List, Dict, Tuple
|
| 7 |
+
from retriever import PaperRetriever
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class RetrieverAgent:
|
| 12 |
+
"""Agent responsible for finding relevant papers"""
|
| 13 |
+
|
| 14 |
+
def __init__(self, retriever: PaperRetriever):
|
| 15 |
+
self.retriever = retriever
|
| 16 |
+
|
| 17 |
+
def retrieve(self, query: str, top_k: int = 5) -> List[Tuple[Dict, float]]:
|
| 18 |
+
"""
|
| 19 |
+
Retrieve relevant papers for the query
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
List of (paper, relevance_score) tuples
|
| 23 |
+
"""
|
| 24 |
+
print(f"🔍 Retriever Agent: Searching for '{query}'...")
|
| 25 |
+
results = self.retriever.search(query, top_k)
|
| 26 |
+
print(f" Found {len(results)} relevant papers")
|
| 27 |
+
return results
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class ReaderAgent:
|
| 31 |
+
"""Agent responsible for reading and summarizing papers"""
|
| 32 |
+
|
| 33 |
+
def __init__(self, llm_client=None):
|
| 34 |
+
"""
|
| 35 |
+
Args:
|
| 36 |
+
llm_client: Optional LLM client (OpenAI, Anthropic, etc.)
|
| 37 |
+
If None, uses rule-based summarization
|
| 38 |
+
"""
|
| 39 |
+
self.llm_client = llm_client
|
| 40 |
+
|
| 41 |
+
def summarize_paper(self, paper: Dict) -> str:
|
| 42 |
+
"""
|
| 43 |
+
Generate a summary of a single paper
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
paper: Paper dictionary with title, abstract, etc.
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
Summary string
|
| 50 |
+
"""
|
| 51 |
+
if self.llm_client:
|
| 52 |
+
return self._llm_summarize(paper)
|
| 53 |
+
else:
|
| 54 |
+
return self._rule_based_summarize(paper)
|
| 55 |
+
|
| 56 |
+
def _rule_based_summarize(self, paper: Dict) -> str:
|
| 57 |
+
"""Simple extractive summary (first 3 sentences)"""
|
| 58 |
+
abstract = paper['abstract']
|
| 59 |
+
sentences = abstract.split('. ')
|
| 60 |
+
summary = '. '.join(sentences[:3]) + '.'
|
| 61 |
+
|
| 62 |
+
return {
|
| 63 |
+
'title': paper['title'],
|
| 64 |
+
'arxiv_id': paper['arxiv_id'],
|
| 65 |
+
'authors': paper['authors'][:3],
|
| 66 |
+
'summary': summary,
|
| 67 |
+
'year': paper['published'][:4]
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
def _llm_summarize(self, paper: Dict) -> str:
|
| 71 |
+
"""Use LLM to generate intelligent summary"""
|
| 72 |
+
prompt = f"""Summarize this research paper in 2-3 sentences, focusing on:
|
| 73 |
+
1. The main contribution/idea
|
| 74 |
+
2. The key methodology or approach
|
| 75 |
+
3. Important results or implications
|
| 76 |
+
|
| 77 |
+
Title: {paper['title']}
|
| 78 |
+
Abstract: {paper['abstract']}
|
| 79 |
+
|
| 80 |
+
Summary:"""
|
| 81 |
+
|
| 82 |
+
# Call LLM (implementation depends on client)
|
| 83 |
+
# This is a placeholder - replace with actual LLM call
|
| 84 |
+
response = "LLM summary would go here"
|
| 85 |
+
|
| 86 |
+
return {
|
| 87 |
+
'title': paper['title'],
|
| 88 |
+
'arxiv_id': paper['arxiv_id'],
|
| 89 |
+
'authors': paper['authors'][:3],
|
| 90 |
+
'summary': response,
|
| 91 |
+
'year': paper['published'][:4]
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
def read_papers(self, papers: List[Tuple[Dict, float]]) -> List[Dict]:
|
| 95 |
+
"""
|
| 96 |
+
Read and summarize multiple papers
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
papers: List of (paper, score) tuples from retriever
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
List of summaries
|
| 103 |
+
"""
|
| 104 |
+
print(f"📖 Reader Agent: Reading {len(papers)} papers...")
|
| 105 |
+
summaries = []
|
| 106 |
+
|
| 107 |
+
for paper, score in papers:
|
| 108 |
+
summary = self.summarize_paper(paper)
|
| 109 |
+
summary['relevance_score'] = score
|
| 110 |
+
summaries.append(summary)
|
| 111 |
+
|
| 112 |
+
print(f" Generated {len(summaries)} summaries")
|
| 113 |
+
return summaries
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class CriticAgent:
|
| 117 |
+
"""Agent responsible for evaluating and filtering summaries"""
|
| 118 |
+
|
| 119 |
+
def __init__(self, llm_client=None):
|
| 120 |
+
self.llm_client = llm_client
|
| 121 |
+
|
| 122 |
+
def critique(self, summaries: List[Dict], query: str) -> List[Dict]:
|
| 123 |
+
"""
|
| 124 |
+
Evaluate summaries for quality and relevance
|
| 125 |
+
|
| 126 |
+
Args:
|
| 127 |
+
summaries: List of paper summaries
|
| 128 |
+
query: Original user query
|
| 129 |
+
|
| 130 |
+
Returns:
|
| 131 |
+
Filtered and scored summaries
|
| 132 |
+
"""
|
| 133 |
+
print(f"🔎 Critic Agent: Evaluating {len(summaries)} summaries...")
|
| 134 |
+
|
| 135 |
+
# Simple rule-based filtering
|
| 136 |
+
filtered = []
|
| 137 |
+
for summary in summaries:
|
| 138 |
+
# Check relevance score threshold
|
| 139 |
+
if summary['relevance_score'] > 0.3:
|
| 140 |
+
# Add quality score (can be enhanced with LLM)
|
| 141 |
+
summary['quality_score'] = self._assess_quality(summary, query)
|
| 142 |
+
filtered.append(summary)
|
| 143 |
+
|
| 144 |
+
# Sort by combined score
|
| 145 |
+
filtered.sort(
|
| 146 |
+
key=lambda x: x['relevance_score'] * 0.7 + x['quality_score'] * 0.3,
|
| 147 |
+
reverse=True
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
print(f" Retained {len(filtered)} high-quality summaries")
|
| 151 |
+
return filtered
|
| 152 |
+
|
| 153 |
+
def _assess_quality(self, summary: Dict, query: str) -> float:
|
| 154 |
+
"""
|
| 155 |
+
Simple quality assessment (can be enhanced with LLM)
|
| 156 |
+
|
| 157 |
+
Returns:
|
| 158 |
+
Quality score 0-1
|
| 159 |
+
"""
|
| 160 |
+
score = 0.5 # Base score
|
| 161 |
+
|
| 162 |
+
# Longer summaries might be more informative
|
| 163 |
+
if len(summary['summary']) > 100:
|
| 164 |
+
score += 0.2
|
| 165 |
+
|
| 166 |
+
# Recent papers get bonus
|
| 167 |
+
if int(summary['year']) >= 2024:
|
| 168 |
+
score += 0.3
|
| 169 |
+
|
| 170 |
+
return min(score, 1.0)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class SynthesizerAgent:
|
| 174 |
+
"""Agent responsible for synthesizing final answer"""
|
| 175 |
+
|
| 176 |
+
def __init__(self, llm_client=None):
|
| 177 |
+
self.llm_client = llm_client
|
| 178 |
+
|
| 179 |
+
def synthesize(
|
| 180 |
+
self,
|
| 181 |
+
summaries: List[Dict],
|
| 182 |
+
query: str,
|
| 183 |
+
max_papers: int = 10
|
| 184 |
+
) -> str:
|
| 185 |
+
"""
|
| 186 |
+
Synthesize final answer from summaries
|
| 187 |
+
|
| 188 |
+
Args:
|
| 189 |
+
summaries: List of filtered, quality summaries
|
| 190 |
+
query: Original user query
|
| 191 |
+
max_papers: Maximum papers to include in response
|
| 192 |
+
|
| 193 |
+
Returns:
|
| 194 |
+
Final synthesized response with citations
|
| 195 |
+
"""
|
| 196 |
+
print(f"✨ Synthesizer Agent: Creating final response...")
|
| 197 |
+
|
| 198 |
+
if not summaries:
|
| 199 |
+
return "No relevant papers found for your query."
|
| 200 |
+
|
| 201 |
+
# Limit to top papers
|
| 202 |
+
top_summaries = summaries[:max_papers]
|
| 203 |
+
|
| 204 |
+
if self.llm_client:
|
| 205 |
+
return self._llm_synthesize(top_summaries, query)
|
| 206 |
+
else:
|
| 207 |
+
return self._rule_based_synthesize(top_summaries, query)
|
| 208 |
+
|
| 209 |
+
def _rule_based_synthesize(self, summaries: List[Dict], query: str) -> str:
|
| 210 |
+
"""Create structured response without LLM"""
|
| 211 |
+
response = f"# Research Summary: {query}\n\n"
|
| 212 |
+
response += f"Based on {len(summaries)} relevant papers from arXiv:\n\n"
|
| 213 |
+
|
| 214 |
+
for i, summary in enumerate(summaries, 1):
|
| 215 |
+
response += f"## [{i}] {summary['title']}\n"
|
| 216 |
+
response += f"**Authors:** {', '.join(summary['authors'])}"
|
| 217 |
+
if len(summary['authors']) >= 3:
|
| 218 |
+
response += " et al."
|
| 219 |
+
response += f"\n**Year:** {summary['year']}\n"
|
| 220 |
+
response += f"**arXiv ID:** {summary['arxiv_id']}\n"
|
| 221 |
+
response += f"**Relevance:** {summary['relevance_score']:.2f}\n\n"
|
| 222 |
+
response += f"{summary['summary']}\n\n"
|
| 223 |
+
response += "---\n\n"
|
| 224 |
+
|
| 225 |
+
return response
|
| 226 |
+
|
| 227 |
+
def _llm_synthesize(self, summaries: List[Dict], query: str) -> str:
|
| 228 |
+
"""Use LLM to create coherent synthesis"""
|
| 229 |
+
# Build context from summaries
|
| 230 |
+
context = ""
|
| 231 |
+
for i, summary in enumerate(summaries, 1):
|
| 232 |
+
context += f"[{i}] {summary['title']} ({summary['arxiv_id']})\n"
|
| 233 |
+
context += f" {summary['summary']}\n\n"
|
| 234 |
+
|
| 235 |
+
prompt = f"""You are a research assistant. Based on the following papers, answer this question:
|
| 236 |
+
|
| 237 |
+
Question: {query}
|
| 238 |
+
|
| 239 |
+
Papers:
|
| 240 |
+
{context}
|
| 241 |
+
|
| 242 |
+
Provide a comprehensive answer that:
|
| 243 |
+
1. Directly addresses the question
|
| 244 |
+
2. Synthesizes information across papers
|
| 245 |
+
3. Cites papers by number [1], [2], etc.
|
| 246 |
+
4. Highlights key findings and consensus/disagreements
|
| 247 |
+
5. Is concise but thorough (3-5 paragraphs)
|
| 248 |
+
|
| 249 |
+
Answer:"""
|
| 250 |
+
|
| 251 |
+
# Placeholder for LLM call
|
| 252 |
+
response = "LLM-generated synthesis would go here with citations"
|
| 253 |
+
|
| 254 |
+
# Append paper references
|
| 255 |
+
response += "\n\n## References\n"
|
| 256 |
+
for i, summary in enumerate(summaries, 1):
|
| 257 |
+
response += f"[{i}] {summary['title']} "
|
| 258 |
+
response += f"({summary['arxiv_id']}, {summary['year']})\n"
|
| 259 |
+
|
| 260 |
+
return response
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
class DocMindOrchestrator:
|
| 264 |
+
"""Main orchestrator that coordinates all agents"""
|
| 265 |
+
|
| 266 |
+
def __init__(
|
| 267 |
+
self,
|
| 268 |
+
retriever: PaperRetriever,
|
| 269 |
+
llm_client=None
|
| 270 |
+
):
|
| 271 |
+
self.retriever_agent = RetrieverAgent(retriever)
|
| 272 |
+
self.reader_agent = ReaderAgent(llm_client)
|
| 273 |
+
self.critic_agent = CriticAgent(llm_client)
|
| 274 |
+
self.synthesizer_agent = SynthesizerAgent(llm_client)
|
| 275 |
+
|
| 276 |
+
def process_query(
|
| 277 |
+
self,
|
| 278 |
+
query: str,
|
| 279 |
+
top_k: int = 10,
|
| 280 |
+
max_papers_in_response: int = 5
|
| 281 |
+
) -> str:
|
| 282 |
+
"""
|
| 283 |
+
Process user query through full agent pipeline
|
| 284 |
+
|
| 285 |
+
Args:
|
| 286 |
+
query: User question
|
| 287 |
+
top_k: Number of papers to retrieve
|
| 288 |
+
max_papers_in_response: Max papers in final response
|
| 289 |
+
|
| 290 |
+
Returns:
|
| 291 |
+
Final synthesized answer
|
| 292 |
+
"""
|
| 293 |
+
print(f"\n{'=' * 60}")
|
| 294 |
+
print(f"Processing query: {query}")
|
| 295 |
+
print('=' * 60)
|
| 296 |
+
|
| 297 |
+
# Step 1: Retrieve
|
| 298 |
+
papers = self.retriever_agent.retrieve(query, top_k)
|
| 299 |
+
|
| 300 |
+
if not papers:
|
| 301 |
+
return "No relevant papers found for your query."
|
| 302 |
+
|
| 303 |
+
# Step 2: Read & Summarize
|
| 304 |
+
summaries = self.reader_agent.read_papers(papers)
|
| 305 |
+
|
| 306 |
+
# Step 3: Critique & Filter
|
| 307 |
+
quality_summaries = self.critic_agent.critique(summaries, query)
|
| 308 |
+
|
| 309 |
+
# Step 4: Synthesize
|
| 310 |
+
final_response = self.synthesizer_agent.synthesize(
|
| 311 |
+
quality_summaries,
|
| 312 |
+
query,
|
| 313 |
+
max_papers_in_response
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
print(f"{'=' * 60}\n")
|
| 317 |
+
return final_response
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
def main():
|
| 321 |
+
"""Example usage of multi-agent system"""
|
| 322 |
+
from fetch_arxiv_data import ArxivFetcher
|
| 323 |
+
|
| 324 |
+
# Setup
|
| 325 |
+
fetcher = ArxivFetcher()
|
| 326 |
+
retriever = PaperRetriever()
|
| 327 |
+
|
| 328 |
+
# Load or build index
|
| 329 |
+
if not retriever.load_index():
|
| 330 |
+
papers = fetcher.load_papers("arxiv_papers.json")
|
| 331 |
+
retriever.build_index(papers)
|
| 332 |
+
retriever.save_index()
|
| 333 |
+
|
| 334 |
+
# Create orchestrator
|
| 335 |
+
orchestrator = DocMindOrchestrator(retriever)
|
| 336 |
+
|
| 337 |
+
# Test queries
|
| 338 |
+
test_queries = [
|
| 339 |
+
"What are the latest improvements in diffusion models?",
|
| 340 |
+
"How does RLHF compare to DPO for language model alignment?",
|
| 341 |
+
"What are the main challenges in scaling transformers?"
|
| 342 |
+
]
|
| 343 |
+
|
| 344 |
+
for query in test_queries:
|
| 345 |
+
response = orchestrator.process_query(query, top_k=8, max_papers_in_response=3)
|
| 346 |
+
print(response)
|
| 347 |
+
print("\n" + "=" * 80 + "\n")
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
if __name__ == "__main__":
|
| 351 |
+
main()
|
fetch_arxiv_data.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
DocMind - arXiv Data Fetcher
|
| 3 |
+
Fetches papers from arXiv API and saves them for indexing
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import arxiv
|
| 7 |
+
import os
|
| 8 |
+
import json
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import List, Dict
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class ArxivFetcher:
|
| 14 |
+
def __init__(self, data_dir: str = "data/papers"):
|
| 15 |
+
self.data_dir = Path(data_dir)
|
| 16 |
+
self.data_dir.mkdir(parents=True, exist_ok=True)
|
| 17 |
+
|
| 18 |
+
def fetch_papers(
|
| 19 |
+
self,
|
| 20 |
+
query: str = "machine learning",
|
| 21 |
+
max_results: int = 100,
|
| 22 |
+
category: str = None
|
| 23 |
+
) -> List[Dict]:
|
| 24 |
+
"""
|
| 25 |
+
Fetch papers from arXiv API
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
query: Search query string
|
| 29 |
+
max_results: Maximum number of papers to fetch
|
| 30 |
+
category: arXiv category (e.g., 'cs.AI', 'cs.LG')
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
List of paper dictionaries
|
| 34 |
+
"""
|
| 35 |
+
print(f"Fetching papers from arXiv: query='{query}', max={max_results}")
|
| 36 |
+
|
| 37 |
+
# Build search query
|
| 38 |
+
search_query = query
|
| 39 |
+
if category:
|
| 40 |
+
search_query = f"cat:{category} AND {query}"
|
| 41 |
+
|
| 42 |
+
search = arxiv.Search(
|
| 43 |
+
query=search_query,
|
| 44 |
+
max_results=max_results,
|
| 45 |
+
sort_by=arxiv.SortCriterion.SubmittedDate
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
papers = []
|
| 49 |
+
for result in search.results():
|
| 50 |
+
paper = {
|
| 51 |
+
'arxiv_id': result.entry_id.split('/')[-1],
|
| 52 |
+
'title': result.title,
|
| 53 |
+
'authors': [author.name for author in result.authors],
|
| 54 |
+
'abstract': result.summary,
|
| 55 |
+
'published': result.published.strftime('%Y-%m-%d'),
|
| 56 |
+
'pdf_url': result.pdf_url,
|
| 57 |
+
'categories': result.categories
|
| 58 |
+
}
|
| 59 |
+
papers.append(paper)
|
| 60 |
+
|
| 61 |
+
print(f"Successfully fetched {len(papers)} papers")
|
| 62 |
+
return papers
|
| 63 |
+
|
| 64 |
+
def save_papers(self, papers: List[Dict], filename: str = "papers.json"):
|
| 65 |
+
"""Save papers to JSON file"""
|
| 66 |
+
filepath = self.data_dir / filename
|
| 67 |
+
with open(filepath, 'w', encoding='utf-8') as f:
|
| 68 |
+
json.dump(papers, f, indent=2, ensure_ascii=False)
|
| 69 |
+
print(f"Saved {len(papers)} papers to {filepath}")
|
| 70 |
+
|
| 71 |
+
def load_papers(self, filename: str = "papers.json") -> List[Dict]:
|
| 72 |
+
"""Load papers from JSON file"""
|
| 73 |
+
filepath = self.data_dir / filename
|
| 74 |
+
if not filepath.exists():
|
| 75 |
+
print(f"No saved papers found at {filepath}")
|
| 76 |
+
return []
|
| 77 |
+
|
| 78 |
+
with open(filepath, 'r', encoding='utf-8') as f:
|
| 79 |
+
papers = json.load(f)
|
| 80 |
+
print(f"Loaded {len(papers)} papers from {filepath}")
|
| 81 |
+
return papers
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def main():
|
| 85 |
+
"""Example usage: Fetch recent ML and AI papers"""
|
| 86 |
+
fetcher = ArxivFetcher()
|
| 87 |
+
|
| 88 |
+
# Fetch recent ML papers
|
| 89 |
+
ml_papers = fetcher.fetch_papers(
|
| 90 |
+
query="machine learning OR deep learning",
|
| 91 |
+
max_results=50,
|
| 92 |
+
category="cs.LG"
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
# Fetch recent AI papers
|
| 96 |
+
ai_papers = fetcher.fetch_papers(
|
| 97 |
+
query="artificial intelligence OR neural networks",
|
| 98 |
+
max_results=50,
|
| 99 |
+
category="cs.AI"
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
# Combine and save
|
| 103 |
+
all_papers = ml_papers + ai_papers
|
| 104 |
+
fetcher.save_papers(all_papers, "arxiv_papers.json")
|
| 105 |
+
|
| 106 |
+
# Show sample
|
| 107 |
+
print("\n=== Sample Paper ===")
|
| 108 |
+
print(f"Title: {all_papers[0]['title']}")
|
| 109 |
+
print(f"Authors: {', '.join(all_papers[0]['authors'][:3])}")
|
| 110 |
+
print(f"Abstract: {all_papers[0]['abstract'][:200]}...")
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
if __name__ == "__main__":
|
| 114 |
+
main()
|
retriever.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
DocMind - Retriever Module
|
| 3 |
+
Semantic search over arXiv papers using FAISS and sentence-transformers
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import faiss
|
| 8 |
+
from sentence_transformers import SentenceTransformer
|
| 9 |
+
from typing import List, Dict, Tuple
|
| 10 |
+
import pickle
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class PaperRetriever:
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
|
| 18 |
+
index_path: str = "data/faiss_index"
|
| 19 |
+
):
|
| 20 |
+
"""
|
| 21 |
+
Initialize retriever with embedding model and FAISS index
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
model_name: HuggingFace sentence-transformer model
|
| 25 |
+
index_path: Directory to save/load FAISS index
|
| 26 |
+
"""
|
| 27 |
+
print(f"Loading embedding model: {model_name}")
|
| 28 |
+
self.model = SentenceTransformer(model_name)
|
| 29 |
+
self.index_path = Path(index_path)
|
| 30 |
+
self.index_path.mkdir(parents=True, exist_ok=True)
|
| 31 |
+
|
| 32 |
+
self.index = None
|
| 33 |
+
self.papers = []
|
| 34 |
+
self.embeddings = None
|
| 35 |
+
|
| 36 |
+
def build_index(self, papers: List[Dict]):
|
| 37 |
+
"""
|
| 38 |
+
Build FAISS index from papers
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
papers: List of paper dictionaries with 'title' and 'abstract'
|
| 42 |
+
"""
|
| 43 |
+
print(f"Building index for {len(papers)} papers...")
|
| 44 |
+
self.papers = papers
|
| 45 |
+
|
| 46 |
+
# Create text to embed: title + abstract
|
| 47 |
+
texts = [
|
| 48 |
+
f"{paper['title']}. {paper['abstract']}"
|
| 49 |
+
for paper in papers
|
| 50 |
+
]
|
| 51 |
+
|
| 52 |
+
# Generate embeddings
|
| 53 |
+
print("Generating embeddings...")
|
| 54 |
+
self.embeddings = self.model.encode(
|
| 55 |
+
texts,
|
| 56 |
+
show_progress_bar=True,
|
| 57 |
+
convert_to_numpy=True
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
# Build FAISS index
|
| 61 |
+
dimension = self.embeddings.shape[1]
|
| 62 |
+
self.index = faiss.IndexFlatIP(dimension) # Inner product (cosine similarity)
|
| 63 |
+
|
| 64 |
+
# Normalize embeddings for cosine similarity
|
| 65 |
+
faiss.normalize_L2(self.embeddings)
|
| 66 |
+
self.index.add(self.embeddings)
|
| 67 |
+
|
| 68 |
+
print(f"Index built with {self.index.ntotal} papers")
|
| 69 |
+
|
| 70 |
+
def save_index(self, name: str = "papers"):
|
| 71 |
+
"""Save FAISS index and metadata"""
|
| 72 |
+
faiss.write_index(self.index, str(self.index_path / f"{name}.index"))
|
| 73 |
+
|
| 74 |
+
with open(self.index_path / f"{name}_papers.pkl", 'wb') as f:
|
| 75 |
+
pickle.dump(self.papers, f)
|
| 76 |
+
|
| 77 |
+
with open(self.index_path / f"{name}_embeddings.npy", 'wb') as f:
|
| 78 |
+
np.save(f, self.embeddings)
|
| 79 |
+
|
| 80 |
+
print(f"Saved index to {self.index_path}/{name}.*")
|
| 81 |
+
|
| 82 |
+
def load_index(self, name: str = "papers"):
|
| 83 |
+
"""Load FAISS index and metadata"""
|
| 84 |
+
index_file = self.index_path / f"{name}.index"
|
| 85 |
+
if not index_file.exists():
|
| 86 |
+
print(f"No index found at {index_file}")
|
| 87 |
+
return False
|
| 88 |
+
|
| 89 |
+
self.index = faiss.read_index(str(index_file))
|
| 90 |
+
|
| 91 |
+
with open(self.index_path / f"{name}_papers.pkl", 'rb') as f:
|
| 92 |
+
self.papers = pickle.load(f)
|
| 93 |
+
|
| 94 |
+
with open(self.index_path / f"{name}_embeddings.npy", 'rb') as f:
|
| 95 |
+
self.embeddings = np.load(f)
|
| 96 |
+
|
| 97 |
+
print(f"Loaded index with {len(self.papers)} papers")
|
| 98 |
+
return True
|
| 99 |
+
|
| 100 |
+
def search(
|
| 101 |
+
self,
|
| 102 |
+
query: str,
|
| 103 |
+
top_k: int = 5
|
| 104 |
+
) -> List[Tuple[Dict, float]]:
|
| 105 |
+
"""
|
| 106 |
+
Search for relevant papers
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
query: Search query string
|
| 110 |
+
top_k: Number of results to return
|
| 111 |
+
|
| 112 |
+
Returns:
|
| 113 |
+
List of (paper_dict, score) tuples
|
| 114 |
+
"""
|
| 115 |
+
if self.index is None:
|
| 116 |
+
raise ValueError("Index not built or loaded")
|
| 117 |
+
|
| 118 |
+
# Embed query
|
| 119 |
+
query_embedding = self.model.encode([query], convert_to_numpy=True)
|
| 120 |
+
faiss.normalize_L2(query_embedding)
|
| 121 |
+
|
| 122 |
+
# Search
|
| 123 |
+
scores, indices = self.index.search(query_embedding, top_k)
|
| 124 |
+
|
| 125 |
+
# Return results
|
| 126 |
+
results = []
|
| 127 |
+
for idx, score in zip(indices[0], scores[0]):
|
| 128 |
+
paper = self.papers[idx]
|
| 129 |
+
results.append((paper, float(score)))
|
| 130 |
+
|
| 131 |
+
return results
|
| 132 |
+
|
| 133 |
+
def get_retrieval_context(
|
| 134 |
+
self,
|
| 135 |
+
query: str,
|
| 136 |
+
top_k: int = 5
|
| 137 |
+
) -> str:
|
| 138 |
+
"""
|
| 139 |
+
Get formatted context string for LLM consumption
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
query: Search query
|
| 143 |
+
top_k: Number of papers to retrieve
|
| 144 |
+
|
| 145 |
+
Returns:
|
| 146 |
+
Formatted context string with paper summaries
|
| 147 |
+
"""
|
| 148 |
+
results = self.search(query, top_k)
|
| 149 |
+
|
| 150 |
+
context = f"Retrieved {len(results)} relevant papers:\n\n"
|
| 151 |
+
for i, (paper, score) in enumerate(results, 1):
|
| 152 |
+
context += f"[{i}] {paper['title']}\n"
|
| 153 |
+
context += f" Authors: {', '.join(paper['authors'][:3])}"
|
| 154 |
+
if len(paper['authors']) > 3:
|
| 155 |
+
context += f" et al."
|
| 156 |
+
context += f"\n arXiv ID: {paper['arxiv_id']}\n"
|
| 157 |
+
context += f" Published: {paper['published']}\n"
|
| 158 |
+
context += f" Relevance: {score:.3f}\n"
|
| 159 |
+
context += f" Abstract: {paper['abstract']}\n\n"
|
| 160 |
+
|
| 161 |
+
return context
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def main():
|
| 165 |
+
"""Example: Build and test retriever"""
|
| 166 |
+
from fetch_arxiv_data import ArxivFetcher
|
| 167 |
+
|
| 168 |
+
# Load papers
|
| 169 |
+
fetcher = ArxivFetcher()
|
| 170 |
+
papers = fetcher.load_papers("arxiv_papers.json")
|
| 171 |
+
|
| 172 |
+
if not papers:
|
| 173 |
+
print("No papers found. Run fetch_arxiv_data.py first")
|
| 174 |
+
return
|
| 175 |
+
|
| 176 |
+
# Build index
|
| 177 |
+
retriever = PaperRetriever()
|
| 178 |
+
retriever.build_index(papers)
|
| 179 |
+
retriever.save_index()
|
| 180 |
+
|
| 181 |
+
# Test search
|
| 182 |
+
test_queries = [
|
| 183 |
+
"diffusion models for image generation",
|
| 184 |
+
"reinforcement learning from human feedback",
|
| 185 |
+
"large language model alignment"
|
| 186 |
+
]
|
| 187 |
+
|
| 188 |
+
for query in test_queries:
|
| 189 |
+
print(f"\n{'=' * 60}")
|
| 190 |
+
print(f"Query: {query}")
|
| 191 |
+
print('=' * 60)
|
| 192 |
+
|
| 193 |
+
results = retriever.search(query, top_k=3)
|
| 194 |
+
for i, (paper, score) in enumerate(results, 1):
|
| 195 |
+
print(f"\n[{i}] Score: {score:.3f}")
|
| 196 |
+
print(f" {paper['title']}")
|
| 197 |
+
print(f" {paper['arxiv_id']}")
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
if __name__ == "__main__":
|
| 201 |
+
main()
|
utils.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
DocMind - Utility Functions
|
| 3 |
+
Helper functions for the multi-agent system
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from typing import List, Dict
|
| 7 |
+
import re
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def clean_text(text: str) -> str:
|
| 12 |
+
"""Clean and normalize text"""
|
| 13 |
+
# Remove extra whitespace
|
| 14 |
+
text = re.sub(r'\s+', ' ', text)
|
| 15 |
+
# Remove special characters but keep basic punctuation
|
| 16 |
+
text = re.sub(r'[^\w\s.,!?;:()\-]', '', text)
|
| 17 |
+
return text.strip()
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def truncate_text(text: str, max_length: int = 500) -> str:
|
| 21 |
+
"""Truncate text to maximum length, ending at sentence boundary"""
|
| 22 |
+
if len(text) <= max_length:
|
| 23 |
+
return text
|
| 24 |
+
|
| 25 |
+
# Find last sentence boundary before max_length
|
| 26 |
+
truncated = text[:max_length]
|
| 27 |
+
last_period = truncated.rfind('.')
|
| 28 |
+
|
| 29 |
+
if last_period > 0:
|
| 30 |
+
return truncated[:last_period + 1]
|
| 31 |
+
return truncated + "..."
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def format_authors(authors: List[str], max_authors: int = 3) -> str:
|
| 35 |
+
"""Format author list for display"""
|
| 36 |
+
if len(authors) <= max_authors:
|
| 37 |
+
return ", ".join(authors)
|
| 38 |
+
else:
|
| 39 |
+
return ", ".join(authors[:max_authors]) + " et al."
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def extract_year(date_string: str) -> int:
|
| 43 |
+
"""Extract year from date string"""
|
| 44 |
+
try:
|
| 45 |
+
if isinstance(date_string, str):
|
| 46 |
+
return int(date_string[:4])
|
| 47 |
+
return datetime.now().year
|
| 48 |
+
except:
|
| 49 |
+
return datetime.now().year
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def score_recency(year: int, current_year: int = None) -> float:
|
| 53 |
+
"""
|
| 54 |
+
Score paper based on recency
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
Score from 0-1, where 1 is most recent
|
| 58 |
+
"""
|
| 59 |
+
if current_year is None:
|
| 60 |
+
current_year = datetime.now().year
|
| 61 |
+
|
| 62 |
+
age = current_year - year
|
| 63 |
+
if age <= 0:
|
| 64 |
+
return 1.0
|
| 65 |
+
elif age <= 1:
|
| 66 |
+
return 0.9
|
| 67 |
+
elif age <= 2:
|
| 68 |
+
return 0.7
|
| 69 |
+
elif age <= 3:
|
| 70 |
+
return 0.5
|
| 71 |
+
else:
|
| 72 |
+
return max(0.3, 1.0 / (age + 1))
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def combine_scores(
|
| 76 |
+
relevance: float,
|
| 77 |
+
recency: float,
|
| 78 |
+
quality: float,
|
| 79 |
+
weights: Dict[str, float] = None
|
| 80 |
+
) -> float:
|
| 81 |
+
"""
|
| 82 |
+
Combine multiple scores with weights
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
relevance: Relevance score (0-1)
|
| 86 |
+
recency: Recency score (0-1)
|
| 87 |
+
quality: Quality score (0-1)
|
| 88 |
+
weights: Dict with keys 'relevance', 'recency', 'quality'
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
Combined score (0-1)
|
| 92 |
+
"""
|
| 93 |
+
if weights is None:
|
| 94 |
+
weights = {
|
| 95 |
+
'relevance': 0.6,
|
| 96 |
+
'recency': 0.2,
|
| 97 |
+
'quality': 0.2
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
return (
|
| 101 |
+
relevance * weights['relevance'] +
|
| 102 |
+
recency * weights['recency'] +
|
| 103 |
+
quality * weights['quality']
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def deduplicate_papers(papers: List[Dict]) -> List[Dict]:
|
| 108 |
+
"""Remove duplicate papers based on arXiv ID"""
|
| 109 |
+
seen = set()
|
| 110 |
+
unique = []
|
| 111 |
+
|
| 112 |
+
for paper in papers:
|
| 113 |
+
paper_id = paper.get('arxiv_id', '')
|
| 114 |
+
if paper_id and paper_id not in seen:
|
| 115 |
+
seen.add(paper_id)
|
| 116 |
+
unique.append(paper)
|
| 117 |
+
|
| 118 |
+
return unique
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def format_citation(paper: Dict, style: str = 'apa') -> str:
|
| 122 |
+
"""
|
| 123 |
+
Format paper citation
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
paper: Paper dict with title, authors, year, arxiv_id
|
| 127 |
+
style: Citation style ('apa', 'simple', 'markdown')
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
Formatted citation string
|
| 131 |
+
"""
|
| 132 |
+
authors = format_authors(paper.get('authors', []))
|
| 133 |
+
title = paper.get('title', 'Unknown Title')
|
| 134 |
+
year = extract_year(paper.get('published', ''))
|
| 135 |
+
arxiv_id = paper.get('arxiv_id', '')
|
| 136 |
+
|
| 137 |
+
if style == 'apa':
|
| 138 |
+
return f"{authors} ({year}). {title}. arXiv:{arxiv_id}"
|
| 139 |
+
|
| 140 |
+
elif style == 'markdown':
|
| 141 |
+
return f"**{title}** - {authors} ({year}) - arXiv:[{arxiv_id}](https://arxiv.org/abs/{arxiv_id})"
|
| 142 |
+
|
| 143 |
+
else: # simple
|
| 144 |
+
return f"{title} ({arxiv_id}, {year})"
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def extract_keywords(text: str, top_n: int = 5) -> List[str]:
|
| 148 |
+
"""
|
| 149 |
+
Extract simple keywords from text (frequency-based)
|
| 150 |
+
|
| 151 |
+
Args:
|
| 152 |
+
text: Input text
|
| 153 |
+
top_n: Number of keywords to return
|
| 154 |
+
|
| 155 |
+
Returns:
|
| 156 |
+
List of top keywords
|
| 157 |
+
"""
|
| 158 |
+
# Simple word frequency approach
|
| 159 |
+
# Remove common words
|
| 160 |
+
stop_words = {
|
| 161 |
+
'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for',
|
| 162 |
+
'of', 'with', 'by', 'from', 'is', 'are', 'was', 'were', 'be', 'been',
|
| 163 |
+
'this', 'that', 'these', 'those', 'we', 'our', 'propose', 'show'
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
# Tokenize and count
|
| 167 |
+
words = re.findall(r'\b[a-z]{4,}\b', text.lower())
|
| 168 |
+
word_freq = {}
|
| 169 |
+
|
| 170 |
+
for word in words:
|
| 171 |
+
if word not in stop_words:
|
| 172 |
+
word_freq[word] = word_freq.get(word, 0) + 1
|
| 173 |
+
|
| 174 |
+
# Get top N
|
| 175 |
+
sorted_words = sorted(word_freq.items(), key=lambda x: x[1], reverse=True)
|
| 176 |
+
return [word for word, freq in sorted_words[:top_n]]
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
class ProgressTracker:
|
| 180 |
+
"""Simple progress tracker for multi-step processes"""
|
| 181 |
+
|
| 182 |
+
def __init__(self, total_steps: int):
|
| 183 |
+
self.total_steps = total_steps
|
| 184 |
+
self.current_step = 0
|
| 185 |
+
self.step_names = []
|
| 186 |
+
|
| 187 |
+
def next_step(self, step_name: str = None):
|
| 188 |
+
"""Move to next step"""
|
| 189 |
+
self.current_step += 1
|
| 190 |
+
if step_name:
|
| 191 |
+
self.step_names.append(step_name)
|
| 192 |
+
|
| 193 |
+
def get_progress(self) -> float:
|
| 194 |
+
"""Get progress as percentage"""
|
| 195 |
+
return (self.current_step / self.total_steps) * 100
|
| 196 |
+
|
| 197 |
+
def get_status(self) -> str:
|
| 198 |
+
"""Get status string"""
|
| 199 |
+
return f"Step {self.current_step}/{self.total_steps} ({self.get_progress():.1f}%)"
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def validate_paper_dict(paper: Dict) -> bool:
|
| 203 |
+
"""Validate that paper dictionary has required fields"""
|
| 204 |
+
required_fields = ['title', 'abstract', 'arxiv_id', 'authors', 'published']
|
| 205 |
+
return all(field in paper for field in required_fields)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def safe_get(dictionary: Dict, key: str, default=None):
|
| 209 |
+
"""Safely get value from dictionary with fallback"""
|
| 210 |
+
try:
|
| 211 |
+
return dictionary.get(key, default)
|
| 212 |
+
except:
|
| 213 |
+
return default
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
# Example usage
|
| 217 |
+
if __name__ == "__main__":
|
| 218 |
+
# Test utilities
|
| 219 |
+
sample_paper = {
|
| 220 |
+
'title': 'Attention Is All You Need',
|
| 221 |
+
'authors': ['Vaswani', 'Shazeer', 'Parmar', 'Uszkoreit'],
|
| 222 |
+
'published': '2017-06-12',
|
| 223 |
+
'arxiv_id': '1706.03762',
|
| 224 |
+
'abstract': 'The dominant sequence transduction models are based on complex recurrent or convolutional neural networks...'
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
print("Citation (APA):", format_citation(sample_paper, 'apa'))
|
| 228 |
+
print("Citation (Markdown):", format_citation(sample_paper, 'markdown'))
|
| 229 |
+
print("Authors:", format_authors(sample_paper['authors']))
|
| 230 |
+
print("Recency score:", score_recency(2017))
|
| 231 |
+
print("Keywords:", extract_keywords(sample_paper['abstract']))
|