foto / app.py
Hurum Maksora Tohfa
Change batch_size to score_batch_size in FigureScorer
bdecc37 unverified
import sys
import time
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
import streamlit as st
from foto import (
MODEL_LABELS, MODEL_REGISTRY, get_model, CostTracker,
InputParser, PaperSearcher, PaperTriager,
PDFStore, FigureExtractor, FigureScorer,
build_zip, format_authors, get_confidence, confidence_badge_class,
)
from foto.models import PROVIDER_DISPLAY
from foto.llm_client import LLMClient
from foto.persistence import load_stats, log_search, log_rating
st.set_page_config(
page_title="FOTO",
page_icon="🐇",
layout="wide",
initial_sidebar_state="collapsed",
)
st.markdown("""
<style>
@import url('https://fonts.googleapis.com/css2?family=DM+Serif+Display:ital@0;1&family=DM+Mono:wght@400;500&family=Inter:wght@300;400;500&display=swap');
html, body, [class*="css"] { font-family: 'Inter', sans-serif; font-weight: 300; }
.stApp { background: #fafaf8; color: #1a1a1a; }
.foto-header { padding: 3rem 0 1.5rem 0; border-bottom: 1px solid #e0e0d8; margin-bottom: 2.5rem; }
.foto-subtitle { font-size: 0.85rem; color: #888; letter-spacing: 0.08em; text-transform: uppercase; margin-top: 0.4rem; font-family: 'DM Mono', monospace; }
.foto-tagline { font-size: 1rem; color: #555; margin-top: 0.8rem; font-weight: 300; max-width: 560px; }
.section-label { font-family: 'DM Mono', monospace; font-size: 0.7rem; letter-spacing: 0.12em; text-transform: uppercase; color: #999; margin-bottom: 0.5rem; }
.api-help { font-size: 0.75rem; color: #777; margin-top: -0.5rem; margin-bottom: 0.8rem; }
.api-help a { color: #4a5568; text-decoration: underline; }
.result-card { background: white; border: 1px solid #e8e8e0; border-radius: 4px; padding: 1.2rem 1.4rem; margin-bottom: 1.2rem; }
.result-title { font-family: 'DM Serif Display', serif; font-size: 1.05rem; color: #1a1a1a; margin-bottom: 0.2rem; line-height: 1.3; }
.result-meta { font-size: 0.8rem; color: #777; font-style: italic; margin-bottom: 0.6rem; }
.result-badges { display: flex; gap: 0.5rem; flex-wrap: wrap; margin-bottom: 0.8rem; }
.badge { font-family: 'DM Mono', monospace; font-size: 0.68rem; padding: 0.15rem 0.5rem; border-radius: 2px; letter-spacing: 0.04em; }
.badge-high { background: #e8f4e8; color: #2d6a2d; }
.badge-mid { background: #fef3e2; color: #8a5700; }
.badge-low { background: #fce8e8; color: #8a2020; }
.badge-type { background: #eef2ff; color: #3d4eac; }
.stats-row { display: flex; gap: 2rem; padding: 1rem 0; border-top: 1px solid #e0e0d8; border-bottom: 1px solid #e0e0d8; margin: 1.5rem 0; }
.stat-item { text-align: center; }
.stat-num { font-family: 'DM Serif Display', serif; font-size: 1.8rem; color: #1a1a1a; line-height: 1; }
.stat-label { font-family: 'DM Mono', monospace; font-size: 0.65rem; color: #999; letter-spacing: 0.1em; text-transform: uppercase; margin-top: 0.2rem; }
.progress-item { font-family: 'DM Mono', monospace; font-size: 0.8rem; color: #555; padding: 0.2rem 0; }
.feedback-box { background: white; border: 1px solid #e8e8e0; border-radius: 4px; padding: 1.5rem 1.8rem; margin-top: 2rem; }
.feedback-title { font-family: 'DM Serif Display', serif; font-size: 1.2rem; margin-bottom: 1rem; }
.tally-box { background: #1a1a1a; color: #fafaf8; padding: 1.2rem 1.8rem; border-radius: 4px; margin-top: 1.5rem; }
.tally-title { font-family: 'DM Mono', monospace; font-size: 0.7rem; letter-spacing: 0.12em; text-transform: uppercase; color: #888; margin-bottom: 0.8rem; }
.tally-row { display: flex; gap: 2rem; }
.tally-num { font-family: 'DM Serif Display', serif; font-size: 2rem; color: #fafaf8; }
.tally-label { font-family: 'DM Mono', monospace; font-size: 0.65rem; color: #888; letter-spacing: 0.08em; margin-top: 0.2rem; }
.pathfinder-cite { font-family: 'DM Mono', monospace; font-size: 0.72rem; color: #888; line-height: 1.5; }
.pathfinder-cite a { color: #4a5568; text-decoration: underline; }
.stTextArea textarea { font-family: 'Inter', sans-serif; font-size: 0.95rem; border: 1px solid #ddd; border-radius: 3px; }
.stButton button { font-family: 'DM Mono', monospace; font-size: 0.8rem; letter-spacing: 0.06em; border-radius: 3px; }
div[data-testid="stSelectbox"] label,
div[data-testid="stTextInput"] label { font-family: 'DM Mono', monospace; font-size: 0.75rem; letter-spacing: 0.08em; text-transform: uppercase; color: #888; }
div[data-testid="stCheckbox"] label p {
font-family: 'Inter', sans-serif !important;
font-size: 0.9rem !important;
letter-spacing: normal !important;
text-transform: none !important;
color: #1a1a1a !important;
font-weight: 400 !important;
}
</style>
""", unsafe_allow_html=True)
for key, default in {
"pdf_cache": {},
"results": None,
"running": False,
"log": [],
"tally": {"searches": 0, "ratings": []},
}.items():
if key not in st.session_state:
st.session_state[key] = default
if "global_stats" not in st.session_state:
st.session_state.global_stats = load_stats()
st.markdown("""
<div class="foto-header">
<p class="foto-subtitle">Figure frOm for Text & illustratiOns</p>
<p class="foto-tagline">
Describe a scientific figure in words, upload a sketch, or both
and FOTO searches the literature to find it.
</p>
</div>
""", unsafe_allow_html=True)
col_left, col_right = st.columns([1, 1], gap="large")
def render_api_key_field(provider: str, model_cfg, key_state: str) -> str:
display_name = PROVIDER_DISPLAY.get(provider, "API Key")
st.markdown(f'<p class="section-label" style="margin-top:0.8rem;">{display_name}</p>', unsafe_allow_html=True)
key = st.text_input(
display_name, type="password", label_visibility="collapsed",
placeholder="...", key=key_state,
)
st.markdown(
f'<div class="api-help">{model_cfg.api_help_text} '
f'<a href="{model_cfg.api_help_url}" target="_blank">Get key →</a></div>',
unsafe_allow_html=True,
)
return key
with col_left:
st.markdown('<p class="section-label">Primary Model</p>', unsafe_allow_html=True)
primary_label = st.selectbox(
"Primary Model", options=MODEL_LABELS,
label_visibility="collapsed", key="primary_model",
)
primary_cfg = get_model(primary_label)
primary_key = render_api_key_field(primary_cfg.provider, primary_cfg, "primary_api_key")
use_pathfinder = st.checkbox(
"Use Pathfinder (recommended)",
value=True,
key="use_pathfinder",
)
st.markdown(
'<div class="pathfinder-cite" style="margin-top:-0.4rem;margin-left:1.8rem;">'
'Based on <a href="https://arxiv.org/abs/2408.01556" target="_blank">arXiv:2408.01556</a> · '
'OpenAI key required for query embedding · '
'~$1 per 2M queries'
'</div>',
unsafe_allow_html=True,
)
openai_key = ""
if use_pathfinder:
st.markdown('<p class="section-label" style="margin-top:0.8rem;">OpenAI API Key</p>', unsafe_allow_html=True)
openai_key = st.text_input(
"OpenAI Key", type="password", label_visibility="collapsed",
placeholder="sk-...", key="openai_key",
)
st.markdown(
'<div class="api-help">Used to embed queries with text-embedding-3-small. '
'<a href="https://platform.openai.com/api-keys" target="_blank">Get key →</a></div>',
unsafe_allow_html=True,
)
s2_key = ""
if not use_pathfinder:
st.markdown('<p class="section-label" style="margin-top:0.8rem;">Semantic Scholar Key (optional)</p>', unsafe_allow_html=True)
s2_key = st.text_input(
"S2 Key", type="password", label_visibility="collapsed",
placeholder="(improves keyword search)", key="s2_key",
)
st.markdown(
'<div class="api-help">Optional — speeds up the keyword-based paper search. '
'<a href="https://www.semanticscholar.org/product/api" target="_blank">Get key →</a></div>',
unsafe_allow_html=True,
)
st.markdown('<p class="section-label" style="margin-top:1.2rem;">Describe the figure</p>', unsafe_allow_html=True)
user_text = st.text_area(
"Figure description", label_visibility="collapsed", height=120,
placeholder='e.g. "scatter plot of cosmological parameter constraints from wavelet scattering transform, Omega_m vs sigma_8"',
)
st.markdown('<p class="section-label" style="margin-top:0.8rem;">Upload a sketch (optional)</p>', unsafe_allow_html=True)
sketch_file = st.file_uploader("Sketch", type=["png", "jpg", "jpeg", "webp"], label_visibility="collapsed")
run_verify = st.checkbox("Secondary verification (recommended)", value=True, key="run_verify")
st.markdown(
'<p style="font-size:0.78rem;color:#888;margin-top:-0.6rem;margin-left:1.8rem;">'
'Double-checks top matches. Uses the primary model by default.</p>',
unsafe_allow_html=True,
)
verify_cfg = primary_cfg
verify_key = primary_key
if run_verify:
verify_options = ["Same as primary"] + MODEL_LABELS
verify_choice = st.selectbox(
"Verification model", options=verify_options,
label_visibility="collapsed", key="verify_model",
)
if verify_choice != "Same as primary":
verify_cfg = get_model(verify_choice)
if verify_cfg.provider != primary_cfg.provider:
verify_key = render_api_key_field(
verify_cfg.provider, verify_cfg, "verify_api_key",
)
else:
verify_key = primary_key
num_papers = st.slider("Papers to search", min_value=5, max_value=50, value=20, step=5)
run_btn = st.button("🔭 Search", use_container_width=True, type="primary", disabled=st.session_state.running)
with col_right:
if not st.session_state.results and not st.session_state.running:
st.markdown("""
<div style="padding: 3rem 2rem; color: #aaa; text-align: center;">
<div style="font-size: 3rem; margin-bottom: 1rem;">🔭</div>
<div style="font-family: 'DM Mono', monospace; font-size: 0.75rem; letter-spacing: 0.1em; text-transform: uppercase;">Your search progress will appear here</div>
</div>
""", unsafe_allow_html=True)
if run_btn:
if not primary_key:
st.error(f"Please enter your {PROVIDER_DISPLAY[primary_cfg.provider]}.")
elif use_pathfinder and not openai_key:
st.error("Pathfinder is checked — please enter your OpenAI API key, or uncheck Pathfinder.")
elif run_verify and verify_cfg.provider != primary_cfg.provider and not verify_key:
st.error(f"Verification model needs its own key — please enter your {PROVIDER_DISPLAY[verify_cfg.provider]}.")
elif not user_text and not sketch_file:
st.error("Please enter a description or upload a sketch (or both).")
else:
st.session_state.running = True
st.session_state.results = None
tracker = CostTracker()
primary_client = LLMClient(provider=primary_cfg.provider, api_key=primary_key)
verify_client = (
primary_client if verify_cfg.provider == primary_cfg.provider
else LLMClient(provider=verify_cfg.provider, api_key=verify_key)
)
sketch_bytes = sketch_file.read() if sketch_file else None
with col_right:
log_placeholder = st.empty()
progress_placeholder = st.empty()
def log(msg):
st.session_state.log.append(msg)
log_placeholder.markdown(
"\n".join(f'<div class="progress-item">{m}</div>' for m in st.session_state.log[-20:]),
unsafe_allow_html=True,
)
st.session_state.log = []
try:
log("⟳ Parsing your description...")
parser = InputParser(
primary_client, primary_cfg.model_id, primary_cfg.prices, tracker,
max_tokens=primary_cfg.triage_max_tokens * 4,
)
spec = parser.parse(text=user_text or None, sketch_bytes=sketch_bytes)
query = spec["science_query"] or user_text or "(no query)"
log(f"✓ Query: <em>{query}</em>")
if spec.get("plot_type"):
log(f" Plot type: {spec['plot_type']}")
searcher = PaperSearcher(s2_key=s2_key or None)
if use_pathfinder:
all_papers = searcher.expanded_search_pathfinder(query, openai_key, log=log)
else:
all_papers = searcher.expanded_search(
query, primary_client, primary_cfg.model_id, primary_cfg.prices, tracker,
max_tokens=primary_cfg.triage_max_tokens * 4, log=log,
)
log(f"✓ {len(all_papers)} unique papers found")
log(f"⟳ Triaging papers (please wait)...")
triager = PaperTriager(
primary_client, primary_cfg.model_id, primary_cfg.prices, tracker,
max_tokens=primary_cfg.triage_max_tokens,
batch_size=primary_cfg.score_batch_size,
)
triaged = triager.triage(all_papers, spec)
top = triaged[:num_papers]
log(f"✓ {len(top)} papers passed triage")
paper_lookup = {p["paperId"]: p for p in top}
log("⟳ Fetching PDFs...")
downloaded = []
for i, paper in enumerate(top):
progress_placeholder.progress((i + 1) / len(top), text=f"Fetching PDF {i+1}/{len(top)}")
pdf_bytes, reason = PDFStore.fetch(paper)
if pdf_bytes:
paper["_pdf_bytes"] = pdf_bytes
downloaded.append(paper)
log(f" ✓ {paper.get('title','')[:60]}")
else:
log(f" ✗ {paper.get('title','')[:60]}")
time.sleep(0.5)
progress_placeholder.empty()
log(f"✓ {len(downloaded)} PDFs ready")
log("⟳ Extracting figures...")
extractor = FigureExtractor()
all_figures = []
for paper in downloaded:
try:
figs = extractor.extract(paper["_pdf_bytes"], paper["paperId"])
all_figures.extend(figs)
except Exception as e:
log(f" ✗ {paper.get('title','')[:40]}: {e}")
filtered = extractor.caption_filter(all_figures, query)
log(f" {len(filtered)} figures after caption filter (from {len(all_figures)} total)")
log(f"⟳ Scoring {len(filtered)} figures...")
scorer = FigureScorer(
primary_client, primary_cfg.model_id, primary_cfg.prices, tracker,
score_max_tokens=primary_cfg.score_max_tokens,
verify_max_tokens=primary_cfg.verify_max_tokens,
batch_size=primary_cfg.score_batch_size,
)
results = scorer.score_batch(filtered, spec)
primary_matches = [fig for fig, result in zip(filtered, results)
if result.get("confidence", 0) >= 0.5]
log(f"✓ {len(primary_matches)} primary matches")
verified = primary_matches
if run_verify and primary_matches:
log(f"⟳ Verifying {len(primary_matches)} matches...")
verifier = FigureScorer(
verify_client, verify_cfg.model_id, verify_cfg.prices, tracker,
score_max_tokens=verify_cfg.score_max_tokens,
verify_max_tokens=verify_cfg.verify_max_tokens,
batch_size=1,
)
verified = []
for i, fig in enumerate(primary_matches):
progress_placeholder.progress((i + 1) / len(primary_matches), text=f"Verifying {i+1}/{len(primary_matches)}")
result = verifier.verify(fig, spec)
if result.get("confidence", 0) >= 0.5:
verified.append(fig)
progress_placeholder.empty()
log(f"✓ {len(verified)} verified matches")
verified.sort(key=lambda m: -get_confidence(m))
total_cost = tracker.total()
log(f"✓ Done — ${total_cost:.4f}")
st.session_state.results = {
"matches": verified,
"paper_lookup": paper_lookup,
"spec": spec,
"query": query,
"cost": total_cost,
"n_papers": len(downloaded),
"n_figures": len(all_figures),
}
log_search(query, len(downloaded), len(verified), total_cost)
st.session_state.global_stats["searches"] += 1
except Exception as e:
import traceback
traceback.print_exc()
log(f"✗ Error: {e}")
st.error(f"Pipeline error: {e}")
finally:
st.session_state.running = False
st.rerun()
if st.session_state.results:
res = st.session_state.results
matches = res["matches"]
paper_lookup = res["paper_lookup"]
with col_right:
st.markdown(f"""
<div class="stats-row">
<div class="stat-item"><div class="stat-num">{len(matches)}</div><div class="stat-label">Matches</div></div>
<div class="stat-item"><div class="stat-num">{res['n_papers']}</div><div class="stat-label">Papers</div></div>
<div class="stat-item"><div class="stat-num">{res['n_figures']}</div><div class="stat-label">Figures</div></div>
<div class="stat-item"><div class="stat-num">${res['cost']:.3f}</div><div class="stat-label">API Cost</div></div>
</div>
""", unsafe_allow_html=True)
if not matches:
st.info("No matches found. Try broadening your description or increasing the number of papers.")
else:
zip_bytes = build_zip(matches, paper_lookup)
st.download_button(
"⬇ Download all matched figures (.zip)",
data=zip_bytes, file_name="foto_results.zip",
mime="application/zip", use_container_width=True,
)
st.markdown("---")
for i, match in enumerate(matches):
paper = paper_lookup.get(match["paper_id"], {})
verdict = match.get("_verify") or match.get("_primary") or {}
conf = get_confidence(match)
arxiv_id = paper.get("externalIds", {}).get("ArXiv")
arxiv_link = f' · <a href="https://arxiv.org/abs/{arxiv_id}" target="_blank">arXiv:{arxiv_id}</a>' if arxiv_id else ""
st.markdown(f"""
<div class="result-card">
<div class="result-title">{paper.get('title', 'Unknown')[:100]}</div>
<div class="result-meta">{format_authors(paper)} ({paper.get('year', '')}){arxiv_link}</div>
<div class="result-badges">
<span class="badge {confidence_badge_class(conf)}">conf {conf:.2f}</span>
<span class="badge badge-type">{verdict.get('plot_type', '')}</span>
<span class="badge" style="background:#f5f5f0;color:#555;">page {match['page']}</span>
</div>
<div style="font-size:0.82rem;color:#555;margin-bottom:0.5rem;">{verdict.get('what_is_plotted', '')}</div>
</div>
""", unsafe_allow_html=True)
try:
st.image(match["image_bytes"], use_container_width=True)
except Exception:
st.warning("Could not display image.")
st.download_button(
"⬇ Download figure",
data=match["image_bytes"],
file_name=f"figure_{i+1:02d}_{(paper.get('title') or 'unknown')[:30].replace(' ','_')}.png",
mime="image/png", key=f"dl_{i}",
)
st.markdown("---")
if st.session_state.results and st.session_state.results.get("matches"):
st.markdown("""
<div class="feedback-box">
<div class="feedback-title">How did FOTO do?</div>
</div>
""", unsafe_allow_html=True)
rating = st.select_slider(
"Was one of the top matches what you were looking for? (1 = not at all, 5 = perfect match)",
options=[1, 2, 3, 4, 5],
value=3,
)
if st.button("Submit feedback", key="submit_feedback"):
log_rating(rating)
st.session_state.global_stats["ratings"].append(rating)
st.success("Thanks!")
stats = st.session_state.global_stats
n_ratings = len(stats["ratings"])
avg = sum(stats["ratings"]) / n_ratings if n_ratings else 0
st.markdown(f"""
<div class="tally-box">
<div class="tally-title">Overall stats</div>
<div class="tally-row">
<div class="stat-item"><div class="tally-num">{stats['searches']}</div><div class="tally-label">Searches</div></div>
<div class="stat-item"><div class="tally-num">{n_ratings}</div><div class="tally-label">Rated</div></div>
<div class="stat-item"><div class="tally-num">{"—" if not n_ratings else f"{avg:.1f}"}</div><div class="tally-label">Avg score</div></div>
</div>
</div>
""", unsafe_allow_html=True)