import asyncio import streamlit as st from arxivkit import ParserConfig, SectionDepth, build_section_paths from arxivkit.db import SessionLocal, init_db from arxivkit.db import crud from arxivkit.db.crud import SYSTEM_USER_ID from arxivkit.downloader import extract_paper_id from arxivkit.models import PaperStructure, Section from arxivkit.pipeline import process_paper # ── Page config ────────────────────────────────────────────────────────────── st.set_page_config( page_title="ArxivKit Explorer", page_icon="📄", layout="wide", initial_sidebar_state="expanded", ) # ── Custom CSS ──────────────────────────────────────────────────────────────── st.markdown( """ """, unsafe_allow_html=True, ) # ── DB init ─────────────────────────────────────────────────────────────────── _DB_READY = False async def _ensure_db() -> None: global _DB_READY if not _DB_READY: await init_db() async with SessionLocal() as session: await crud.ensure_system_user(session) _DB_READY = True # ── Paper loading ───────────────────────────────────────────────────────────── async def _fetch_paper(paper_id: str) -> PaperStructure: """Check DB first; download + parse + save if not found.""" await _ensure_db() async with SessionLocal() as session: paper = await crud.load_paper(session, paper_id) if paper is not None: return paper cfg = ParserConfig( math_in_text=True, theorems_in_text=True, convert_tables_to_md=True, max_section_depth=SectionDepth.SUBPARAGRAPH, ) paper = process_paper(paper_id, cfg) async with SessionLocal() as session: await crud.save_paper(session, paper, user_id=SYSTEM_USER_ID) return paper def _load_paper(paper_id: str) -> PaperStructure: return asyncio.run(_fetch_paper(paper_id)) # ── Render helpers ──────────────────────────────────────────────────────────── _SAMPLE_PAPERS = [ ("Attention Is All You Need", "1706.03762"), ("BERT", "1810.04805"), ("Mamba", "2312.00752"), ("LLaMA", "2302.13971"), ("Chain-of-Thought Prompting", "2201.11903"), ] def _count_recursive(secs: list[Section], attr: str) -> int: return sum(len(getattr(s, attr)) + _count_recursive(s.subsections, attr) for s in secs) def _clean_formula(raw: str) -> str: s = raw.strip() if s.startswith("$$") and s.endswith("$$"): return s[2:-2].strip() if s.startswith("$") and s.endswith("$") and len(s) > 2: return s[1:-1].strip() return s def _render_formula(raw: str) -> None: latex = _clean_formula(raw) if not latex: return try: st.latex(latex) except Exception: st.code(raw, language="latex") def _render_section_content(section: Section, depth: int = 0) -> None: tab_text, tab_formulas, tab_tables = st.tabs(["📝 Text", "∑ Formulas", "⊞ Tables"]) with tab_text: if section.text.strip(): st.markdown(section.text) else: st.markdown('

No text content in this section.

', unsafe_allow_html=True) with tab_formulas: if section.formulas: for i, formula in enumerate(section.formulas, 1): with st.expander(f"Formula {i}", expanded=(len(section.formulas) == 1)): _render_formula(formula) else: st.markdown('

No formulas in this section.

', unsafe_allow_html=True) with tab_tables: if section.tables: for i, table in enumerate(section.tables, 1): label = f"Table {i}" + (f": {table.caption}" if table.caption else "") with st.expander(label, expanded=False): if table.caption: st.markdown(f"**Caption:** {table.caption}") if table.content.strip().startswith("|"): st.markdown(table.content) else: st.code(table.content, language="latex") else: st.markdown('

No tables in this section.

', unsafe_allow_html=True) if section.subsections: st.markdown("---") for sub in section.subsections: _render_section(sub, depth=depth + 1) def _render_section(section: Section, depth: int = 0) -> None: with st.expander(section.title, expanded=False): _render_section_content(section, depth=depth) def _paper_json(paper: PaperStructure) -> dict: """Serialise paper to dict with the `images` key stripped from every section.""" def strip_images(s: dict) -> dict: s.pop("images", None) s["subsections"] = [strip_images(sub) for sub in s.get("subsections", [])] return s data = paper.model_dump() data["sections"] = [strip_images(s) for s in data.get("sections", [])] return data def _render_paper(paper: PaperStructure) -> None: st.markdown(f"## {paper.title}") if paper.authors: authors_html = " ".join(f'{a}' for a in paper.authors) st.markdown(authors_html, unsafe_allow_html=True) st.markdown( f'
Abstract
{paper.abstract}
', unsafe_allow_html=True, ) m1, m2, m3 = st.columns(3) m1.metric("Sections", len(paper.sections)) m2.metric("Formulas", _count_recursive(paper.sections, "formulas")) m3.metric("Tables", _count_recursive(paper.sections, "tables")) st.divider() tab_explore, tab_sections, tab_json = st.tabs(["📄 Explore", "🗂 Sections", "{ } JSON"]) with tab_explore: if paper.sections: for section in paper.sections: _render_section(section) else: st.info("No sections were extracted from this paper.") with tab_sections: for path in build_section_paths(paper.sections): st.text(path) with tab_json: import json st.code(json.dumps(_paper_json(paper), indent=2), language="json") # ── Sidebar ─────────────────────────────────────────────────────────────────── with st.sidebar: st.header("🧪 Try a sample") st.caption("Click any paper to load it.") for title, paper_id in _SAMPLE_PAPERS: if st.button(title, key=f"sample_{paper_id}", use_container_width=True): st.session_state["requested_id"] = paper_id st.session_state.pop("loaded_paper", None) st.session_state.pop("load_error", None) st.rerun() st.divider() st.caption("Powered by ArxivKit") st.caption( "⚠️ Results depend on arXiv API availability and LaTeX source quality. " "Parsing errors and incomplete output are expected in this experimental release." ) # ── Main app ────────────────────────────────────────────────────────────────── st.title("📄 ArxivKit (Structured Arxiv Data Extractor)") st.caption("Paste an arXiv paper URL or ID to explore its structure.") st.warning( "**Experimental & Unstable Release** — This tool is heavily dependent on the " "arXiv API and LaTeX source availability. Not all papers will parse correctly, " "and some results may include extraction errors, garbled text, or missing content. " "Do not rely on the output blindly — always verify against the original paper.", icon="⚠️", ) st.info( "**Coming soon:** Direct upload of LaTeX source files (.tex / .tar.gz) for local parsing — " "no arXiv API dependency required.", icon="🔜", ) col_input, col_btn = st.columns([5, 1]) with col_input: raw_input = st.text_input( label="Paper", placeholder="e.g. 2301.07041 or https://arxiv.org/abs/2301.07041", label_visibility="collapsed", ) with col_btn: parse_clicked = st.button("Parse", type="primary", use_container_width=True) # ── State transitions ───────────────────────────────────────────────────────── if parse_clicked and raw_input.strip(): try: new_id = extract_paper_id(raw_input.strip()) except ValueError: st.error(f"Could not parse an arXiv ID from: `{raw_input.strip()}`") st.stop() if st.session_state.get("requested_id") != new_id: st.session_state.pop("loaded_paper", None) st.session_state.pop("load_error", None) st.session_state["requested_id"] = new_id # ── Loading phase ───────────────────────────────────────────────────────────── if "requested_id" in st.session_state and "loaded_paper" not in st.session_state and "load_error" not in st.session_state: paper_id = st.session_state["requested_id"] with st.spinner(f"Fetching and parsing **{paper_id}** …"): try: paper = _load_paper(paper_id) st.session_state["loaded_paper"] = paper except Exception as exc: st.session_state["load_error"] = str(exc) st.rerun() # ── Display phase ───────────────────────────────────────────────────────────── if "load_error" in st.session_state: st.error(f"Failed to load paper `{st.session_state['requested_id']}`: {st.session_state['load_error']}") elif "loaded_paper" in st.session_state: _render_paper(st.session_state["loaded_paper"])