Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -4,7 +4,7 @@ import urllib.parse
|
|
| 4 |
import re
|
| 5 |
import xml.etree.ElementTree as ET
|
| 6 |
from dataclasses import dataclass, field
|
| 7 |
-
from typing import Dict, List, Optional
|
| 8 |
import sys
|
| 9 |
from loguru import logger
|
| 10 |
|
|
@@ -12,13 +12,23 @@ import aiohttp
|
|
| 12 |
import gradio as gr
|
| 13 |
|
| 14 |
from langchain.prompts import PromptTemplate
|
| 15 |
-
|
| 16 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 17 |
|
| 18 |
import bibtexparser
|
| 19 |
from bibtexparser.bwriter import BibTexWriter
|
| 20 |
from bibtexparser.bibdatabase import BibDatabase
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
@dataclass
|
| 23 |
class Config:
|
| 24 |
gemini_api_key: str
|
|
@@ -28,18 +38,25 @@ class Config:
|
|
| 28 |
max_citations_per_query: int = 10
|
| 29 |
arxiv_base_url: str = 'http://export.arxiv.org/api/query?'
|
| 30 |
crossref_base_url: str = 'https://api.crossref.org/works'
|
| 31 |
-
default_headers:
|
| 32 |
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
|
| 33 |
})
|
| 34 |
log_level: str = 'DEBUG'
|
| 35 |
|
|
|
|
| 36 |
class ArxivXmlParser:
|
|
|
|
|
|
|
|
|
|
| 37 |
NS = {
|
| 38 |
'atom': 'http://www.w3.org/2005/Atom',
|
| 39 |
'arxiv': 'http://arxiv.org/schemas/atom'
|
| 40 |
}
|
| 41 |
|
| 42 |
-
def parse_papers(self, data: str) -> List[Dict]:
|
|
|
|
|
|
|
|
|
|
| 43 |
try:
|
| 44 |
root = ET.fromstring(data)
|
| 45 |
papers = []
|
|
@@ -52,12 +69,15 @@ class ArxivXmlParser:
|
|
| 52 |
logger.error(f"Error parsing ArXiv XML: {e}")
|
| 53 |
return []
|
| 54 |
|
| 55 |
-
def parse_entry(self, entry) -> Optional[
|
|
|
|
|
|
|
|
|
|
| 56 |
try:
|
| 57 |
title_node = entry.find('atom:title', self.NS)
|
| 58 |
if title_node is None:
|
| 59 |
return None
|
| 60 |
-
title = title_node.text.strip()
|
| 61 |
|
| 62 |
authors = []
|
| 63 |
for author in entry.findall('atom:author', self.NS):
|
|
@@ -66,15 +86,15 @@ class ArxivXmlParser:
|
|
| 66 |
authors.append(self._format_author_name(author_name_node.text.strip()))
|
| 67 |
|
| 68 |
arxiv_id_node = entry.find('atom:id', self.NS)
|
| 69 |
-
if arxiv_id_node is None:
|
| 70 |
return None
|
| 71 |
arxiv_id = arxiv_id_node.text.split('/')[-1]
|
| 72 |
|
| 73 |
published_node = entry.find('atom:published', self.NS)
|
| 74 |
-
year = published_node.text[:4] if published_node is not None else "Unknown"
|
| 75 |
|
| 76 |
abstract_node = entry.find('atom:summary', self.NS)
|
| 77 |
-
abstract = abstract_node.text.strip() if abstract_node is not None else ""
|
| 78 |
|
| 79 |
bibtex_key = f"{authors[0].split(',')[0]}{arxiv_id.replace('.', '')}" if authors else f"unknown{arxiv_id.replace('.', '')}"
|
| 80 |
bibtex_entry = self._generate_bibtex_entry(bibtex_key, title, authors, arxiv_id, year)
|
|
@@ -94,12 +114,18 @@ class ArxivXmlParser:
|
|
| 94 |
|
| 95 |
@staticmethod
|
| 96 |
def _format_author_name(author: str) -> str:
|
|
|
|
|
|
|
|
|
|
| 97 |
names = author.split()
|
| 98 |
if len(names) > 1:
|
| 99 |
return f"{names[-1]}, {' '.join(names[:-1])}"
|
| 100 |
return author
|
| 101 |
|
| 102 |
def _generate_bibtex_entry(self, key: str, title: str, authors: List[str], arxiv_id: str, year: str) -> str:
|
|
|
|
|
|
|
|
|
|
| 103 |
db = BibDatabase()
|
| 104 |
db.entries = [{
|
| 105 |
'ENTRYTYPE': 'article',
|
|
@@ -109,13 +135,15 @@ class ArxivXmlParser:
|
|
| 109 |
'journal': f'arXiv preprint arXiv:{arxiv_id}',
|
| 110 |
'year': year
|
| 111 |
}]
|
| 112 |
-
writer =
|
| 113 |
-
writer.indent = ' '
|
| 114 |
-
writer.comma_first = False
|
| 115 |
return writer.write(db).strip()
|
| 116 |
|
|
|
|
| 117 |
class AsyncContextManager:
|
| 118 |
-
|
|
|
|
|
|
|
|
|
|
| 119 |
self._session = aiohttp.ClientSession()
|
| 120 |
return self._session
|
| 121 |
|
|
@@ -123,13 +151,17 @@ class AsyncContextManager:
|
|
| 123 |
if self._session:
|
| 124 |
await self._session.close()
|
| 125 |
|
|
|
|
| 126 |
class CitationGenerator:
|
| 127 |
-
|
|
|
|
|
|
|
|
|
|
| 128 |
self.config = config
|
| 129 |
self.xml_parser = ArxivXmlParser()
|
| 130 |
self.async_context = AsyncContextManager()
|
| 131 |
self.llm = ChatGoogleGenerativeAI(
|
| 132 |
-
model="gemini-2.0-flash
|
| 133 |
temperature=0.3,
|
| 134 |
google_api_key=config.gemini_api_key,
|
| 135 |
streaming=True
|
|
@@ -165,6 +197,9 @@ class CitationGenerator:
|
|
| 165 |
logger.add(sys.stderr, level=config.log_level)
|
| 166 |
|
| 167 |
async def generate_queries(self, text: str, num_queries: int) -> List[str]:
|
|
|
|
|
|
|
|
|
|
| 168 |
input_map = {
|
| 169 |
"text": text,
|
| 170 |
"num_queries": num_queries
|
|
@@ -186,14 +221,15 @@ class CitationGenerator:
|
|
| 186 |
lines = [line.strip() for line in content.split('\n')
|
| 187 |
if line.strip() and not line.strip().startswith(('[', ']'))]
|
| 188 |
return lines[:num_queries]
|
| 189 |
-
|
| 190 |
return ["deep learning neural networks"]
|
| 191 |
-
|
| 192 |
except Exception as e:
|
| 193 |
logger.error(f"Error generating queries: {e}")
|
| 194 |
return ["deep learning neural networks"]
|
| 195 |
|
| 196 |
-
async def search_arxiv(self, session: aiohttp.ClientSession, query: str, max_results: int) -> List[Dict]:
|
|
|
|
|
|
|
|
|
|
| 197 |
try:
|
| 198 |
params = {
|
| 199 |
'search_query': f'all:{urllib.parse.quote(query)}',
|
|
@@ -202,8 +238,9 @@ class CitationGenerator:
|
|
| 202 |
'sortBy': 'relevance',
|
| 203 |
'sortOrder': 'descending'
|
| 204 |
}
|
|
|
|
| 205 |
async with session.get(
|
| 206 |
-
|
| 207 |
headers=self.config.default_headers,
|
| 208 |
timeout=30
|
| 209 |
) as response:
|
|
@@ -215,20 +252,23 @@ class CitationGenerator:
|
|
| 215 |
return []
|
| 216 |
|
| 217 |
async def fix_author_name(self, author: str) -> str:
|
|
|
|
|
|
|
|
|
|
| 218 |
if not re.search(r'[�]', author):
|
| 219 |
return author
|
| 220 |
try:
|
| 221 |
prompt = f"""Fix this author name that contains corrupted characters (�):
|
| 222 |
|
| 223 |
-
|
| 224 |
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
response = await self.llm.ainvoke(prompt)
|
| 233 |
fixed_name = response.content.strip()
|
| 234 |
return fixed_name if fixed_name else author
|
|
@@ -237,6 +277,9 @@ class CitationGenerator:
|
|
| 237 |
return author
|
| 238 |
|
| 239 |
async def format_bibtex_author_names(self, text: str) -> str:
|
|
|
|
|
|
|
|
|
|
| 240 |
try:
|
| 241 |
bib_database = bibtexparser.loads(text)
|
| 242 |
for entry in bib_database.entries:
|
|
@@ -247,15 +290,16 @@ class CitationGenerator:
|
|
| 247 |
fixed_author = await self.fix_author_name(author)
|
| 248 |
cleaned_authors.append(fixed_author)
|
| 249 |
entry['author'] = ' and '.join(cleaned_authors)
|
| 250 |
-
writer =
|
| 251 |
-
writer.indent = ' '
|
| 252 |
-
writer.comma_first = False
|
| 253 |
return writer.write(bib_database).strip()
|
| 254 |
except Exception as e:
|
| 255 |
logger.error(f"Error cleaning BibTeX special characters: {e}")
|
| 256 |
return text
|
| 257 |
|
| 258 |
-
async def search_crossref(self, session: aiohttp.ClientSession, query: str, max_results: int) -> List[Dict]:
|
|
|
|
|
|
|
|
|
|
| 259 |
try:
|
| 260 |
cleaned_query = query.replace("'", "").replace('"', "")
|
| 261 |
if ' ' in cleaned_query:
|
|
@@ -316,7 +360,6 @@ class CitationGenerator:
|
|
| 316 |
continue
|
| 317 |
|
| 318 |
bibtex_text = await bibtex_response.text()
|
| 319 |
-
|
| 320 |
bib_database = bibtexparser.loads(bibtex_text)
|
| 321 |
if not bib_database.entries:
|
| 322 |
continue
|
|
@@ -335,9 +378,7 @@ class CitationGenerator:
|
|
| 335 |
entry['ID'] = key
|
| 336 |
existing_keys.add(key)
|
| 337 |
|
| 338 |
-
writer =
|
| 339 |
-
writer.indent = ' '
|
| 340 |
-
writer.comma_first = False
|
| 341 |
formatted_bibtex = writer.write(bib_database).strip()
|
| 342 |
|
| 343 |
papers.append({
|
|
@@ -364,7 +405,10 @@ class CitationGenerator:
|
|
| 364 |
logger.error(f"Error searching CrossRef: {e}")
|
| 365 |
return []
|
| 366 |
|
| 367 |
-
def _generate_unique_bibtex_key(self, entry: Dict, existing_keys: set) -> str:
|
|
|
|
|
|
|
|
|
|
| 368 |
entry_type = entry.get('ENTRYTYPE', '').lower()
|
| 369 |
author_field = entry.get('author', '')
|
| 370 |
year = entry.get('year', '')
|
|
@@ -373,10 +417,10 @@ class CitationGenerator:
|
|
| 373 |
|
| 374 |
if entry_type == 'inbook':
|
| 375 |
booktitle = entry.get('booktitle', '')
|
| 376 |
-
title_word = re.sub(r'\W+', '', booktitle.split()[0]) if booktitle else 'untitled'
|
| 377 |
else:
|
| 378 |
title = entry.get('title', '')
|
| 379 |
-
title_word = re.sub(r'\W+', '', title.split()[0]) if title else 'untitled'
|
| 380 |
|
| 381 |
base_key = f"{first_author_last_name}{year}{title_word}"
|
| 382 |
key = base_key
|
|
@@ -387,17 +431,20 @@ class CitationGenerator:
|
|
| 387 |
return key
|
| 388 |
|
| 389 |
async def process_text(self, text: str, num_queries: int, citations_per_query: int,
|
| 390 |
-
use_arxiv: bool = True, use_crossref: bool = True) ->
|
|
|
|
|
|
|
|
|
|
| 391 |
if not (use_arxiv or use_crossref):
|
| 392 |
return "Please select at least one source (ArXiv or CrossRef)", "", ""
|
| 393 |
|
| 394 |
num_queries = min(max(1, num_queries), self.config.max_queries)
|
| 395 |
citations_per_query = min(max(1, citations_per_query), self.config.max_citations_per_query)
|
| 396 |
|
| 397 |
-
async def generate_queries_tool(input_data:
|
| 398 |
return await self.generate_queries(input_data["text"], input_data["num_queries"])
|
| 399 |
|
| 400 |
-
async def search_papers_tool(input_data:
|
| 401 |
queries = input_data["queries"]
|
| 402 |
papers = []
|
| 403 |
async with self.async_context as session:
|
|
@@ -411,7 +458,7 @@ class CitationGenerator:
|
|
| 411 |
for r in results:
|
| 412 |
if not isinstance(r, Exception):
|
| 413 |
papers.extend(r)
|
| 414 |
-
#
|
| 415 |
unique_papers = []
|
| 416 |
seen_keys = set()
|
| 417 |
for p in papers:
|
|
@@ -420,7 +467,7 @@ class CitationGenerator:
|
|
| 420 |
unique_papers.append(p)
|
| 421 |
return unique_papers
|
| 422 |
|
| 423 |
-
async def cite_text_tool(input_data:
|
| 424 |
try:
|
| 425 |
citation_input = {
|
| 426 |
"text": input_data["text"],
|
|
@@ -430,7 +477,6 @@ class CitationGenerator:
|
|
| 430 |
response = await self.llm.ainvoke(prompt)
|
| 431 |
cited_text = response.content.strip()
|
| 432 |
|
| 433 |
-
# Aggregate BibTeX entries
|
| 434 |
bib_database = BibDatabase()
|
| 435 |
for p in input_data["papers"]:
|
| 436 |
if 'bibtex_entry' in p:
|
|
@@ -439,16 +485,14 @@ class CitationGenerator:
|
|
| 439 |
bib_database.entries.append(bib_db.entries[0])
|
| 440 |
else:
|
| 441 |
logger.warning(f"Empty BibTeX entry for key: {p['bibtex_key']}")
|
| 442 |
-
writer =
|
| 443 |
-
writer.indent = ' '
|
| 444 |
-
writer.comma_first = False
|
| 445 |
bibtex_entries = writer.write(bib_database).strip()
|
| 446 |
return cited_text, bibtex_entries
|
| 447 |
except Exception as e:
|
| 448 |
logger.error(f"Error inserting citations: {e}")
|
| 449 |
return input_data["text"], ""
|
| 450 |
|
| 451 |
-
async def agent_run(input_data:
|
| 452 |
queries = await generate_queries_tool(input_data)
|
| 453 |
papers = await search_papers_tool({
|
| 454 |
"queries": queries,
|
|
@@ -473,9 +517,13 @@ class CitationGenerator:
|
|
| 473 |
})
|
| 474 |
return final_text, final_bibtex, final_queries
|
| 475 |
|
|
|
|
| 476 |
def create_gradio_interface() -> gr.Interface:
|
|
|
|
|
|
|
|
|
|
| 477 |
async def process(api_key: str, text: str, num_queries: int, citations_per_query: int,
|
| 478 |
-
|
| 479 |
if not api_key.strip():
|
| 480 |
return "Please enter your Gemini API Key.", "", ""
|
| 481 |
if not text.strip():
|
|
@@ -494,14 +542,14 @@ def create_gradio_interface() -> gr.Interface:
|
|
| 494 |
|
| 495 |
css = """
|
| 496 |
:root {
|
| 497 |
-
/* Modern
|
| 498 |
--primary-bg: #F8F9FA;
|
| 499 |
--secondary-bg: #FFFFFF;
|
| 500 |
-
--accent-1: #4A90E2;
|
| 501 |
-
--accent-2: #50C878;
|
| 502 |
-
--accent-3: #F5B041;
|
| 503 |
-
--text-primary: #2C3E50;
|
| 504 |
-
--text-secondary: #566573;
|
| 505 |
--border: #E5E7E9;
|
| 506 |
--shadow: rgba(0, 0, 0, 0.1);
|
| 507 |
}
|
|
@@ -690,6 +738,7 @@ def create_gradio_interface() -> gr.Interface:
|
|
| 690 |
|
| 691 |
return demo
|
| 692 |
|
|
|
|
| 693 |
if __name__ == "__main__":
|
| 694 |
demo = create_gradio_interface()
|
| 695 |
try:
|
|
|
|
| 4 |
import re
|
| 5 |
import xml.etree.ElementTree as ET
|
| 6 |
from dataclasses import dataclass, field
|
| 7 |
+
from typing import Dict, List, Optional, Any, Tuple
|
| 8 |
import sys
|
| 9 |
from loguru import logger
|
| 10 |
|
|
|
|
| 12 |
import gradio as gr
|
| 13 |
|
| 14 |
from langchain.prompts import PromptTemplate
|
|
|
|
| 15 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 16 |
|
| 17 |
import bibtexparser
|
| 18 |
from bibtexparser.bwriter import BibTexWriter
|
| 19 |
from bibtexparser.bibdatabase import BibDatabase
|
| 20 |
|
| 21 |
+
|
| 22 |
+
def get_bibtex_writer() -> BibTexWriter:
|
| 23 |
+
"""
|
| 24 |
+
Create and return a configured BibTexWriter instance.
|
| 25 |
+
"""
|
| 26 |
+
writer = BibTexWriter()
|
| 27 |
+
writer.indent = ' '
|
| 28 |
+
writer.comma_first = False
|
| 29 |
+
return writer
|
| 30 |
+
|
| 31 |
+
|
| 32 |
@dataclass
|
| 33 |
class Config:
|
| 34 |
gemini_api_key: str
|
|
|
|
| 38 |
max_citations_per_query: int = 10
|
| 39 |
arxiv_base_url: str = 'http://export.arxiv.org/api/query?'
|
| 40 |
crossref_base_url: str = 'https://api.crossref.org/works'
|
| 41 |
+
default_headers: Dict[str, str] = field(default_factory=lambda: {
|
| 42 |
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
|
| 43 |
})
|
| 44 |
log_level: str = 'DEBUG'
|
| 45 |
|
| 46 |
+
|
| 47 |
class ArxivXmlParser:
|
| 48 |
+
"""
|
| 49 |
+
Class to parse ArXiv XML responses.
|
| 50 |
+
"""
|
| 51 |
NS = {
|
| 52 |
'atom': 'http://www.w3.org/2005/Atom',
|
| 53 |
'arxiv': 'http://arxiv.org/schemas/atom'
|
| 54 |
}
|
| 55 |
|
| 56 |
+
def parse_papers(self, data: str) -> List[Dict[str, Any]]:
|
| 57 |
+
"""
|
| 58 |
+
Parse ArXiv XML data and return a list of paper dictionaries.
|
| 59 |
+
"""
|
| 60 |
try:
|
| 61 |
root = ET.fromstring(data)
|
| 62 |
papers = []
|
|
|
|
| 69 |
logger.error(f"Error parsing ArXiv XML: {e}")
|
| 70 |
return []
|
| 71 |
|
| 72 |
+
def parse_entry(self, entry: ET.Element) -> Optional[Dict[str, Any]]:
|
| 73 |
+
"""
|
| 74 |
+
Parse a single ArXiv entry element and return a dictionary with paper details.
|
| 75 |
+
"""
|
| 76 |
try:
|
| 77 |
title_node = entry.find('atom:title', self.NS)
|
| 78 |
if title_node is None:
|
| 79 |
return None
|
| 80 |
+
title = title_node.text.strip() if title_node.text else ""
|
| 81 |
|
| 82 |
authors = []
|
| 83 |
for author in entry.findall('atom:author', self.NS):
|
|
|
|
| 86 |
authors.append(self._format_author_name(author_name_node.text.strip()))
|
| 87 |
|
| 88 |
arxiv_id_node = entry.find('atom:id', self.NS)
|
| 89 |
+
if arxiv_id_node is None or not arxiv_id_node.text:
|
| 90 |
return None
|
| 91 |
arxiv_id = arxiv_id_node.text.split('/')[-1]
|
| 92 |
|
| 93 |
published_node = entry.find('atom:published', self.NS)
|
| 94 |
+
year = published_node.text[:4] if (published_node is not None and published_node.text) else "Unknown"
|
| 95 |
|
| 96 |
abstract_node = entry.find('atom:summary', self.NS)
|
| 97 |
+
abstract = abstract_node.text.strip() if (abstract_node is not None and abstract_node.text) else ""
|
| 98 |
|
| 99 |
bibtex_key = f"{authors[0].split(',')[0]}{arxiv_id.replace('.', '')}" if authors else f"unknown{arxiv_id.replace('.', '')}"
|
| 100 |
bibtex_entry = self._generate_bibtex_entry(bibtex_key, title, authors, arxiv_id, year)
|
|
|
|
| 114 |
|
| 115 |
@staticmethod
|
| 116 |
def _format_author_name(author: str) -> str:
|
| 117 |
+
"""
|
| 118 |
+
Format an author name as 'Lastname, Firstname'.
|
| 119 |
+
"""
|
| 120 |
names = author.split()
|
| 121 |
if len(names) > 1:
|
| 122 |
return f"{names[-1]}, {' '.join(names[:-1])}"
|
| 123 |
return author
|
| 124 |
|
| 125 |
def _generate_bibtex_entry(self, key: str, title: str, authors: List[str], arxiv_id: str, year: str) -> str:
|
| 126 |
+
"""
|
| 127 |
+
Generate a BibTeX entry for a paper.
|
| 128 |
+
"""
|
| 129 |
db = BibDatabase()
|
| 130 |
db.entries = [{
|
| 131 |
'ENTRYTYPE': 'article',
|
|
|
|
| 135 |
'journal': f'arXiv preprint arXiv:{arxiv_id}',
|
| 136 |
'year': year
|
| 137 |
}]
|
| 138 |
+
writer = get_bibtex_writer()
|
|
|
|
|
|
|
| 139 |
return writer.write(db).strip()
|
| 140 |
|
| 141 |
+
|
| 142 |
class AsyncContextManager:
|
| 143 |
+
"""
|
| 144 |
+
Asynchronous context manager to handle aiohttp ClientSession.
|
| 145 |
+
"""
|
| 146 |
+
async def __aenter__(self) -> aiohttp.ClientSession:
|
| 147 |
self._session = aiohttp.ClientSession()
|
| 148 |
return self._session
|
| 149 |
|
|
|
|
| 151 |
if self._session:
|
| 152 |
await self._session.close()
|
| 153 |
|
| 154 |
+
|
| 155 |
class CitationGenerator:
|
| 156 |
+
"""
|
| 157 |
+
Class that handles generating citations using AI and searching for academic papers.
|
| 158 |
+
"""
|
| 159 |
+
def __init__(self, config: Config) -> None:
|
| 160 |
self.config = config
|
| 161 |
self.xml_parser = ArxivXmlParser()
|
| 162 |
self.async_context = AsyncContextManager()
|
| 163 |
self.llm = ChatGoogleGenerativeAI(
|
| 164 |
+
model="gemini-2.0-flash",
|
| 165 |
temperature=0.3,
|
| 166 |
google_api_key=config.gemini_api_key,
|
| 167 |
streaming=True
|
|
|
|
| 197 |
logger.add(sys.stderr, level=config.log_level)
|
| 198 |
|
| 199 |
async def generate_queries(self, text: str, num_queries: int) -> List[str]:
|
| 200 |
+
"""
|
| 201 |
+
Generate a list of academic search queries from the input text.
|
| 202 |
+
"""
|
| 203 |
input_map = {
|
| 204 |
"text": text,
|
| 205 |
"num_queries": num_queries
|
|
|
|
| 221 |
lines = [line.strip() for line in content.split('\n')
|
| 222 |
if line.strip() and not line.strip().startswith(('[', ']'))]
|
| 223 |
return lines[:num_queries]
|
|
|
|
| 224 |
return ["deep learning neural networks"]
|
|
|
|
| 225 |
except Exception as e:
|
| 226 |
logger.error(f"Error generating queries: {e}")
|
| 227 |
return ["deep learning neural networks"]
|
| 228 |
|
| 229 |
+
async def search_arxiv(self, session: aiohttp.ClientSession, query: str, max_results: int) -> List[Dict[str, Any]]:
|
| 230 |
+
"""
|
| 231 |
+
Search ArXiv for papers matching the query.
|
| 232 |
+
"""
|
| 233 |
try:
|
| 234 |
params = {
|
| 235 |
'search_query': f'all:{urllib.parse.quote(query)}',
|
|
|
|
| 238 |
'sortBy': 'relevance',
|
| 239 |
'sortOrder': 'descending'
|
| 240 |
}
|
| 241 |
+
url = self.config.arxiv_base_url + urllib.parse.urlencode(params)
|
| 242 |
async with session.get(
|
| 243 |
+
url,
|
| 244 |
headers=self.config.default_headers,
|
| 245 |
timeout=30
|
| 246 |
) as response:
|
|
|
|
| 252 |
return []
|
| 253 |
|
| 254 |
async def fix_author_name(self, author: str) -> str:
|
| 255 |
+
"""
|
| 256 |
+
Correct an author name that contains corrupted characters.
|
| 257 |
+
"""
|
| 258 |
if not re.search(r'[�]', author):
|
| 259 |
return author
|
| 260 |
try:
|
| 261 |
prompt = f"""Fix this author name that contains corrupted characters (�):
|
| 262 |
|
| 263 |
+
Name: {author}
|
| 264 |
|
| 265 |
+
Requirements:
|
| 266 |
+
1. Return ONLY the fixed author name
|
| 267 |
+
2. Use proper diacritical marks for names
|
| 268 |
+
3. Consider common name patterns and languages
|
| 269 |
+
4. If unsure, use the most likely letter
|
| 270 |
+
5. Maintain the format: "Lastname, Firstname"
|
| 271 |
+
"""
|
| 272 |
response = await self.llm.ainvoke(prompt)
|
| 273 |
fixed_name = response.content.strip()
|
| 274 |
return fixed_name if fixed_name else author
|
|
|
|
| 277 |
return author
|
| 278 |
|
| 279 |
async def format_bibtex_author_names(self, text: str) -> str:
|
| 280 |
+
"""
|
| 281 |
+
Clean and format author names in a BibTeX string.
|
| 282 |
+
"""
|
| 283 |
try:
|
| 284 |
bib_database = bibtexparser.loads(text)
|
| 285 |
for entry in bib_database.entries:
|
|
|
|
| 290 |
fixed_author = await self.fix_author_name(author)
|
| 291 |
cleaned_authors.append(fixed_author)
|
| 292 |
entry['author'] = ' and '.join(cleaned_authors)
|
| 293 |
+
writer = get_bibtex_writer()
|
|
|
|
|
|
|
| 294 |
return writer.write(bib_database).strip()
|
| 295 |
except Exception as e:
|
| 296 |
logger.error(f"Error cleaning BibTeX special characters: {e}")
|
| 297 |
return text
|
| 298 |
|
| 299 |
+
async def search_crossref(self, session: aiohttp.ClientSession, query: str, max_results: int) -> List[Dict[str, Any]]:
|
| 300 |
+
"""
|
| 301 |
+
Search CrossRef for papers matching the query.
|
| 302 |
+
"""
|
| 303 |
try:
|
| 304 |
cleaned_query = query.replace("'", "").replace('"', "")
|
| 305 |
if ' ' in cleaned_query:
|
|
|
|
| 360 |
continue
|
| 361 |
|
| 362 |
bibtex_text = await bibtex_response.text()
|
|
|
|
| 363 |
bib_database = bibtexparser.loads(bibtex_text)
|
| 364 |
if not bib_database.entries:
|
| 365 |
continue
|
|
|
|
| 378 |
entry['ID'] = key
|
| 379 |
existing_keys.add(key)
|
| 380 |
|
| 381 |
+
writer = get_bibtex_writer()
|
|
|
|
|
|
|
| 382 |
formatted_bibtex = writer.write(bib_database).strip()
|
| 383 |
|
| 384 |
papers.append({
|
|
|
|
| 405 |
logger.error(f"Error searching CrossRef: {e}")
|
| 406 |
return []
|
| 407 |
|
| 408 |
+
def _generate_unique_bibtex_key(self, entry: Dict[str, Any], existing_keys: set) -> str:
|
| 409 |
+
"""
|
| 410 |
+
Generate a unique BibTeX key for an entry.
|
| 411 |
+
"""
|
| 412 |
entry_type = entry.get('ENTRYTYPE', '').lower()
|
| 413 |
author_field = entry.get('author', '')
|
| 414 |
year = entry.get('year', '')
|
|
|
|
| 417 |
|
| 418 |
if entry_type == 'inbook':
|
| 419 |
booktitle = entry.get('booktitle', '')
|
| 420 |
+
title_word = re.sub(r'\W+', '', booktitle.split()[0]) if booktitle.split() else 'untitled'
|
| 421 |
else:
|
| 422 |
title = entry.get('title', '')
|
| 423 |
+
title_word = re.sub(r'\W+', '', title.split()[0]) if title.split() else 'untitled'
|
| 424 |
|
| 425 |
base_key = f"{first_author_last_name}{year}{title_word}"
|
| 426 |
key = base_key
|
|
|
|
| 431 |
return key
|
| 432 |
|
| 433 |
async def process_text(self, text: str, num_queries: int, citations_per_query: int,
|
| 434 |
+
use_arxiv: bool = True, use_crossref: bool = True) -> Tuple[str, str, str]:
|
| 435 |
+
"""
|
| 436 |
+
Process the input text to generate citations and corresponding BibTeX entries.
|
| 437 |
+
"""
|
| 438 |
if not (use_arxiv or use_crossref):
|
| 439 |
return "Please select at least one source (ArXiv or CrossRef)", "", ""
|
| 440 |
|
| 441 |
num_queries = min(max(1, num_queries), self.config.max_queries)
|
| 442 |
citations_per_query = min(max(1, citations_per_query), self.config.max_citations_per_query)
|
| 443 |
|
| 444 |
+
async def generate_queries_tool(input_data: Dict[str, Any]) -> List[str]:
|
| 445 |
return await self.generate_queries(input_data["text"], input_data["num_queries"])
|
| 446 |
|
| 447 |
+
async def search_papers_tool(input_data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
| 448 |
queries = input_data["queries"]
|
| 449 |
papers = []
|
| 450 |
async with self.async_context as session:
|
|
|
|
| 458 |
for r in results:
|
| 459 |
if not isinstance(r, Exception):
|
| 460 |
papers.extend(r)
|
| 461 |
+
# Remove duplicate papers
|
| 462 |
unique_papers = []
|
| 463 |
seen_keys = set()
|
| 464 |
for p in papers:
|
|
|
|
| 467 |
unique_papers.append(p)
|
| 468 |
return unique_papers
|
| 469 |
|
| 470 |
+
async def cite_text_tool(input_data: Dict[str, Any]) -> Tuple[str, str]:
|
| 471 |
try:
|
| 472 |
citation_input = {
|
| 473 |
"text": input_data["text"],
|
|
|
|
| 477 |
response = await self.llm.ainvoke(prompt)
|
| 478 |
cited_text = response.content.strip()
|
| 479 |
|
|
|
|
| 480 |
bib_database = BibDatabase()
|
| 481 |
for p in input_data["papers"]:
|
| 482 |
if 'bibtex_entry' in p:
|
|
|
|
| 485 |
bib_database.entries.append(bib_db.entries[0])
|
| 486 |
else:
|
| 487 |
logger.warning(f"Empty BibTeX entry for key: {p['bibtex_key']}")
|
| 488 |
+
writer = get_bibtex_writer()
|
|
|
|
|
|
|
| 489 |
bibtex_entries = writer.write(bib_database).strip()
|
| 490 |
return cited_text, bibtex_entries
|
| 491 |
except Exception as e:
|
| 492 |
logger.error(f"Error inserting citations: {e}")
|
| 493 |
return input_data["text"], ""
|
| 494 |
|
| 495 |
+
async def agent_run(input_data: Dict[str, Any]) -> Tuple[str, str, str]:
|
| 496 |
queries = await generate_queries_tool(input_data)
|
| 497 |
papers = await search_papers_tool({
|
| 498 |
"queries": queries,
|
|
|
|
| 517 |
})
|
| 518 |
return final_text, final_bibtex, final_queries
|
| 519 |
|
| 520 |
+
|
| 521 |
def create_gradio_interface() -> gr.Interface:
|
| 522 |
+
"""
|
| 523 |
+
Create and return a Gradio interface for the citation generator.
|
| 524 |
+
"""
|
| 525 |
async def process(api_key: str, text: str, num_queries: int, citations_per_query: int,
|
| 526 |
+
use_arxiv: bool, use_crossref: bool) -> Tuple[str, str, str]:
|
| 527 |
if not api_key.strip():
|
| 528 |
return "Please enter your Gemini API Key.", "", ""
|
| 529 |
if not text.strip():
|
|
|
|
| 542 |
|
| 543 |
css = """
|
| 544 |
:root {
|
| 545 |
+
/* Modern color palette */
|
| 546 |
--primary-bg: #F8F9FA;
|
| 547 |
--secondary-bg: #FFFFFF;
|
| 548 |
+
--accent-1: #4A90E2;
|
| 549 |
+
--accent-2: #50C878;
|
| 550 |
+
--accent-3: #F5B041;
|
| 551 |
+
--text-primary: #2C3E50;
|
| 552 |
+
--text-secondary: #566573;
|
| 553 |
--border: #E5E7E9;
|
| 554 |
--shadow: rgba(0, 0, 0, 0.1);
|
| 555 |
}
|
|
|
|
| 738 |
|
| 739 |
return demo
|
| 740 |
|
| 741 |
+
|
| 742 |
if __name__ == "__main__":
|
| 743 |
demo = create_gradio_interface()
|
| 744 |
try:
|