import sys import time from pathlib import Path sys.path.insert(0, str(Path(__file__).parent)) import streamlit as st from foto import ( MODEL_LABELS, MODEL_REGISTRY, get_model, CostTracker, InputParser, PaperSearcher, PaperTriager, PDFStore, FigureExtractor, FigureScorer, build_zip, format_authors, get_confidence, confidence_badge_class, ) from foto.models import PROVIDER_DISPLAY from foto.llm_client import LLMClient from foto.persistence import load_stats, log_search, log_rating st.set_page_config( page_title="FOTO", page_icon="๐Ÿ‡", layout="wide", initial_sidebar_state="collapsed", ) st.markdown(""" """, unsafe_allow_html=True) for key, default in { "pdf_cache": {}, "results": None, "running": False, "log": [], "tally": {"searches": 0, "ratings": []}, }.items(): if key not in st.session_state: st.session_state[key] = default if "global_stats" not in st.session_state: st.session_state.global_stats = load_stats() st.markdown("""

Figure frOm for Text & illustratiOns

Describe a scientific figure in words, upload a sketch, or both and FOTO searches the literature to find it.

""", unsafe_allow_html=True) col_left, col_right = st.columns([1, 1], gap="large") def render_api_key_field(provider: str, model_cfg, key_state: str) -> str: display_name = PROVIDER_DISPLAY.get(provider, "API Key") st.markdown(f'

{display_name}

', unsafe_allow_html=True) key = st.text_input( display_name, type="password", label_visibility="collapsed", placeholder="...", key=key_state, ) st.markdown( f'
{model_cfg.api_help_text} ' f'Get key โ†’
', unsafe_allow_html=True, ) return key with col_left: st.markdown('

Primary Model

', unsafe_allow_html=True) primary_label = st.selectbox( "Primary Model", options=MODEL_LABELS, label_visibility="collapsed", key="primary_model", ) primary_cfg = get_model(primary_label) primary_key = render_api_key_field(primary_cfg.provider, primary_cfg, "primary_api_key") use_pathfinder = st.checkbox( "Use Pathfinder (recommended)", value=True, key="use_pathfinder", ) st.markdown( '
' 'Based on arXiv:2408.01556 ยท ' 'OpenAI key required for query embedding ยท ' '~$1 per 2M queries' '
', unsafe_allow_html=True, ) openai_key = "" if use_pathfinder: st.markdown('

OpenAI API Key

', unsafe_allow_html=True) openai_key = st.text_input( "OpenAI Key", type="password", label_visibility="collapsed", placeholder="sk-...", key="openai_key", ) st.markdown( '
Used to embed queries with text-embedding-3-small. ' 'Get key โ†’
', unsafe_allow_html=True, ) s2_key = "" if not use_pathfinder: st.markdown('

Semantic Scholar Key (optional)

', unsafe_allow_html=True) s2_key = st.text_input( "S2 Key", type="password", label_visibility="collapsed", placeholder="(improves keyword search)", key="s2_key", ) st.markdown( '
Optional โ€” speeds up the keyword-based paper search. ' 'Get key โ†’
', unsafe_allow_html=True, ) st.markdown('

Describe the figure

', unsafe_allow_html=True) user_text = st.text_area( "Figure description", label_visibility="collapsed", height=120, placeholder='e.g. "scatter plot of cosmological parameter constraints from wavelet scattering transform, Omega_m vs sigma_8"', ) st.markdown('

Upload a sketch (optional)

', unsafe_allow_html=True) sketch_file = st.file_uploader("Sketch", type=["png", "jpg", "jpeg", "webp"], label_visibility="collapsed") run_verify = st.checkbox("Secondary verification (recommended)", value=True, key="run_verify") st.markdown( '

' 'Double-checks top matches. Uses the primary model by default.

', unsafe_allow_html=True, ) verify_cfg = primary_cfg verify_key = primary_key if run_verify: verify_options = ["Same as primary"] + MODEL_LABELS verify_choice = st.selectbox( "Verification model", options=verify_options, label_visibility="collapsed", key="verify_model", ) if verify_choice != "Same as primary": verify_cfg = get_model(verify_choice) if verify_cfg.provider != primary_cfg.provider: verify_key = render_api_key_field( verify_cfg.provider, verify_cfg, "verify_api_key", ) else: verify_key = primary_key num_papers = st.slider("Papers to search", min_value=5, max_value=50, value=20, step=5) run_btn = st.button("๐Ÿ”ญ Search", use_container_width=True, type="primary", disabled=st.session_state.running) with col_right: if not st.session_state.results and not st.session_state.running: st.markdown("""
๐Ÿ”ญ
Your search progress will appear here
""", unsafe_allow_html=True) if run_btn: if not primary_key: st.error(f"Please enter your {PROVIDER_DISPLAY[primary_cfg.provider]}.") elif use_pathfinder and not openai_key: st.error("Pathfinder is checked โ€” please enter your OpenAI API key, or uncheck Pathfinder.") elif run_verify and verify_cfg.provider != primary_cfg.provider and not verify_key: st.error(f"Verification model needs its own key โ€” please enter your {PROVIDER_DISPLAY[verify_cfg.provider]}.") elif not user_text and not sketch_file: st.error("Please enter a description or upload a sketch (or both).") else: st.session_state.running = True st.session_state.results = None tracker = CostTracker() primary_client = LLMClient(provider=primary_cfg.provider, api_key=primary_key) verify_client = ( primary_client if verify_cfg.provider == primary_cfg.provider else LLMClient(provider=verify_cfg.provider, api_key=verify_key) ) sketch_bytes = sketch_file.read() if sketch_file else None with col_right: log_placeholder = st.empty() progress_placeholder = st.empty() def log(msg): st.session_state.log.append(msg) log_placeholder.markdown( "\n".join(f'
{m}
' for m in st.session_state.log[-20:]), unsafe_allow_html=True, ) st.session_state.log = [] try: log("โŸณ Parsing your description...") parser = InputParser( primary_client, primary_cfg.model_id, primary_cfg.prices, tracker, max_tokens=primary_cfg.triage_max_tokens * 4, ) spec = parser.parse(text=user_text or None, sketch_bytes=sketch_bytes) query = spec["science_query"] or user_text or "(no query)" log(f"โœ“ Query: {query}") if spec.get("plot_type"): log(f" Plot type: {spec['plot_type']}") searcher = PaperSearcher(s2_key=s2_key or None) if use_pathfinder: all_papers = searcher.expanded_search_pathfinder(query, openai_key, log=log) else: all_papers = searcher.expanded_search( query, primary_client, primary_cfg.model_id, primary_cfg.prices, tracker, max_tokens=primary_cfg.triage_max_tokens * 4, log=log, ) log(f"โœ“ {len(all_papers)} unique papers found") log(f"โŸณ Triaging papers (please wait)...") triager = PaperTriager( primary_client, primary_cfg.model_id, primary_cfg.prices, tracker, max_tokens=primary_cfg.triage_max_tokens, batch_size=primary_cfg.score_batch_size, ) triaged = triager.triage(all_papers, spec) top = triaged[:num_papers] log(f"โœ“ {len(top)} papers passed triage") paper_lookup = {p["paperId"]: p for p in top} log("โŸณ Fetching PDFs...") downloaded = [] for i, paper in enumerate(top): progress_placeholder.progress((i + 1) / len(top), text=f"Fetching PDF {i+1}/{len(top)}") pdf_bytes, reason = PDFStore.fetch(paper) if pdf_bytes: paper["_pdf_bytes"] = pdf_bytes downloaded.append(paper) log(f" โœ“ {paper.get('title','')[:60]}") else: log(f" โœ— {paper.get('title','')[:60]}") time.sleep(0.5) progress_placeholder.empty() log(f"โœ“ {len(downloaded)} PDFs ready") log("โŸณ Extracting figures...") extractor = FigureExtractor() all_figures = [] for paper in downloaded: try: figs = extractor.extract(paper["_pdf_bytes"], paper["paperId"]) all_figures.extend(figs) except Exception as e: log(f" โœ— {paper.get('title','')[:40]}: {e}") filtered = extractor.caption_filter(all_figures, query) log(f" {len(filtered)} figures after caption filter (from {len(all_figures)} total)") log(f"โŸณ Scoring {len(filtered)} figures...") scorer = FigureScorer( primary_client, primary_cfg.model_id, primary_cfg.prices, tracker, score_max_tokens=primary_cfg.score_max_tokens, verify_max_tokens=primary_cfg.verify_max_tokens, batch_size=primary_cfg.score_batch_size, ) results = scorer.score_batch(filtered, spec) primary_matches = [fig for fig, result in zip(filtered, results) if result.get("confidence", 0) >= 0.5] log(f"โœ“ {len(primary_matches)} primary matches") verified = primary_matches if run_verify and primary_matches: log(f"โŸณ Verifying {len(primary_matches)} matches...") verifier = FigureScorer( verify_client, verify_cfg.model_id, verify_cfg.prices, tracker, score_max_tokens=verify_cfg.score_max_tokens, verify_max_tokens=verify_cfg.verify_max_tokens, batch_size=1, ) verified = [] for i, fig in enumerate(primary_matches): progress_placeholder.progress((i + 1) / len(primary_matches), text=f"Verifying {i+1}/{len(primary_matches)}") result = verifier.verify(fig, spec) if result.get("confidence", 0) >= 0.5: verified.append(fig) progress_placeholder.empty() log(f"โœ“ {len(verified)} verified matches") verified.sort(key=lambda m: -get_confidence(m)) total_cost = tracker.total() log(f"โœ“ Done โ€” ${total_cost:.4f}") st.session_state.results = { "matches": verified, "paper_lookup": paper_lookup, "spec": spec, "query": query, "cost": total_cost, "n_papers": len(downloaded), "n_figures": len(all_figures), } log_search(query, len(downloaded), len(verified), total_cost) st.session_state.global_stats["searches"] += 1 except Exception as e: import traceback traceback.print_exc() log(f"โœ— Error: {e}") st.error(f"Pipeline error: {e}") finally: st.session_state.running = False st.rerun() if st.session_state.results: res = st.session_state.results matches = res["matches"] paper_lookup = res["paper_lookup"] with col_right: st.markdown(f"""
{len(matches)}
Matches
{res['n_papers']}
Papers
{res['n_figures']}
Figures
${res['cost']:.3f}
API Cost
""", unsafe_allow_html=True) if not matches: st.info("No matches found. Try broadening your description or increasing the number of papers.") else: zip_bytes = build_zip(matches, paper_lookup) st.download_button( "โฌ‡ Download all matched figures (.zip)", data=zip_bytes, file_name="foto_results.zip", mime="application/zip", use_container_width=True, ) st.markdown("---") for i, match in enumerate(matches): paper = paper_lookup.get(match["paper_id"], {}) verdict = match.get("_verify") or match.get("_primary") or {} conf = get_confidence(match) arxiv_id = paper.get("externalIds", {}).get("ArXiv") arxiv_link = f' ยท arXiv:{arxiv_id}' if arxiv_id else "" st.markdown(f"""
{paper.get('title', 'Unknown')[:100]}
{format_authors(paper)} ({paper.get('year', '')}){arxiv_link}
conf {conf:.2f} {verdict.get('plot_type', '')} page {match['page']}
{verdict.get('what_is_plotted', '')}
""", unsafe_allow_html=True) try: st.image(match["image_bytes"], use_container_width=True) except Exception: st.warning("Could not display image.") st.download_button( "โฌ‡ Download figure", data=match["image_bytes"], file_name=f"figure_{i+1:02d}_{(paper.get('title') or 'unknown')[:30].replace(' ','_')}.png", mime="image/png", key=f"dl_{i}", ) st.markdown("---") if st.session_state.results and st.session_state.results.get("matches"): st.markdown("""
How did FOTO do?
""", unsafe_allow_html=True) rating = st.select_slider( "Was one of the top matches what you were looking for? (1 = not at all, 5 = perfect match)", options=[1, 2, 3, 4, 5], value=3, ) if st.button("Submit feedback", key="submit_feedback"): log_rating(rating) st.session_state.global_stats["ratings"].append(rating) st.success("Thanks!") stats = st.session_state.global_stats n_ratings = len(stats["ratings"]) avg = sum(stats["ratings"]) / n_ratings if n_ratings else 0 st.markdown(f"""
Overall stats
{stats['searches']}
Searches
{n_ratings}
Rated
{"โ€”" if not n_ratings else f"{avg:.1f}"}
Avg score
""", unsafe_allow_html=True)