Hurum Maksora Tohfa commited on
Commit
dfdfaa2
Β·
unverified Β·
1 Parent(s): 28edfc5

Refactor API key handling and improve UI elements

Browse files
Files changed (1) hide show
  1. app.py +148 -81
app.py CHANGED
@@ -1,17 +1,19 @@
1
  import sys
2
  import time
3
  from pathlib import Path
 
4
  sys.path.insert(0, str(Path(__file__).parent))
5
 
6
  import streamlit as st
7
- from anthropic import Anthropic
8
 
9
  from foto import (
10
- MODEL_LABELS, get_model, CostTracker,
11
  InputParser, PaperSearcher, PaperTriager,
12
  PDFStore, FigureExtractor, FigureScorer,
13
  build_zip, format_authors, get_confidence, confidence_badge_class,
14
  )
 
 
15
  from foto.persistence import load_stats, log_search, log_rating
16
 
17
  st.set_page_config(
@@ -29,13 +31,14 @@ html, body, [class*="css"] { font-family: 'Inter', sans-serif; font-weight: 300;
29
  .stApp { background: #fafaf8; color: #1a1a1a; }
30
 
31
  .foto-header { padding: 3rem 0 1.5rem 0; border-bottom: 1px solid #e0e0d8; margin-bottom: 2.5rem; }
32
- .foto-title { font-family: 'DM Serif Display', serif; font-size: 2.8rem; font-weight: 400; letter-spacing: -0.02em; line-height: 1.1; color: #1a1a1a; margin: 0; }
33
- .foto-title em { font-style: italic; color: #4a5568; }
34
  .foto-subtitle { font-size: 0.85rem; color: #888; letter-spacing: 0.08em; text-transform: uppercase; margin-top: 0.4rem; font-family: 'DM Mono', monospace; }
35
  .foto-tagline { font-size: 1rem; color: #555; margin-top: 0.8rem; font-weight: 300; max-width: 560px; }
36
 
37
  .section-label { font-family: 'DM Mono', monospace; font-size: 0.7rem; letter-spacing: 0.12em; text-transform: uppercase; color: #999; margin-bottom: 0.5rem; }
38
 
 
 
 
39
  .result-card { background: white; border: 1px solid #e8e8e0; border-radius: 4px; padding: 1.2rem 1.4rem; margin-bottom: 1.2rem; }
40
  .result-title { font-family: 'DM Serif Display', serif; font-size: 1.05rem; color: #1a1a1a; margin-bottom: 0.2rem; line-height: 1.3; }
41
  .result-meta { font-size: 0.8rem; color: #777; font-style: italic; margin-bottom: 0.6rem; }
@@ -62,20 +65,27 @@ html, body, [class*="css"] { font-family: 'Inter', sans-serif; font-weight: 300;
62
  .tally-num { font-family: 'DM Serif Display', serif; font-size: 2rem; color: #fafaf8; }
63
  .tally-label { font-family: 'DM Mono', monospace; font-size: 0.65rem; color: #888; letter-spacing: 0.08em; margin-top: 0.2rem; }
64
 
 
 
 
65
  .stTextArea textarea { font-family: 'Inter', sans-serif; font-size: 0.95rem; border: 1px solid #ddd; border-radius: 3px; }
66
  .stButton button { font-family: 'DM Mono', monospace; font-size: 0.8rem; letter-spacing: 0.06em; border-radius: 3px; }
67
- div[data-testid="stSelectbox"] label,
68
- div[data-testid="stTextInput"] label,
69
- div[data-testid="stCheckbox"] label { font-family: 'DM Mono', monospace; font-size: 0.75rem; letter-spacing: 0.08em; text-transform: uppercase; color: #888; }
70
 
71
- .pathfinder-row { display: flex; align-items: center; gap: 0.5rem; }
72
- .pathfinder-cite { font-family: 'DM Mono', monospace; font-size: 0.7rem; color: #888; }
73
- .pathfinder-cite a { color: #888; text-decoration: underline; }
 
 
 
 
 
 
 
 
74
  </style>
75
  """, unsafe_allow_html=True)
76
 
77
 
78
- # Per-session state for caching, run status, log buffer, and feedback tally
79
  for key, default in {
80
  "pdf_cache": {},
81
  "results": None,
@@ -102,55 +112,104 @@ st.markdown("""
102
 
103
  col_left, col_right = st.columns([1, 1], gap="large")
104
 
105
- with col_left:
106
- st.markdown('<p class="section-label">Model</p>', unsafe_allow_html=True)
107
- model_label = st.selectbox("Model", options=MODEL_LABELS, label_visibility="collapsed")
108
 
109
- st.markdown('<p class="section-label" style="margin-top:1.2rem;">Anthropic API Key</p>', unsafe_allow_html=True)
110
- api_key = st.text_input("Anthropic API Key", type="password", label_visibility="collapsed", placeholder="sk-ant-...")
 
 
 
 
 
 
 
 
 
 
 
111
 
112
- st.markdown('<p class="section-label" style="margin-top:0.8rem;">Semantic Scholar Key (optional)</p>', unsafe_allow_html=True)
113
- s2_key = st.text_input("S2 Key", type="password", label_visibility="collapsed", placeholder="(Recommended for keyword fallback)")
114
 
115
- st.markdown('<p class="section-label" style="margin-top:1.5rem;">Describe the figure</p>', unsafe_allow_html=True)
116
- user_text = st.text_area(
117
- "Figure description", label_visibility="collapsed", height=120,
118
- placeholder='e.g. "scatter plot of cosmological parameter constraints from wavelet scattering transform, Omega_m vs sigma_8"',
 
119
  )
 
 
120
 
121
- st.markdown('<p class="section-label" style="margin-top:0.8rem;">Upload a sketch (optional)</p>', unsafe_allow_html=True)
122
- sketch_file = st.file_uploader("Sketch", type=["png", "jpg", "jpeg", "webp"], label_visibility="collapsed")
 
 
 
 
 
 
 
 
 
 
 
123
 
124
- # Pathfinder toggle + inline citation link
125
- pf_col1, pf_col2 = st.columns([2, 3])
126
- with pf_col1:
127
- use_pathfinder = st.checkbox("Use Pathfinder", value=True)
128
- with pf_col2:
 
 
129
  st.markdown(
130
- '<div class="pathfinder-cite" style="padding-top:0.55rem;">'
131
- 'based on <a href="https://arxiv.org/abs/2408.01556" target="_blank">arXiv:2408.01556</a>'
132
- '</div>',
133
  unsafe_allow_html=True,
134
  )
135
 
136
- # OpenAI key only needed when Pathfinder is active (used to embed queries
137
- # with text-embedding-3-small against the Pathfinder corpus)
138
- openai_key = None
139
- if use_pathfinder:
140
- st.markdown('<p class="section-label" style="margin-top:0.6rem;">OpenAI API Key</p>', unsafe_allow_html=True)
141
- openai_key = st.text_input(
142
- "OpenAI Key", type="password", label_visibility="collapsed",
143
- placeholder="sk-...",
144
  )
145
  st.markdown(
146
- '<p style="font-size:0.78rem;color:#888;margin-top:-0.4rem;">'
147
- 'Used to embed queries with text-embedding-3-small (~$0.40 per million queries).'
148
- '</p>',
149
  unsafe_allow_html=True,
150
  )
151
 
152
- run_verify = st.checkbox("Secondary verification β€” recommended, adds ~$0.05", value=True)
153
- st.markdown('<p style="font-size:0.78rem;color:#888;margin-top:-0.8rem;margin-left:1.8rem;">Uses a smarter model to double-check top matches. Best results, small extra cost.</p>', unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  num_papers = st.slider("Papers to search", min_value=5, max_value=50, value=20, step=5)
155
 
156
  run_btn = st.button("πŸ”­ Search", use_container_width=True, type="primary", disabled=st.session_state.running)
@@ -160,29 +219,30 @@ with col_right:
160
  st.markdown("""
161
  <div style="padding: 3rem 2rem; color: #aaa; text-align: center;">
162
  <div style="font-size: 3rem; margin-bottom: 1rem;">πŸ”­</div>
163
- <div style="font-family: 'DM Mono', monospace; font-size: 0.75rem; letter-spacing: 0.1em; text-transform: uppercase;">Results will appear here</div>
164
- <div style="margin-top: 0.8rem; font-size: 0.85rem; max-width: 300px; margin-left: auto; margin-right: auto; line-height: 1.6;">
165
- Enter a description, an optional sketch, and your API key then hit Search (On average per search cost 20-30Β’).
166
- </div>
167
  </div>
168
  """, unsafe_allow_html=True)
169
 
170
 
171
- # Full pipeline runs on button press
172
  if run_btn:
173
- if not api_key:
174
- st.error("Please enter your Anthropic API key.")
175
  elif use_pathfinder and not openai_key:
176
- st.error("Pathfinder is checked β€” please enter your OpenAI API key, or uncheck Pathfinder to use keyword search.")
 
 
177
  elif not user_text and not sketch_file:
178
  st.error("Please enter a description or upload a sketch (or both).")
179
  else:
180
  st.session_state.running = True
181
  st.session_state.results = None
182
 
183
- model_cfg = get_model(model_label)
184
- client = Anthropic(api_key=api_key)
185
- tracker = CostTracker(model_cfg.prices)
 
 
 
186
  sketch_bytes = sketch_file.read() if sketch_file else None
187
 
188
  with col_right:
@@ -199,33 +259,38 @@ if run_btn:
199
  st.session_state.log = []
200
 
201
  try:
202
- # Parse text + optional sketch into a structured search spec
203
  log("⟳ Parsing your description...")
204
- parser = InputParser(client, model_cfg.smart, tracker)
 
 
 
205
  spec = parser.parse(text=user_text or None, sketch_bytes=sketch_bytes)
206
  query = spec["science_query"] or user_text or "(no query)"
207
  log(f"βœ“ Query: <em>{query}</em>")
208
  if spec.get("plot_type"):
209
  log(f" Plot type: {spec['plot_type']}")
210
 
211
- # Pathfinder semantic retrieval, or legacy keyword expansion as fallback
212
  searcher = PaperSearcher(s2_key=s2_key or None)
213
  if use_pathfinder:
214
  all_papers = searcher.expanded_search_pathfinder(query, openai_key, log=log)
215
  else:
216
  all_papers = searcher.expanded_search(
217
- query, client, model_cfg.smart, tracker, log=log)
 
 
218
  log(f"βœ“ {len(all_papers)} unique papers found")
219
 
220
- # Abstract-level relevance filter to cut downstream cost
221
- log("⟳ Triaging with Claude...")
222
- triager = PaperTriager(client, model_cfg.cheap, tracker)
 
 
 
223
  triaged = triager.triage(all_papers, spec)
224
  top = triaged[:num_papers]
225
  log(f"βœ“ {len(top)} papers passed triage")
226
  paper_lookup = {p["paperId"]: p for p in top}
227
 
228
- # PDF fetch with arxiv-first URL preference, polite spacing
229
  log("⟳ Fetching PDFs...")
230
  downloaded = []
231
  for i, paper in enumerate(top):
@@ -241,7 +306,6 @@ if run_btn:
241
  progress_placeholder.empty()
242
  log(f"βœ“ {len(downloaded)} PDFs ready")
243
 
244
- # Pull raster figures + captions from each PDF, then caption pre-filter
245
  log("⟳ Extracting figures...")
246
  extractor = FigureExtractor()
247
  all_figures = []
@@ -255,23 +319,27 @@ if run_btn:
255
  filtered = extractor.caption_filter(all_figures, query)
256
  log(f" {len(filtered)} figures after caption filter (from {len(all_figures)} total)")
257
 
258
- # Cheap vision pass: score every surviving figure against the spec
259
- log(f"⟳ Scoring {len(filtered)} figures...")
260
- scorer = FigureScorer(client, model_cfg.cheap, tracker)
261
- primary_matches = []
262
- for i, fig in enumerate(filtered):
263
- progress_placeholder.progress((i + 1) / len(filtered), text=f"Scoring {i+1}/{len(filtered)}")
264
- result = scorer.score(fig, spec)
265
- if result.get("confidence", 0) >= 0.5:
266
- primary_matches.append(fig)
267
- progress_placeholder.empty()
268
  log(f"βœ“ {len(primary_matches)} primary matches")
269
 
270
- # Optional smart-model verification on figures that passed primary scoring
271
  verified = primary_matches
272
  if run_verify and primary_matches:
273
  log(f"⟳ Verifying {len(primary_matches)} matches...")
274
- verifier = FigureScorer(client, model_cfg.smart, tracker)
 
 
 
 
 
275
  verified = []
276
  for i, fig in enumerate(primary_matches):
277
  progress_placeholder.progress((i + 1) / len(primary_matches), text=f"Verifying {i+1}/{len(primary_matches)}")
@@ -298,6 +366,8 @@ if run_btn:
298
  st.session_state.global_stats["searches"] += 1
299
 
300
  except Exception as e:
 
 
301
  log(f"βœ— Error: {e}")
302
  st.error(f"Pipeline error: {e}")
303
  finally:
@@ -306,7 +376,6 @@ if run_btn:
306
  st.rerun()
307
 
308
 
309
- # Render results: stats row, downloadable zip, then per-figure cards with metadata
310
  if st.session_state.results:
311
  res = st.session_state.results
312
  matches = res["matches"]
@@ -367,7 +436,6 @@ if st.session_state.results:
367
  st.markdown("---")
368
 
369
 
370
- # Post-search feedback slider; submission logs to persistence layer
371
  if st.session_state.results and st.session_state.results.get("matches"):
372
  st.markdown("""
373
  <div class="feedback-box">
@@ -387,7 +455,6 @@ if st.session_state.results and st.session_state.results.get("matches"):
387
  st.success("Thanks!")
388
 
389
 
390
- # Aggregate stats across all sessions, loaded from persistence
391
  stats = st.session_state.global_stats
392
  n_ratings = len(stats["ratings"])
393
  avg = sum(stats["ratings"]) / n_ratings if n_ratings else 0
@@ -401,4 +468,4 @@ st.markdown(f"""
401
  <div class="stat-item"><div class="tally-num">{"β€”" if not n_ratings else f"{avg:.1f}"}</div><div class="tally-label">Avg score</div></div>
402
  </div>
403
  </div>
404
- """, unsafe_allow_html=True)
 
1
  import sys
2
  import time
3
  from pathlib import Path
4
+
5
  sys.path.insert(0, str(Path(__file__).parent))
6
 
7
  import streamlit as st
 
8
 
9
  from foto import (
10
+ MODEL_LABELS, MODEL_REGISTRY, get_model, CostTracker,
11
  InputParser, PaperSearcher, PaperTriager,
12
  PDFStore, FigureExtractor, FigureScorer,
13
  build_zip, format_authors, get_confidence, confidence_badge_class,
14
  )
15
+ from foto.models import PROVIDER_DISPLAY
16
+ from foto.llm_client import LLMClient
17
  from foto.persistence import load_stats, log_search, log_rating
18
 
19
  st.set_page_config(
 
31
  .stApp { background: #fafaf8; color: #1a1a1a; }
32
 
33
  .foto-header { padding: 3rem 0 1.5rem 0; border-bottom: 1px solid #e0e0d8; margin-bottom: 2.5rem; }
 
 
34
  .foto-subtitle { font-size: 0.85rem; color: #888; letter-spacing: 0.08em; text-transform: uppercase; margin-top: 0.4rem; font-family: 'DM Mono', monospace; }
35
  .foto-tagline { font-size: 1rem; color: #555; margin-top: 0.8rem; font-weight: 300; max-width: 560px; }
36
 
37
  .section-label { font-family: 'DM Mono', monospace; font-size: 0.7rem; letter-spacing: 0.12em; text-transform: uppercase; color: #999; margin-bottom: 0.5rem; }
38
 
39
+ .api-help { font-size: 0.75rem; color: #777; margin-top: -0.5rem; margin-bottom: 0.8rem; }
40
+ .api-help a { color: #4a5568; text-decoration: underline; }
41
+
42
  .result-card { background: white; border: 1px solid #e8e8e0; border-radius: 4px; padding: 1.2rem 1.4rem; margin-bottom: 1.2rem; }
43
  .result-title { font-family: 'DM Serif Display', serif; font-size: 1.05rem; color: #1a1a1a; margin-bottom: 0.2rem; line-height: 1.3; }
44
  .result-meta { font-size: 0.8rem; color: #777; font-style: italic; margin-bottom: 0.6rem; }
 
65
  .tally-num { font-family: 'DM Serif Display', serif; font-size: 2rem; color: #fafaf8; }
66
  .tally-label { font-family: 'DM Mono', monospace; font-size: 0.65rem; color: #888; letter-spacing: 0.08em; margin-top: 0.2rem; }
67
 
68
+ .pathfinder-cite { font-family: 'DM Mono', monospace; font-size: 0.72rem; color: #888; line-height: 1.5; }
69
+ .pathfinder-cite a { color: #4a5568; text-decoration: underline; }
70
+
71
  .stTextArea textarea { font-family: 'Inter', sans-serif; font-size: 0.95rem; border: 1px solid #ddd; border-radius: 3px; }
72
  .stButton button { font-family: 'DM Mono', monospace; font-size: 0.8rem; letter-spacing: 0.06em; border-radius: 3px; }
 
 
 
73
 
74
+ div[data-testid="stSelectbox"] label,
75
+ div[data-testid="stTextInput"] label { font-family: 'DM Mono', monospace; font-size: 0.75rem; letter-spacing: 0.08em; text-transform: uppercase; color: #888; }
76
+
77
+ div[data-testid="stCheckbox"] label p {
78
+ font-family: 'Inter', sans-serif !important;
79
+ font-size: 0.9rem !important;
80
+ letter-spacing: normal !important;
81
+ text-transform: none !important;
82
+ color: #1a1a1a !important;
83
+ font-weight: 400 !important;
84
+ }
85
  </style>
86
  """, unsafe_allow_html=True)
87
 
88
 
 
89
  for key, default in {
90
  "pdf_cache": {},
91
  "results": None,
 
112
 
113
  col_left, col_right = st.columns([1, 1], gap="large")
114
 
 
 
 
115
 
116
+ def render_api_key_field(provider: str, model_cfg, key_state: str) -> str:
117
+ display_name = PROVIDER_DISPLAY.get(provider, "API Key")
118
+ st.markdown(f'<p class="section-label" style="margin-top:0.8rem;">{display_name}</p>', unsafe_allow_html=True)
119
+ key = st.text_input(
120
+ display_name, type="password", label_visibility="collapsed",
121
+ placeholder="...", key=key_state,
122
+ )
123
+ st.markdown(
124
+ f'<div class="api-help">{model_cfg.api_help_text} '
125
+ f'<a href="{model_cfg.api_help_url}" target="_blank">Get key β†’</a></div>',
126
+ unsafe_allow_html=True,
127
+ )
128
+ return key
129
 
 
 
130
 
131
+ with col_left:
132
+ st.markdown('<p class="section-label">Primary Model</p>', unsafe_allow_html=True)
133
+ primary_label = st.selectbox(
134
+ "Primary Model", options=MODEL_LABELS,
135
+ label_visibility="collapsed", key="primary_model",
136
  )
137
+ primary_cfg = get_model(primary_label)
138
+ primary_key = render_api_key_field(primary_cfg.provider, primary_cfg, "primary_api_key")
139
 
140
+ use_pathfinder = st.checkbox(
141
+ "Use Pathfinder (recommended)",
142
+ value=True,
143
+ key="use_pathfinder",
144
+ )
145
+ st.markdown(
146
+ '<div class="pathfinder-cite" style="margin-top:-0.4rem;margin-left:1.8rem;">'
147
+ 'Based on <a href="https://arxiv.org/abs/2408.01556" target="_blank">arXiv:2408.01556</a> Β· '
148
+ 'OpenAI key required for query embedding Β· '
149
+ '~$1 per 2M queries'
150
+ '</div>',
151
+ unsafe_allow_html=True,
152
+ )
153
 
154
+ openai_key = ""
155
+ if use_pathfinder:
156
+ st.markdown('<p class="section-label" style="margin-top:0.8rem;">OpenAI API Key</p>', unsafe_allow_html=True)
157
+ openai_key = st.text_input(
158
+ "OpenAI Key", type="password", label_visibility="collapsed",
159
+ placeholder="sk-...", key="openai_key",
160
+ )
161
  st.markdown(
162
+ '<div class="api-help">Used to embed queries with text-embedding-3-small. '
163
+ '<a href="https://platform.openai.com/api-keys" target="_blank">Get key β†’</a></div>',
 
164
  unsafe_allow_html=True,
165
  )
166
 
167
+ s2_key = ""
168
+ if not use_pathfinder:
169
+ st.markdown('<p class="section-label" style="margin-top:0.8rem;">Semantic Scholar Key (optional)</p>', unsafe_allow_html=True)
170
+ s2_key = st.text_input(
171
+ "S2 Key", type="password", label_visibility="collapsed",
172
+ placeholder="(improves keyword search)", key="s2_key",
 
 
173
  )
174
  st.markdown(
175
+ '<div class="api-help">Optional β€” speeds up the keyword-based paper search. '
176
+ '<a href="https://www.semanticscholar.org/product/api" target="_blank">Get key β†’</a></div>',
 
177
  unsafe_allow_html=True,
178
  )
179
 
180
+ st.markdown('<p class="section-label" style="margin-top:1.2rem;">Describe the figure</p>', unsafe_allow_html=True)
181
+ user_text = st.text_area(
182
+ "Figure description", label_visibility="collapsed", height=120,
183
+ placeholder='e.g. "scatter plot of cosmological parameter constraints from wavelet scattering transform, Omega_m vs sigma_8"',
184
+ )
185
+
186
+ st.markdown('<p class="section-label" style="margin-top:0.8rem;">Upload a sketch (optional)</p>', unsafe_allow_html=True)
187
+ sketch_file = st.file_uploader("Sketch", type=["png", "jpg", "jpeg", "webp"], label_visibility="collapsed")
188
+
189
+ run_verify = st.checkbox("Secondary verification (recommended)", value=True, key="run_verify")
190
+ st.markdown(
191
+ '<p style="font-size:0.78rem;color:#888;margin-top:-0.6rem;margin-left:1.8rem;">'
192
+ 'Double-checks top matches. Uses the primary model by default.</p>',
193
+ unsafe_allow_html=True,
194
+ )
195
+
196
+ verify_cfg = primary_cfg
197
+ verify_key = primary_key
198
+ if run_verify:
199
+ verify_options = ["Same as primary"] + MODEL_LABELS
200
+ verify_choice = st.selectbox(
201
+ "Verification model", options=verify_options,
202
+ label_visibility="collapsed", key="verify_model",
203
+ )
204
+ if verify_choice != "Same as primary":
205
+ verify_cfg = get_model(verify_choice)
206
+ if verify_cfg.provider != primary_cfg.provider:
207
+ verify_key = render_api_key_field(
208
+ verify_cfg.provider, verify_cfg, "verify_api_key",
209
+ )
210
+ else:
211
+ verify_key = primary_key
212
+
213
  num_papers = st.slider("Papers to search", min_value=5, max_value=50, value=20, step=5)
214
 
215
  run_btn = st.button("πŸ”­ Search", use_container_width=True, type="primary", disabled=st.session_state.running)
 
219
  st.markdown("""
220
  <div style="padding: 3rem 2rem; color: #aaa; text-align: center;">
221
  <div style="font-size: 3rem; margin-bottom: 1rem;">πŸ”­</div>
222
+ <div style="font-family: 'DM Mono', monospace; font-size: 0.75rem; letter-spacing: 0.1em; text-transform: uppercase;">Your search progress will appear here</div>
 
 
 
223
  </div>
224
  """, unsafe_allow_html=True)
225
 
226
 
 
227
  if run_btn:
228
+ if not primary_key:
229
+ st.error(f"Please enter your {PROVIDER_DISPLAY[primary_cfg.provider]}.")
230
  elif use_pathfinder and not openai_key:
231
+ st.error("Pathfinder is checked β€” please enter your OpenAI API key, or uncheck Pathfinder.")
232
+ elif run_verify and verify_cfg.provider != primary_cfg.provider and not verify_key:
233
+ st.error(f"Verification model needs its own key β€” please enter your {PROVIDER_DISPLAY[verify_cfg.provider]}.")
234
  elif not user_text and not sketch_file:
235
  st.error("Please enter a description or upload a sketch (or both).")
236
  else:
237
  st.session_state.running = True
238
  st.session_state.results = None
239
 
240
+ tracker = CostTracker()
241
+ primary_client = LLMClient(provider=primary_cfg.provider, api_key=primary_key)
242
+ verify_client = (
243
+ primary_client if verify_cfg.provider == primary_cfg.provider
244
+ else LLMClient(provider=verify_cfg.provider, api_key=verify_key)
245
+ )
246
  sketch_bytes = sketch_file.read() if sketch_file else None
247
 
248
  with col_right:
 
259
  st.session_state.log = []
260
 
261
  try:
 
262
  log("⟳ Parsing your description...")
263
+ parser = InputParser(
264
+ primary_client, primary_cfg.model_id, primary_cfg.prices, tracker,
265
+ max_tokens=primary_cfg.triage_max_tokens * 4,
266
+ )
267
  spec = parser.parse(text=user_text or None, sketch_bytes=sketch_bytes)
268
  query = spec["science_query"] or user_text or "(no query)"
269
  log(f"βœ“ Query: <em>{query}</em>")
270
  if spec.get("plot_type"):
271
  log(f" Plot type: {spec['plot_type']}")
272
 
 
273
  searcher = PaperSearcher(s2_key=s2_key or None)
274
  if use_pathfinder:
275
  all_papers = searcher.expanded_search_pathfinder(query, openai_key, log=log)
276
  else:
277
  all_papers = searcher.expanded_search(
278
+ query, primary_client, primary_cfg.model_id, primary_cfg.prices, tracker,
279
+ max_tokens=primary_cfg.triage_max_tokens * 4, log=log,
280
+ )
281
  log(f"βœ“ {len(all_papers)} unique papers found")
282
 
283
+ log(f"⟳ Triaging papers (batches of {primary_cfg.batch_size})...")
284
+ triager = PaperTriager(
285
+ primary_client, primary_cfg.model_id, primary_cfg.prices, tracker,
286
+ max_tokens=primary_cfg.triage_max_tokens,
287
+ batch_size=primary_cfg.batch_size,
288
+ )
289
  triaged = triager.triage(all_papers, spec)
290
  top = triaged[:num_papers]
291
  log(f"βœ“ {len(top)} papers passed triage")
292
  paper_lookup = {p["paperId"]: p for p in top}
293
 
 
294
  log("⟳ Fetching PDFs...")
295
  downloaded = []
296
  for i, paper in enumerate(top):
 
306
  progress_placeholder.empty()
307
  log(f"βœ“ {len(downloaded)} PDFs ready")
308
 
 
309
  log("⟳ Extracting figures...")
310
  extractor = FigureExtractor()
311
  all_figures = []
 
319
  filtered = extractor.caption_filter(all_figures, query)
320
  log(f" {len(filtered)} figures after caption filter (from {len(all_figures)} total)")
321
 
322
+ log(f"⟳ Scoring {len(filtered)} figures (batches of {primary_cfg.batch_size})...")
323
+ scorer = FigureScorer(
324
+ primary_client, primary_cfg.model_id, primary_cfg.prices, tracker,
325
+ score_max_tokens=primary_cfg.score_max_tokens,
326
+ verify_max_tokens=primary_cfg.verify_max_tokens,
327
+ batch_size=primary_cfg.batch_size,
328
+ )
329
+ results = scorer.score_batch(filtered, spec)
330
+ primary_matches = [fig for fig, result in zip(filtered, results)
331
+ if result.get("confidence", 0) >= 0.5]
332
  log(f"βœ“ {len(primary_matches)} primary matches")
333
 
 
334
  verified = primary_matches
335
  if run_verify and primary_matches:
336
  log(f"⟳ Verifying {len(primary_matches)} matches...")
337
+ verifier = FigureScorer(
338
+ verify_client, verify_cfg.model_id, verify_cfg.prices, tracker,
339
+ score_max_tokens=verify_cfg.score_max_tokens,
340
+ verify_max_tokens=verify_cfg.verify_max_tokens,
341
+ batch_size=1,
342
+ )
343
  verified = []
344
  for i, fig in enumerate(primary_matches):
345
  progress_placeholder.progress((i + 1) / len(primary_matches), text=f"Verifying {i+1}/{len(primary_matches)}")
 
366
  st.session_state.global_stats["searches"] += 1
367
 
368
  except Exception as e:
369
+ import traceback
370
+ traceback.print_exc()
371
  log(f"βœ— Error: {e}")
372
  st.error(f"Pipeline error: {e}")
373
  finally:
 
376
  st.rerun()
377
 
378
 
 
379
  if st.session_state.results:
380
  res = st.session_state.results
381
  matches = res["matches"]
 
436
  st.markdown("---")
437
 
438
 
 
439
  if st.session_state.results and st.session_state.results.get("matches"):
440
  st.markdown("""
441
  <div class="feedback-box">
 
455
  st.success("Thanks!")
456
 
457
 
 
458
  stats = st.session_state.global_stats
459
  n_ratings = len(stats["ratings"])
460
  avg = sum(stats["ratings"]) / n_ratings if n_ratings else 0
 
468
  <div class="stat-item"><div class="tally-num">{"β€”" if not n_ratings else f"{avg:.1f}"}</div><div class="tally-label">Avg score</div></div>
469
  </div>
470
  </div>
471
+ """, unsafe_allow_html=True)