import sys
import time
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
import streamlit as st
from foto import (
MODEL_LABELS, MODEL_REGISTRY, get_model, CostTracker,
InputParser, PaperSearcher, PaperTriager,
PDFStore, FigureExtractor, FigureScorer,
build_zip, format_authors, get_confidence, confidence_badge_class,
)
from foto.models import PROVIDER_DISPLAY
from foto.llm_client import LLMClient
from foto.persistence import load_stats, log_search, log_rating
st.set_page_config(
page_title="FOTO",
page_icon="๐",
layout="wide",
initial_sidebar_state="collapsed",
)
st.markdown("""
""", unsafe_allow_html=True)
for key, default in {
"pdf_cache": {},
"results": None,
"running": False,
"log": [],
"tally": {"searches": 0, "ratings": []},
}.items():
if key not in st.session_state:
st.session_state[key] = default
if "global_stats" not in st.session_state:
st.session_state.global_stats = load_stats()
st.markdown("""
""", unsafe_allow_html=True)
col_left, col_right = st.columns([1, 1], gap="large")
def render_api_key_field(provider: str, model_cfg, key_state: str) -> str:
display_name = PROVIDER_DISPLAY.get(provider, "API Key")
st.markdown(f'{display_name}
', unsafe_allow_html=True)
key = st.text_input(
display_name, type="password", label_visibility="collapsed",
placeholder="...", key=key_state,
)
st.markdown(
f'',
unsafe_allow_html=True,
)
return key
with col_left:
st.markdown('Primary Model
', unsafe_allow_html=True)
primary_label = st.selectbox(
"Primary Model", options=MODEL_LABELS,
label_visibility="collapsed", key="primary_model",
)
primary_cfg = get_model(primary_label)
primary_key = render_api_key_field(primary_cfg.provider, primary_cfg, "primary_api_key")
use_pathfinder = st.checkbox(
"Use Pathfinder (recommended)",
value=True,
key="use_pathfinder",
)
st.markdown(
''
'Based on
arXiv:2408.01556 ยท '
'OpenAI key required for query embedding ยท '
'~$1 per 2M queries'
'
',
unsafe_allow_html=True,
)
openai_key = ""
if use_pathfinder:
st.markdown('OpenAI API Key
', unsafe_allow_html=True)
openai_key = st.text_input(
"OpenAI Key", type="password", label_visibility="collapsed",
placeholder="sk-...", key="openai_key",
)
st.markdown(
'Used to embed queries with text-embedding-3-small. '
'
Get key โ ',
unsafe_allow_html=True,
)
s2_key = ""
if not use_pathfinder:
st.markdown('Semantic Scholar Key (optional)
', unsafe_allow_html=True)
s2_key = st.text_input(
"S2 Key", type="password", label_visibility="collapsed",
placeholder="(improves keyword search)", key="s2_key",
)
st.markdown(
'Optional โ speeds up the keyword-based paper search. '
'
Get key โ ',
unsafe_allow_html=True,
)
st.markdown('Describe the figure
', unsafe_allow_html=True)
user_text = st.text_area(
"Figure description", label_visibility="collapsed", height=120,
placeholder='e.g. "scatter plot of cosmological parameter constraints from wavelet scattering transform, Omega_m vs sigma_8"',
)
st.markdown('Upload a sketch (optional)
', unsafe_allow_html=True)
sketch_file = st.file_uploader("Sketch", type=["png", "jpg", "jpeg", "webp"], label_visibility="collapsed")
run_verify = st.checkbox("Secondary verification (recommended)", value=True, key="run_verify")
st.markdown(
''
'Double-checks top matches. Uses the primary model by default.
',
unsafe_allow_html=True,
)
verify_cfg = primary_cfg
verify_key = primary_key
if run_verify:
verify_options = ["Same as primary"] + MODEL_LABELS
verify_choice = st.selectbox(
"Verification model", options=verify_options,
label_visibility="collapsed", key="verify_model",
)
if verify_choice != "Same as primary":
verify_cfg = get_model(verify_choice)
if verify_cfg.provider != primary_cfg.provider:
verify_key = render_api_key_field(
verify_cfg.provider, verify_cfg, "verify_api_key",
)
else:
verify_key = primary_key
num_papers = st.slider("Papers to search", min_value=5, max_value=50, value=20, step=5)
run_btn = st.button("๐ญ Search", use_container_width=True, type="primary", disabled=st.session_state.running)
with col_right:
if not st.session_state.results and not st.session_state.running:
st.markdown("""
๐ญ
Your search progress will appear here
""", unsafe_allow_html=True)
if run_btn:
if not primary_key:
st.error(f"Please enter your {PROVIDER_DISPLAY[primary_cfg.provider]}.")
elif use_pathfinder and not openai_key:
st.error("Pathfinder is checked โ please enter your OpenAI API key, or uncheck Pathfinder.")
elif run_verify and verify_cfg.provider != primary_cfg.provider and not verify_key:
st.error(f"Verification model needs its own key โ please enter your {PROVIDER_DISPLAY[verify_cfg.provider]}.")
elif not user_text and not sketch_file:
st.error("Please enter a description or upload a sketch (or both).")
else:
st.session_state.running = True
st.session_state.results = None
tracker = CostTracker()
primary_client = LLMClient(provider=primary_cfg.provider, api_key=primary_key)
verify_client = (
primary_client if verify_cfg.provider == primary_cfg.provider
else LLMClient(provider=verify_cfg.provider, api_key=verify_key)
)
sketch_bytes = sketch_file.read() if sketch_file else None
with col_right:
log_placeholder = st.empty()
progress_placeholder = st.empty()
def log(msg):
st.session_state.log.append(msg)
log_placeholder.markdown(
"\n".join(f'{m}
' for m in st.session_state.log[-20:]),
unsafe_allow_html=True,
)
st.session_state.log = []
try:
log("โณ Parsing your description...")
parser = InputParser(
primary_client, primary_cfg.model_id, primary_cfg.prices, tracker,
max_tokens=primary_cfg.triage_max_tokens * 4,
)
spec = parser.parse(text=user_text or None, sketch_bytes=sketch_bytes)
query = spec["science_query"] or user_text or "(no query)"
log(f"โ Query: {query}")
if spec.get("plot_type"):
log(f" Plot type: {spec['plot_type']}")
searcher = PaperSearcher(s2_key=s2_key or None)
if use_pathfinder:
all_papers = searcher.expanded_search_pathfinder(query, openai_key, log=log)
else:
all_papers = searcher.expanded_search(
query, primary_client, primary_cfg.model_id, primary_cfg.prices, tracker,
max_tokens=primary_cfg.triage_max_tokens * 4, log=log,
)
log(f"โ {len(all_papers)} unique papers found")
log(f"โณ Triaging papers (please wait)...")
triager = PaperTriager(
primary_client, primary_cfg.model_id, primary_cfg.prices, tracker,
max_tokens=primary_cfg.triage_max_tokens,
batch_size=primary_cfg.score_batch_size,
)
triaged = triager.triage(all_papers, spec)
top = triaged[:num_papers]
log(f"โ {len(top)} papers passed triage")
paper_lookup = {p["paperId"]: p for p in top}
log("โณ Fetching PDFs...")
downloaded = []
for i, paper in enumerate(top):
progress_placeholder.progress((i + 1) / len(top), text=f"Fetching PDF {i+1}/{len(top)}")
pdf_bytes, reason = PDFStore.fetch(paper)
if pdf_bytes:
paper["_pdf_bytes"] = pdf_bytes
downloaded.append(paper)
log(f" โ {paper.get('title','')[:60]}")
else:
log(f" โ {paper.get('title','')[:60]}")
time.sleep(0.5)
progress_placeholder.empty()
log(f"โ {len(downloaded)} PDFs ready")
log("โณ Extracting figures...")
extractor = FigureExtractor()
all_figures = []
for paper in downloaded:
try:
figs = extractor.extract(paper["_pdf_bytes"], paper["paperId"])
all_figures.extend(figs)
except Exception as e:
log(f" โ {paper.get('title','')[:40]}: {e}")
filtered = extractor.caption_filter(all_figures, query)
log(f" {len(filtered)} figures after caption filter (from {len(all_figures)} total)")
log(f"โณ Scoring {len(filtered)} figures...")
scorer = FigureScorer(
primary_client, primary_cfg.model_id, primary_cfg.prices, tracker,
score_max_tokens=primary_cfg.score_max_tokens,
verify_max_tokens=primary_cfg.verify_max_tokens,
batch_size=primary_cfg.score_batch_size,
)
results = scorer.score_batch(filtered, spec)
primary_matches = [fig for fig, result in zip(filtered, results)
if result.get("confidence", 0) >= 0.5]
log(f"โ {len(primary_matches)} primary matches")
verified = primary_matches
if run_verify and primary_matches:
log(f"โณ Verifying {len(primary_matches)} matches...")
verifier = FigureScorer(
verify_client, verify_cfg.model_id, verify_cfg.prices, tracker,
score_max_tokens=verify_cfg.score_max_tokens,
verify_max_tokens=verify_cfg.verify_max_tokens,
batch_size=1,
)
verified = []
for i, fig in enumerate(primary_matches):
progress_placeholder.progress((i + 1) / len(primary_matches), text=f"Verifying {i+1}/{len(primary_matches)}")
result = verifier.verify(fig, spec)
if result.get("confidence", 0) >= 0.5:
verified.append(fig)
progress_placeholder.empty()
log(f"โ {len(verified)} verified matches")
verified.sort(key=lambda m: -get_confidence(m))
total_cost = tracker.total()
log(f"โ Done โ ${total_cost:.4f}")
st.session_state.results = {
"matches": verified,
"paper_lookup": paper_lookup,
"spec": spec,
"query": query,
"cost": total_cost,
"n_papers": len(downloaded),
"n_figures": len(all_figures),
}
log_search(query, len(downloaded), len(verified), total_cost)
st.session_state.global_stats["searches"] += 1
except Exception as e:
import traceback
traceback.print_exc()
log(f"โ Error: {e}")
st.error(f"Pipeline error: {e}")
finally:
st.session_state.running = False
st.rerun()
if st.session_state.results:
res = st.session_state.results
matches = res["matches"]
paper_lookup = res["paper_lookup"]
with col_right:
st.markdown(f"""
{res['n_figures']}
Figures
${res['cost']:.3f}
API Cost
""", unsafe_allow_html=True)
if not matches:
st.info("No matches found. Try broadening your description or increasing the number of papers.")
else:
zip_bytes = build_zip(matches, paper_lookup)
st.download_button(
"โฌ Download all matched figures (.zip)",
data=zip_bytes, file_name="foto_results.zip",
mime="application/zip", use_container_width=True,
)
st.markdown("---")
for i, match in enumerate(matches):
paper = paper_lookup.get(match["paper_id"], {})
verdict = match.get("_verify") or match.get("_primary") or {}
conf = get_confidence(match)
arxiv_id = paper.get("externalIds", {}).get("ArXiv")
arxiv_link = f' ยท arXiv:{arxiv_id}' if arxiv_id else ""
st.markdown(f"""
{paper.get('title', 'Unknown')[:100]}
{format_authors(paper)} ({paper.get('year', '')}){arxiv_link}
conf {conf:.2f}
{verdict.get('plot_type', '')}
page {match['page']}
{verdict.get('what_is_plotted', '')}
""", unsafe_allow_html=True)
try:
st.image(match["image_bytes"], use_container_width=True)
except Exception:
st.warning("Could not display image.")
st.download_button(
"โฌ Download figure",
data=match["image_bytes"],
file_name=f"figure_{i+1:02d}_{(paper.get('title') or 'unknown')[:30].replace(' ','_')}.png",
mime="image/png", key=f"dl_{i}",
)
st.markdown("---")
if st.session_state.results and st.session_state.results.get("matches"):
st.markdown("""
""", unsafe_allow_html=True)
rating = st.select_slider(
"Was one of the top matches what you were looking for? (1 = not at all, 5 = perfect match)",
options=[1, 2, 3, 4, 5],
value=3,
)
if st.button("Submit feedback", key="submit_feedback"):
log_rating(rating)
st.session_state.global_stats["ratings"].append(rating)
st.success("Thanks!")
stats = st.session_state.global_stats
n_ratings = len(stats["ratings"])
avg = sum(stats["ratings"]) / n_ratings if n_ratings else 0
st.markdown(f"""
Overall stats
{stats['searches']}
Searches
{"โ" if not n_ratings else f"{avg:.1f}"}
Avg score
""", unsafe_allow_html=True)