import asyncio import streamlit as st from arxivkit import ParserConfig, SectionDepth, build_section_paths from arxivkit.db import SessionLocal, init_db from arxivkit.db import crud from arxivkit.db.crud import SYSTEM_USER_ID from arxivkit.downloader import extract_paper_id from arxivkit.models import PaperStructure, Section from arxivkit.pipeline import process_paper # ── Page config ────────────────────────────────────────────────────────────── st.set_page_config( page_title="ArxivKit Explorer", page_icon="📄", layout="wide", initial_sidebar_state="expanded", ) # ── Custom CSS ──────────────────────────────────────────────────────────────── st.markdown( """ """, unsafe_allow_html=True, ) # ── DB init ─────────────────────────────────────────────────────────────────── _DB_READY = False async def _ensure_db() -> None: global _DB_READY if not _DB_READY: await init_db() async with SessionLocal() as session: await crud.ensure_system_user(session) _DB_READY = True # ── Paper loading ───────────────────────────────────────────────────────────── async def _fetch_paper(paper_id: str) -> PaperStructure: """Check DB first; download + parse + save if not found.""" await _ensure_db() async with SessionLocal() as session: paper = await crud.load_paper(session, paper_id) if paper is not None: return paper cfg = ParserConfig( math_in_text=True, theorems_in_text=True, convert_tables_to_md=True, max_section_depth=SectionDepth.SUBPARAGRAPH, ) paper = process_paper(paper_id, cfg) async with SessionLocal() as session: await crud.save_paper(session, paper, user_id=SYSTEM_USER_ID) return paper def _load_paper(paper_id: str) -> PaperStructure: return asyncio.run(_fetch_paper(paper_id)) # ── Render helpers ──────────────────────────────────────────────────────────── _SAMPLE_PAPERS = [ ("Attention Is All You Need", "1706.03762"), ("BERT", "1810.04805"), ("Mamba", "2312.00752"), ("LLaMA", "2302.13971"), ("Chain-of-Thought Prompting", "2201.11903"), ] def _count_recursive(secs: list[Section], attr: str) -> int: return sum(len(getattr(s, attr)) + _count_recursive(s.subsections, attr) for s in secs) def _clean_formula(raw: str) -> str: s = raw.strip() if s.startswith("$$") and s.endswith("$$"): return s[2:-2].strip() if s.startswith("$") and s.endswith("$") and len(s) > 2: return s[1:-1].strip() return s def _render_formula(raw: str) -> None: latex = _clean_formula(raw) if not latex: return try: st.latex(latex) except Exception: st.code(raw, language="latex") def _render_section_content(section: Section, depth: int = 0) -> None: tab_text, tab_formulas, tab_tables = st.tabs(["📝 Text", "∑ Formulas", "⊞ Tables"]) with tab_text: if section.text.strip(): st.markdown(section.text) else: st.markdown('
No text content in this section.
', unsafe_allow_html=True) with tab_formulas: if section.formulas: for i, formula in enumerate(section.formulas, 1): with st.expander(f"Formula {i}", expanded=(len(section.formulas) == 1)): _render_formula(formula) else: st.markdown('No formulas in this section.
', unsafe_allow_html=True) with tab_tables: if section.tables: for i, table in enumerate(section.tables, 1): label = f"Table {i}" + (f": {table.caption}" if table.caption else "") with st.expander(label, expanded=False): if table.caption: st.markdown(f"**Caption:** {table.caption}") if table.content.strip().startswith("|"): st.markdown(table.content) else: st.code(table.content, language="latex") else: st.markdown('No tables in this section.
', unsafe_allow_html=True) if section.subsections: st.markdown("---") for sub in section.subsections: _render_section(sub, depth=depth + 1) def _render_section(section: Section, depth: int = 0) -> None: with st.expander(section.title, expanded=False): _render_section_content(section, depth=depth) def _paper_json(paper: PaperStructure) -> dict: """Serialise paper to dict with the `images` key stripped from every section.""" def strip_images(s: dict) -> dict: s.pop("images", None) s["subsections"] = [strip_images(sub) for sub in s.get("subsections", [])] return s data = paper.model_dump() data["sections"] = [strip_images(s) for s in data.get("sections", [])] return data def _render_paper(paper: PaperStructure) -> None: st.markdown(f"## {paper.title}") if paper.authors: authors_html = " ".join(f'' for a in paper.authors) st.markdown(authors_html, unsafe_allow_html=True) st.markdown( f'