ArxivKit / src /streamlit_app.py
anand004's picture
add warnings
ff508a7
import asyncio
import streamlit as st
from arxivkit import ParserConfig, SectionDepth, build_section_paths
from arxivkit.db import SessionLocal, init_db
from arxivkit.db import crud
from arxivkit.db.crud import SYSTEM_USER_ID
from arxivkit.downloader import extract_paper_id
from arxivkit.models import PaperStructure, Section
from arxivkit.pipeline import process_paper
# ── Page config ──────────────────────────────────────────────────────────────
st.set_page_config(
page_title="ArxivKit Explorer",
page_icon="πŸ“„",
layout="wide",
initial_sidebar_state="expanded",
)
# ── Custom CSS ────────────────────────────────────────────────────────────────
st.markdown(
"""
<style>
.streamlit-expanderContent { padding: 0.5rem 1rem 1rem 1rem; }
.empty-state { color:#999; font-size:0.85rem; font-style:italic; padding:0.4rem 0; }
.abstract-box {
background: #f8f9fa;
border-left: 4px solid #4a90d9;
border-radius: 0 8px 8px 0;
padding: 0.8rem 1.1rem;
margin: 0.5rem 0 1.2rem 0;
font-size: 0.95rem;
line-height: 1.65;
color: #333;
}
.author-pill {
display: inline-block;
background: #eef2ff;
color: #3730a3;
border-radius: 999px;
padding: 2px 10px;
font-size: 0.8rem;
margin: 2px 3px;
}
</style>
""",
unsafe_allow_html=True,
)
# ── DB init ───────────────────────────────────────────────────────────────────
_DB_READY = False
async def _ensure_db() -> None:
global _DB_READY
if not _DB_READY:
await init_db()
async with SessionLocal() as session:
await crud.ensure_system_user(session)
_DB_READY = True
# ── Paper loading ─────────────────────────────────────────────────────────────
async def _fetch_paper(paper_id: str) -> PaperStructure:
"""Check DB first; download + parse + save if not found."""
await _ensure_db()
async with SessionLocal() as session:
paper = await crud.load_paper(session, paper_id)
if paper is not None:
return paper
cfg = ParserConfig(
math_in_text=True,
theorems_in_text=True,
convert_tables_to_md=True,
max_section_depth=SectionDepth.SUBPARAGRAPH,
)
paper = process_paper(paper_id, cfg)
async with SessionLocal() as session:
await crud.save_paper(session, paper, user_id=SYSTEM_USER_ID)
return paper
def _load_paper(paper_id: str) -> PaperStructure:
return asyncio.run(_fetch_paper(paper_id))
# ── Render helpers ────────────────────────────────────────────────────────────
_SAMPLE_PAPERS = [
("Attention Is All You Need", "1706.03762"),
("BERT", "1810.04805"),
("Mamba", "2312.00752"),
("LLaMA", "2302.13971"),
("Chain-of-Thought Prompting", "2201.11903"),
]
def _count_recursive(secs: list[Section], attr: str) -> int:
return sum(len(getattr(s, attr)) + _count_recursive(s.subsections, attr) for s in secs)
def _clean_formula(raw: str) -> str:
s = raw.strip()
if s.startswith("$$") and s.endswith("$$"):
return s[2:-2].strip()
if s.startswith("$") and s.endswith("$") and len(s) > 2:
return s[1:-1].strip()
return s
def _render_formula(raw: str) -> None:
latex = _clean_formula(raw)
if not latex:
return
try:
st.latex(latex)
except Exception:
st.code(raw, language="latex")
def _render_section_content(section: Section, depth: int = 0) -> None:
tab_text, tab_formulas, tab_tables = st.tabs(["πŸ“ Text", "βˆ‘ Formulas", "⊞ Tables"])
with tab_text:
if section.text.strip():
st.markdown(section.text)
else:
st.markdown('<p class="empty-state">No text content in this section.</p>', unsafe_allow_html=True)
with tab_formulas:
if section.formulas:
for i, formula in enumerate(section.formulas, 1):
with st.expander(f"Formula {i}", expanded=(len(section.formulas) == 1)):
_render_formula(formula)
else:
st.markdown('<p class="empty-state">No formulas in this section.</p>', unsafe_allow_html=True)
with tab_tables:
if section.tables:
for i, table in enumerate(section.tables, 1):
label = f"Table {i}" + (f": {table.caption}" if table.caption else "")
with st.expander(label, expanded=False):
if table.caption:
st.markdown(f"**Caption:** {table.caption}")
if table.content.strip().startswith("|"):
st.markdown(table.content)
else:
st.code(table.content, language="latex")
else:
st.markdown('<p class="empty-state">No tables in this section.</p>', unsafe_allow_html=True)
if section.subsections:
st.markdown("---")
for sub in section.subsections:
_render_section(sub, depth=depth + 1)
def _render_section(section: Section, depth: int = 0) -> None:
with st.expander(section.title, expanded=False):
_render_section_content(section, depth=depth)
def _paper_json(paper: PaperStructure) -> dict:
"""Serialise paper to dict with the `images` key stripped from every section."""
def strip_images(s: dict) -> dict:
s.pop("images", None)
s["subsections"] = [strip_images(sub) for sub in s.get("subsections", [])]
return s
data = paper.model_dump()
data["sections"] = [strip_images(s) for s in data.get("sections", [])]
return data
def _render_paper(paper: PaperStructure) -> None:
st.markdown(f"## {paper.title}")
if paper.authors:
authors_html = " ".join(f'<span class="author-pill">{a}</span>' for a in paper.authors)
st.markdown(authors_html, unsafe_allow_html=True)
st.markdown(
f'<div class="abstract-box"><strong>Abstract</strong><br>{paper.abstract}</div>',
unsafe_allow_html=True,
)
m1, m2, m3 = st.columns(3)
m1.metric("Sections", len(paper.sections))
m2.metric("Formulas", _count_recursive(paper.sections, "formulas"))
m3.metric("Tables", _count_recursive(paper.sections, "tables"))
st.divider()
tab_explore, tab_sections, tab_json = st.tabs(["πŸ“„ Explore", "πŸ—‚ Sections", "{ } JSON"])
with tab_explore:
if paper.sections:
for section in paper.sections:
_render_section(section)
else:
st.info("No sections were extracted from this paper.")
with tab_sections:
for path in build_section_paths(paper.sections):
st.text(path)
with tab_json:
import json
st.code(json.dumps(_paper_json(paper), indent=2), language="json")
# ── Sidebar ───────────────────────────────────────────────────────────────────
with st.sidebar:
st.header("πŸ§ͺ Try a sample")
st.caption("Click any paper to load it.")
for title, paper_id in _SAMPLE_PAPERS:
if st.button(title, key=f"sample_{paper_id}", use_container_width=True):
st.session_state["requested_id"] = paper_id
st.session_state.pop("loaded_paper", None)
st.session_state.pop("load_error", None)
st.rerun()
st.divider()
st.caption("Powered by ArxivKit")
st.caption(
"⚠️ Results depend on arXiv API availability and LaTeX source quality. "
"Parsing errors and incomplete output are expected in this experimental release."
)
# ── Main app ──────────────────────────────────────────────────────────────────
st.title("πŸ“„ ArxivKit (Structured Arxiv Data Extractor)")
st.caption("Paste an arXiv paper URL or ID to explore its structure.")
st.warning(
"**Experimental & Unstable Release** β€” This tool is heavily dependent on the "
"arXiv API and LaTeX source availability. Not all papers will parse correctly, "
"and some results may include extraction errors, garbled text, or missing content. "
"Do not rely on the output blindly β€” always verify against the original paper.",
icon="⚠️",
)
st.info(
"**Coming soon:** Direct upload of LaTeX source files (.tex / .tar.gz) for local parsing β€” "
"no arXiv API dependency required.",
icon="πŸ”œ",
)
col_input, col_btn = st.columns([5, 1])
with col_input:
raw_input = st.text_input(
label="Paper",
placeholder="e.g. 2301.07041 or https://arxiv.org/abs/2301.07041",
label_visibility="collapsed",
)
with col_btn:
parse_clicked = st.button("Parse", type="primary", use_container_width=True)
# ── State transitions ─────────────────────────────────────────────────────────
if parse_clicked and raw_input.strip():
try:
new_id = extract_paper_id(raw_input.strip())
except ValueError:
st.error(f"Could not parse an arXiv ID from: `{raw_input.strip()}`")
st.stop()
if st.session_state.get("requested_id") != new_id:
st.session_state.pop("loaded_paper", None)
st.session_state.pop("load_error", None)
st.session_state["requested_id"] = new_id
# ── Loading phase ─────────────────────────────────────────────────────────────
if "requested_id" in st.session_state and "loaded_paper" not in st.session_state and "load_error" not in st.session_state:
paper_id = st.session_state["requested_id"]
with st.spinner(f"Fetching and parsing **{paper_id}** …"):
try:
paper = _load_paper(paper_id)
st.session_state["loaded_paper"] = paper
except Exception as exc:
st.session_state["load_error"] = str(exc)
st.rerun()
# ── Display phase ─────────────────────────────────────────────────────────────
if "load_error" in st.session_state:
st.error(f"Failed to load paper `{st.session_state['requested_id']}`: {st.session_state['load_error']}")
elif "loaded_paper" in st.session_state:
_render_paper(st.session_state["loaded_paper"])