htohfa commited on
Commit
2502a3c
·
1 Parent(s): 6cc68be

adding pathfinder capabilities

Browse files
Files changed (5) hide show
  1. .DS_Store +0 -0
  2. app.py +53 -16
  3. foto/pathfinder_search.py +118 -0
  4. foto/search.py +20 -4
  5. requirements.txt +4 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
app.py CHANGED
@@ -68,11 +68,15 @@ html, body, [class*="css"] { font-family: 'Inter', sans-serif; font-weight: 300;
68
  div[data-testid="stSelectbox"] label,
69
  div[data-testid="stTextInput"] label,
70
  div[data-testid="stCheckbox"] label { font-family: 'DM Mono', monospace; font-size: 0.75rem; letter-spacing: 0.08em; text-transform: uppercase; color: #888; }
 
 
 
 
71
  </style>
72
  """, unsafe_allow_html=True)
73
 
74
 
75
- # Session state
76
  for key, default in {
77
  "pdf_cache": {},
78
  "results": None,
@@ -107,7 +111,7 @@ with col_left:
107
  api_key = st.text_input("Anthropic API Key", type="password", label_visibility="collapsed", placeholder="sk-ant-...")
108
 
109
  st.markdown('<p class="section-label" style="margin-top:0.8rem;">Semantic Scholar Key (optional)</p>', unsafe_allow_html=True)
110
- s2_key = st.text_input("S2 Key", type="password", label_visibility="collapsed", placeholder="(Recommended)")
111
 
112
  st.markdown('<p class="section-label" style="margin-top:1.5rem;">Describe the figure</p>', unsafe_allow_html=True)
113
  user_text = st.text_area(
@@ -118,6 +122,34 @@ with col_left:
118
  st.markdown('<p class="section-label" style="margin-top:0.8rem;">Upload a sketch (optional)</p>', unsafe_allow_html=True)
119
  sketch_file = st.file_uploader("Sketch", type=["png", "jpg", "jpeg", "webp"], label_visibility="collapsed")
120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  run_verify = st.checkbox("Secondary verification — recommended, adds ~$0.05", value=True)
122
  st.markdown('<p style="font-size:0.78rem;color:#888;margin-top:-0.8rem;margin-left:1.8rem;">Uses a smarter model to double-check top matches. Best results, small extra cost.</p>', unsafe_allow_html=True)
123
  num_papers = st.slider("Papers to search", min_value=5, max_value=50, value=20, step=5)
@@ -137,10 +169,12 @@ with col_right:
137
  """, unsafe_allow_html=True)
138
 
139
 
140
- # Pipeline
141
  if run_btn:
142
  if not api_key:
143
  st.error("Please enter your Anthropic API key.")
 
 
144
  elif not user_text and not sketch_file:
145
  st.error("Please enter a description or upload a sketch (or both).")
146
  else:
@@ -166,7 +200,7 @@ if run_btn:
166
  st.session_state.log = []
167
 
168
  try:
169
- # Parse input
170
  log("⟳ Parsing your description...")
171
  parser = InputParser(client, model_cfg.smart, tracker)
172
  spec = parser.parse(text=user_text or None, sketch_bytes=sketch_bytes)
@@ -175,13 +209,16 @@ if run_btn:
175
  if spec.get("plot_type"):
176
  log(f" Plot type: {spec['plot_type']}")
177
 
178
- # Search
179
  searcher = PaperSearcher(s2_key=s2_key or None)
180
- all_papers = searcher.expanded_search(
181
- query, client, model_cfg.smart, tracker, log=log)
 
 
 
182
  log(f"✓ {len(all_papers)} unique papers found")
183
 
184
- # Triage
185
  log("⟳ Triaging with Claude...")
186
  triager = PaperTriager(client, model_cfg.cheap, tracker)
187
  triaged = triager.triage(all_papers, spec)
@@ -189,7 +226,7 @@ if run_btn:
189
  log(f"✓ {len(top)} papers passed triage")
190
  paper_lookup = {p["paperId"]: p for p in top}
191
 
192
- # Fetch PDFs
193
  log("⟳ Fetching PDFs...")
194
  downloaded = []
195
  for i, paper in enumerate(top):
@@ -205,7 +242,7 @@ if run_btn:
205
  progress_placeholder.empty()
206
  log(f"✓ {len(downloaded)} PDFs ready")
207
 
208
- # Extract figures
209
  log("⟳ Extracting figures...")
210
  extractor = FigureExtractor()
211
  all_figures = []
@@ -219,7 +256,7 @@ if run_btn:
219
  filtered = extractor.caption_filter(all_figures, query)
220
  log(f" {len(filtered)} figures after caption filter (from {len(all_figures)} total)")
221
 
222
- # Primary scoring
223
  log(f"⟳ Scoring {len(filtered)} figures...")
224
  scorer = FigureScorer(client, model_cfg.cheap, tracker)
225
  primary_matches = []
@@ -231,7 +268,7 @@ if run_btn:
231
  progress_placeholder.empty()
232
  log(f"✓ {len(primary_matches)} primary matches")
233
 
234
- # Verification
235
  verified = primary_matches
236
  if run_verify and primary_matches:
237
  log(f"⟳ Verifying {len(primary_matches)} matches...")
@@ -270,7 +307,7 @@ if run_btn:
270
  st.rerun()
271
 
272
 
273
- # Results
274
  if st.session_state.results:
275
  res = st.session_state.results
276
  matches = res["matches"]
@@ -331,7 +368,7 @@ if st.session_state.results:
331
  st.markdown("---")
332
 
333
 
334
- # Feedback
335
  if st.session_state.results and st.session_state.results.get("matches"):
336
  st.markdown("""
337
  <div class="feedback-box">
@@ -351,7 +388,7 @@ if st.session_state.results and st.session_state.results.get("matches"):
351
  st.success("Thanks!")
352
 
353
 
354
- # Persistent tally
355
  stats = st.session_state.global_stats
356
  n_ratings = len(stats["ratings"])
357
  avg = sum(stats["ratings"]) / n_ratings if n_ratings else 0
@@ -365,4 +402,4 @@ st.markdown(f"""
365
  <div class="stat-item"><div class="tally-num">{"—" if not n_ratings else f"{avg:.1f}"}</div><div class="tally-label">Avg score</div></div>
366
  </div>
367
  </div>
368
- """, unsafe_allow_html=True)
 
68
  div[data-testid="stSelectbox"] label,
69
  div[data-testid="stTextInput"] label,
70
  div[data-testid="stCheckbox"] label { font-family: 'DM Mono', monospace; font-size: 0.75rem; letter-spacing: 0.08em; text-transform: uppercase; color: #888; }
71
+
72
+ .pathfinder-row { display: flex; align-items: center; gap: 0.5rem; }
73
+ .pathfinder-cite { font-family: 'DM Mono', monospace; font-size: 0.7rem; color: #888; }
74
+ .pathfinder-cite a { color: #888; text-decoration: underline; }
75
  </style>
76
  """, unsafe_allow_html=True)
77
 
78
 
79
+ # Per-session state for caching, run status, log buffer, and feedback tally
80
  for key, default in {
81
  "pdf_cache": {},
82
  "results": None,
 
111
  api_key = st.text_input("Anthropic API Key", type="password", label_visibility="collapsed", placeholder="sk-ant-...")
112
 
113
  st.markdown('<p class="section-label" style="margin-top:0.8rem;">Semantic Scholar Key (optional)</p>', unsafe_allow_html=True)
114
+ s2_key = st.text_input("S2 Key", type="password", label_visibility="collapsed", placeholder="(Recommended for keyword fallback)")
115
 
116
  st.markdown('<p class="section-label" style="margin-top:1.5rem;">Describe the figure</p>', unsafe_allow_html=True)
117
  user_text = st.text_area(
 
122
  st.markdown('<p class="section-label" style="margin-top:0.8rem;">Upload a sketch (optional)</p>', unsafe_allow_html=True)
123
  sketch_file = st.file_uploader("Sketch", type=["png", "jpg", "jpeg", "webp"], label_visibility="collapsed")
124
 
125
+ # Pathfinder toggle + inline citation link
126
+ pf_col1, pf_col2 = st.columns([2, 3])
127
+ with pf_col1:
128
+ use_pathfinder = st.checkbox("Use Pathfinder", value=True)
129
+ with pf_col2:
130
+ st.markdown(
131
+ '<div class="pathfinder-cite" style="padding-top:0.55rem;">'
132
+ 'based on <a href="https://arxiv.org/abs/2408.01556" target="_blank">arXiv:2408.01556</a>'
133
+ '</div>',
134
+ unsafe_allow_html=True,
135
+ )
136
+
137
+ # OpenAI key only needed when Pathfinder is active (used to embed queries
138
+ # with text-embedding-3-small against the Pathfinder corpus)
139
+ openai_key = None
140
+ if use_pathfinder:
141
+ st.markdown('<p class="section-label" style="margin-top:0.6rem;">OpenAI API Key</p>', unsafe_allow_html=True)
142
+ openai_key = st.text_input(
143
+ "OpenAI Key", type="password", label_visibility="collapsed",
144
+ placeholder="sk-...",
145
+ )
146
+ st.markdown(
147
+ '<p style="font-size:0.78rem;color:#888;margin-top:-0.4rem;">'
148
+ 'Used to embed queries with text-embedding-3-small (~$0.40 per million queries).'
149
+ '</p>',
150
+ unsafe_allow_html=True,
151
+ )
152
+
153
  run_verify = st.checkbox("Secondary verification — recommended, adds ~$0.05", value=True)
154
  st.markdown('<p style="font-size:0.78rem;color:#888;margin-top:-0.8rem;margin-left:1.8rem;">Uses a smarter model to double-check top matches. Best results, small extra cost.</p>', unsafe_allow_html=True)
155
  num_papers = st.slider("Papers to search", min_value=5, max_value=50, value=20, step=5)
 
169
  """, unsafe_allow_html=True)
170
 
171
 
172
+ # Full pipeline runs on button press
173
  if run_btn:
174
  if not api_key:
175
  st.error("Please enter your Anthropic API key.")
176
+ elif use_pathfinder and not openai_key:
177
+ st.error("Pathfinder is checked — please enter your OpenAI API key, or uncheck Pathfinder to use keyword search.")
178
  elif not user_text and not sketch_file:
179
  st.error("Please enter a description or upload a sketch (or both).")
180
  else:
 
200
  st.session_state.log = []
201
 
202
  try:
203
+ # Parse text + optional sketch into a structured search spec
204
  log("⟳ Parsing your description...")
205
  parser = InputParser(client, model_cfg.smart, tracker)
206
  spec = parser.parse(text=user_text or None, sketch_bytes=sketch_bytes)
 
209
  if spec.get("plot_type"):
210
  log(f" Plot type: {spec['plot_type']}")
211
 
212
+ # Pathfinder semantic retrieval, or legacy keyword expansion as fallback
213
  searcher = PaperSearcher(s2_key=s2_key or None)
214
+ if use_pathfinder:
215
+ all_papers = searcher.expanded_search_pathfinder(query, openai_key, log=log)
216
+ else:
217
+ all_papers = searcher.expanded_search(
218
+ query, client, model_cfg.smart, tracker, log=log)
219
  log(f"✓ {len(all_papers)} unique papers found")
220
 
221
+ # Abstract-level relevance filter to cut downstream cost
222
  log("⟳ Triaging with Claude...")
223
  triager = PaperTriager(client, model_cfg.cheap, tracker)
224
  triaged = triager.triage(all_papers, spec)
 
226
  log(f"✓ {len(top)} papers passed triage")
227
  paper_lookup = {p["paperId"]: p for p in top}
228
 
229
+ # PDF fetch with arxiv-first URL preference, polite spacing
230
  log("⟳ Fetching PDFs...")
231
  downloaded = []
232
  for i, paper in enumerate(top):
 
242
  progress_placeholder.empty()
243
  log(f"✓ {len(downloaded)} PDFs ready")
244
 
245
+ # Pull raster figures + captions from each PDF, then caption pre-filter
246
  log("⟳ Extracting figures...")
247
  extractor = FigureExtractor()
248
  all_figures = []
 
256
  filtered = extractor.caption_filter(all_figures, query)
257
  log(f" {len(filtered)} figures after caption filter (from {len(all_figures)} total)")
258
 
259
+ # Cheap vision pass: score every surviving figure against the spec
260
  log(f"⟳ Scoring {len(filtered)} figures...")
261
  scorer = FigureScorer(client, model_cfg.cheap, tracker)
262
  primary_matches = []
 
268
  progress_placeholder.empty()
269
  log(f"✓ {len(primary_matches)} primary matches")
270
 
271
+ # Optional smart-model verification on figures that passed primary scoring
272
  verified = primary_matches
273
  if run_verify and primary_matches:
274
  log(f"⟳ Verifying {len(primary_matches)} matches...")
 
307
  st.rerun()
308
 
309
 
310
+ # Render results: stats row, downloadable zip, then per-figure cards with metadata
311
  if st.session_state.results:
312
  res = st.session_state.results
313
  matches = res["matches"]
 
368
  st.markdown("---")
369
 
370
 
371
+ # Post-search feedback slider; submission logs to persistence layer
372
  if st.session_state.results and st.session_state.results.get("matches"):
373
  st.markdown("""
374
  <div class="feedback-box">
 
388
  st.success("Thanks!")
389
 
390
 
391
+ # Aggregate stats across all sessions, loaded from persistence
392
  stats = st.session_state.global_stats
393
  n_ratings = len(stats["ratings"])
394
  avg = sum(stats["ratings"]) / n_ratings if n_ratings else 0
 
402
  <div class="stat-item"><div class="tally-num">{"—" if not n_ratings else f"{avg:.1f}"}</div><div class="tally-label">Avg score</div></div>
403
  </div>
404
  </div>
405
+ """, unsafe_allow_html=True)
foto/pathfinder_search.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Semantic search over the Pathfinder astronomy corpus (Iyer et al. 2024,
2
+ arXiv:2408.01556).
3
+
4
+ The corpus ships with pre-computed text-embedding-3-small vectors for each
5
+ paper. Queries are embedded with the same OpenAI model and matched via FAISS
6
+ on the embedding column. Output format matches the rest of foto's search
7
+ layer so it slots in interchangeably with keyword search.
8
+ """
9
+ import os
10
+ from pathlib import Path
11
+ from typing import Optional
12
+
13
+ import numpy as np
14
+ import streamlit as st
15
+ from datasets import load_from_disk, load_dataset
16
+
17
+
18
+ DATASET_NAME = "kiyer/pathfinder_arxiv_data"
19
+ EMBEDDING_MODEL = "text-embedding-3-small"
20
+
21
+ # First-run download lands here; subsequent runs load_from_disk straight from cache
22
+ DATA_DIR = Path.home() / ".cache" / "foto" / "pathfinder_data"
23
+
24
+
25
+ @st.cache_resource(show_spinner="Loading Pathfinder corpus (~5 GB on first run)...")
26
+ def load_pathfinder_corpus():
27
+ """Returns the dataset with a FAISS index attached to the embed column.
28
+ Downloads from HF on first call, reuses local cache afterward."""
29
+ if not DATA_DIR.exists():
30
+ DATA_DIR.mkdir(parents=True, exist_ok=True)
31
+ ds = load_dataset(DATASET_NAME, split="train")
32
+ ds.save_to_disk(str(DATA_DIR))
33
+ else:
34
+ ds = load_from_disk(str(DATA_DIR))
35
+
36
+ if not ds.is_index_initialized("embed"):
37
+ ds.add_faiss_index(column="embed")
38
+
39
+ return ds
40
+
41
+
42
+ def make_embedder(openai_key: str):
43
+ """Returns a function that embeds text into a 1536-dim vector with
44
+ text-embedding-3-small. Key is user-supplied so this is not cached."""
45
+ from openai import OpenAI
46
+
47
+ if not openai_key:
48
+ raise RuntimeError(
49
+ "Pathfinder uses text-embedding-3-small from OpenAI. Set an "
50
+ "OpenAI API key in the sidebar (get one at platform.openai.com)."
51
+ )
52
+ client = OpenAI(api_key=openai_key)
53
+
54
+ def embed(text: str) -> np.ndarray:
55
+ resp = client.embeddings.create(model=EMBEDDING_MODEL, input=[text])
56
+ return np.array(resp.data[0].embedding, dtype=np.float32)
57
+
58
+ return embed
59
+
60
+
61
+ def _row_to_paper(row: dict, similarity: float) -> dict:
62
+ """Map a Pathfinder dataset row into foto's paper dict shape."""
63
+ arxiv_id = row.get("arxiv_id") or ""
64
+
65
+ year = None
66
+ d = row.get("date")
67
+ if d is not None:
68
+ try:
69
+ year = d.year
70
+ except AttributeError:
71
+ s = str(d)
72
+ year = int(s[:4]) if s[:4].isdigit() else None
73
+
74
+ # Pathfinder stores authors as a list of strings; foto wants [{"name": ...}, ...]
75
+ raw_authors = row.get("authors") or []
76
+ if raw_authors and isinstance(raw_authors[0], str):
77
+ authors = [{"name": a} for a in raw_authors]
78
+ else:
79
+ authors = raw_authors
80
+
81
+ return {
82
+ "paperId": f"arxiv_{arxiv_id}" if arxiv_id else f"ads_{row.get('ads_id', '')}",
83
+ "title": row.get("title", ""),
84
+ "abstract": row.get("abstract", ""),
85
+ "year": year,
86
+ "authors": authors,
87
+ "externalIds": {"ArXiv": arxiv_id} if arxiv_id else {},
88
+ "openAccessPdf": {"url": f"https://arxiv.org/pdf/{arxiv_id}"} if arxiv_id else {},
89
+ "citationCount": row.get("cites", 0) or 0,
90
+ "_source": "pathfinder",
91
+ "_pathfinder_score": similarity,
92
+ }
93
+
94
+
95
+ class PathfinderSearcher:
96
+ """Semantic retrieval over the Pathfinder corpus.
97
+ Output format matches PaperSearcher.search_s2 so downstream code is unchanged."""
98
+
99
+ def __init__(self, openai_key: str):
100
+ self.dataset = load_pathfinder_corpus()
101
+ self.embed = make_embedder(openai_key)
102
+
103
+ def search(self, query: str, limit: int = 50) -> list[dict]:
104
+ query_vec = self.embed(query)
105
+
106
+ tmp = self.dataset.search("embed", query_vec, k=limit)
107
+
108
+ results = []
109
+ for idx, dist in zip(tmp.indices, tmp.scores):
110
+ row = self.dataset[int(idx)]
111
+ # Skip papers with no arxiv_id — the download step needs it
112
+ if not row.get("arxiv_id"):
113
+ continue
114
+ # Convert FAISS distance to similarity, matching Pathfinder's convention
115
+ similarity = 1.0 / (1.0 + float(dist))
116
+ results.append(_row_to_paper(row, similarity))
117
+
118
+ return results
foto/search.py CHANGED
@@ -146,6 +146,8 @@ class PaperSearcher:
146
  return scored[:top_n]
147
 
148
  def expanded_search(self, query: str, client, model: str, tracker, log=None) -> list[dict]:
 
 
149
  def _log(msg):
150
  if log:
151
  log(msg)
@@ -153,7 +155,7 @@ class PaperSearcher:
153
  seen_ids, seen_titles = set(), set()
154
  all_results = []
155
 
156
- # Round 1: expanded queries from Claude
157
  try:
158
  response = client.messages.create(
159
  model=model, max_tokens=200,
@@ -175,7 +177,7 @@ class PaperSearcher:
175
 
176
  _log(f" {len(all_results)} papers after round 1")
177
 
178
- # Round 2: author search
179
  top_authors = self._top_authors(all_results)
180
  if top_authors:
181
  for author in top_authors:
@@ -185,7 +187,7 @@ class PaperSearcher:
185
  time.sleep(1)
186
  _log(f" {len(all_results)} papers after author search")
187
 
188
- # Round 3: adjacent topics from landmarks
189
  landmarks = self._landmarks(all_results)
190
  if landmarks:
191
  titles = "\n".join(f"- {p['title']}" for p in landmarks)
@@ -210,6 +212,20 @@ class PaperSearcher:
210
  all_results.sort(key=lambda p: -(p.get("citationCount", 0) or 0))
211
  return all_results
212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
 
214
  class PaperTriager:
215
  def __init__(self, client, model: str, tracker):
@@ -243,4 +259,4 @@ class PaperTriager:
243
 
244
  scored.sort(key=lambda p: -(p.get("citationCount") or 0) * 0.1
245
  - p["_triage"].get("confidence", 0))
246
- return scored
 
146
  return scored[:top_n]
147
 
148
  def expanded_search(self, query: str, client, model: str, tracker, log=None) -> list[dict]:
149
+ """Keyword-based retrieval with 3 rounds of LLM-driven query expansion.
150
+ Used as fallback when Pathfinder semantic search is disabled."""
151
  def _log(msg):
152
  if log:
153
  log(msg)
 
155
  seen_ids, seen_titles = set(), set()
156
  all_results = []
157
 
158
+ # Round 1: LLM expands the original query into subfield-specific variants
159
  try:
160
  response = client.messages.create(
161
  model=model, max_tokens=200,
 
177
 
178
  _log(f" {len(all_results)} papers after round 1")
179
 
180
+ # Round 2: pull recent papers from top citation-weighted authors
181
  top_authors = self._top_authors(all_results)
182
  if top_authors:
183
  for author in top_authors:
 
187
  time.sleep(1)
188
  _log(f" {len(all_results)} papers after author search")
189
 
190
+ # Round 3: LLM proposes adjacent subfields based on landmark titles
191
  landmarks = self._landmarks(all_results)
192
  if landmarks:
193
  titles = "\n".join(f"- {p['title']}" for p in landmarks)
 
212
  all_results.sort(key=lambda p: -(p.get("citationCount", 0) or 0))
213
  return all_results
214
 
215
+ def expanded_search_pathfinder(self, query: str, openai_key: str, log=None) -> list[dict]:
216
+ """Semantic retrieval against the Pathfinder corpus (Iyer et al. 2024,
217
+ arXiv:2408.01556). Embeds query with text-embedding-3-small, returns
218
+ top-K papers by FAISS similarity over precomputed embeddings."""
219
+ from .pathfinder_search import PathfinderSearcher
220
+
221
+ if log:
222
+ log("⟳ Semantic search via Pathfinder corpus...")
223
+ searcher = PathfinderSearcher(openai_key=openai_key)
224
+ results = searcher.search(query, limit=50)
225
+ if log:
226
+ log(f" ✓ {len(results)} papers retrieved (semantic similarity)")
227
+ return results
228
+
229
 
230
  class PaperTriager:
231
  def __init__(self, client, model: str, tracker):
 
259
 
260
  scored.sort(key=lambda p: -(p.get("citationCount") or 0) * 0.1
261
  - p["_triage"].get("confidence", 0))
262
+ return scored
requirements.txt CHANGED
@@ -1,6 +1,10 @@
1
  streamlit>=1.35.0
2
  anthropic>=0.25.0
 
3
  requests>=2.31.0
4
  PyMuPDF>=1.24.0
5
  Pillow>=10.0.0
6
  gspread>=6.0.0
 
 
 
 
1
  streamlit>=1.35.0
2
  anthropic>=0.25.0
3
+ openai>=1.0.0
4
  requests>=2.31.0
5
  PyMuPDF>=1.24.0
6
  Pillow>=10.0.0
7
  gspread>=6.0.0
8
+ google-auth>=2.0.0
9
+ datasets>=2.18.0
10
+ faiss-cpu>=1.7.4