INLEXIO commited on
Commit
d8d5f96
Β·
verified Β·
1 Parent(s): e8c6cf8

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +390 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,392 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
1
  import streamlit as st
2
+ import requests
3
+ from sentence_transformers import SentenceTransformer
4
+ import numpy as np
5
+ from collections import defaultdict
6
+ import time
7
+
8
+ # Page config
9
+ st.set_page_config(
10
+ page_title="OpenAlex Semantic Search",
11
+ page_icon="πŸ”¬",
12
+ layout="wide"
13
+ )
14
+
15
+ # Cache the model loading
16
+ @st.cache_resource
17
+ def load_model():
18
+ """Load the sentence transformer model"""
19
+ return SentenceTransformer('all-MiniLM-L6-v2')
20
+
21
+ @st.cache_data(ttl=3600)
22
+ def search_openalex_papers(query, num_results=50):
23
+ """
24
+ Search OpenAlex for papers related to the query
25
+ """
26
+ base_url = "https://api.openalex.org/works"
27
+
28
+ params = {
29
+ "search": query,
30
+ "per_page": num_results,
31
+ "select": "id,title,abstract_inverted_index,authorships,publication_year,cited_by_count,display_name",
32
+ "mailto": "user@example.com" # Polite pool
33
+ }
34
+
35
+ try:
36
+ response = requests.get(base_url, params=params, timeout=30)
37
+ response.raise_for_status()
38
+ data = response.json()
39
+ return data.get("results", [])
40
+ except Exception as e:
41
+ st.error(f"Error fetching papers: {str(e)}")
42
+ return []
43
+
44
+ def reconstruct_abstract(inverted_index):
45
+ """
46
+ Reconstruct abstract from OpenAlex inverted index format
47
+ """
48
+ if not inverted_index:
49
+ return ""
50
+
51
+ # Create list of (position, word) tuples
52
+ words_with_positions = []
53
+ for word, positions in inverted_index.items():
54
+ for pos in positions:
55
+ words_with_positions.append((pos, word))
56
+
57
+ # Sort by position and join
58
+ words_with_positions.sort(key=lambda x: x[0])
59
+ return " ".join([word for _, word in words_with_positions])
60
+
61
+ @st.cache_data(ttl=3600)
62
+ def get_author_details(author_id):
63
+ """
64
+ Fetch detailed author information from OpenAlex
65
+ """
66
+ base_url = f"https://api.openalex.org/authors/{author_id}"
67
+
68
+ params = {
69
+ "mailto": "user@example.com"
70
+ }
71
+
72
+ try:
73
+ response = requests.get(base_url, params=params, timeout=10)
74
+ response.raise_for_status()
75
+ return response.json()
76
+ except Exception as e:
77
+ return None
78
+
79
+ def calculate_semantic_similarity(query_embedding, paper_embeddings):
80
+ """
81
+ Calculate cosine similarity between query and papers
82
+ """
83
+ # Normalize embeddings
84
+ query_norm = query_embedding / np.linalg.norm(query_embedding)
85
+ paper_norms = paper_embeddings / np.linalg.norm(paper_embeddings, axis=1, keepdims=True)
86
+
87
+ # Calculate cosine similarity
88
+ similarities = np.dot(paper_norms, query_norm)
89
+ return similarities
90
+
91
+ def rank_authors(papers, paper_scores, model, query_embedding, min_papers=2):
92
+ """
93
+ Extract authors from papers and rank them based on:
94
+ - Semantic relevance (average of their paper scores)
95
+ - H-index
96
+ - Total citations
97
+ """
98
+ author_data = defaultdict(lambda: {
99
+ 'name': '',
100
+ 'id': '',
101
+ 'paper_scores': [],
102
+ 'paper_ids': [],
103
+ 'total_citations': 0,
104
+ 'works_count': 0,
105
+ 'h_index': 0,
106
+ 'institution': ''
107
+ })
108
+
109
+ # Collect author information from papers
110
+ for paper, score in zip(papers, paper_scores):
111
+ for authorship in paper.get('authorships', []):
112
+ author = authorship.get('author', {})
113
+ author_id = author.get('id', '').split('/')[-1] if author.get('id') else None
114
+
115
+ if author_id and author_id.startswith('A'):
116
+ author_data[author_id]['name'] = author.get('display_name', 'Unknown')
117
+ author_data[author_id]['id'] = author_id
118
+ author_data[author_id]['paper_scores'].append(score)
119
+ author_data[author_id]['paper_ids'].append(paper.get('id', ''))
120
+
121
+ # Get institution
122
+ institutions = authorship.get('institutions', [])
123
+ if institutions and not author_data[author_id]['institution']:
124
+ author_data[author_id]['institution'] = institutions[0].get('display_name', '')
125
+
126
+ # Filter authors with minimum paper count
127
+ filtered_authors = {
128
+ aid: data for aid, data in author_data.items()
129
+ if len(data['paper_scores']) >= min_papers
130
+ }
131
+
132
+ # Fetch detailed metrics for each author
133
+ with st.spinner(f"Fetching metrics for {len(filtered_authors)} authors..."):
134
+ progress_bar = st.progress(0)
135
+ for idx, (author_id, data) in enumerate(filtered_authors.items()):
136
+ author_details = get_author_details(author_id)
137
+ if author_details:
138
+ data['h_index'] = author_details.get('summary_stats', {}).get('h_index', 0)
139
+ data['total_citations'] = author_details.get('cited_by_count', 0)
140
+ data['works_count'] = author_details.get('works_count', 0)
141
+
142
+ progress_bar.progress((idx + 1) / len(filtered_authors))
143
+ time.sleep(0.1) # Rate limiting
144
+
145
+ progress_bar.empty()
146
+
147
+ # Calculate composite score for ranking
148
+ ranked_authors = []
149
+ for author_id, data in filtered_authors.items():
150
+ avg_relevance = np.mean(data['paper_scores'])
151
+
152
+ # Normalize metrics (using log scale for citations)
153
+ normalized_h_index = data['h_index'] / 100.0 # Assume max h-index of 100
154
+ normalized_citations = np.log1p(data['total_citations']) / 15.0 # Log scale
155
+
156
+ # Composite score: weighted combination
157
+ composite_score = (
158
+ 0.5 * avg_relevance + # 50% semantic relevance
159
+ 0.3 * normalized_h_index + # 30% h-index
160
+ 0.2 * normalized_citations # 20% citations
161
+ )
162
+
163
+ ranked_authors.append({
164
+ 'author_id': author_id,
165
+ 'name': data['name'],
166
+ 'institution': data['institution'],
167
+ 'h_index': data['h_index'],
168
+ 'total_citations': data['total_citations'],
169
+ 'works_count': data['works_count'],
170
+ 'num_relevant_papers': len(data['paper_scores']),
171
+ 'avg_relevance_score': avg_relevance,
172
+ 'composite_score': composite_score,
173
+ 'openalex_url': f"https://openalex.org/{author_id}"
174
+ })
175
+
176
+ # Sort by composite score
177
+ ranked_authors.sort(key=lambda x: x['composite_score'], reverse=True)
178
+
179
+ return ranked_authors
180
+
181
+ def main():
182
+ st.title("πŸ”¬ OpenAlex Semantic Search")
183
+ st.markdown("""
184
+ Search for academic papers and discover top researchers using semantic search powered by OpenAlex.
185
+
186
+ **How it works:**
187
+ 1. Enter your search terms (e.g., "machine learning for drug discovery")
188
+ 2. The app finds relevant papers using semantic similarity
189
+ 3. Authors are ranked by relevance, h-index, and citation metrics
190
+ """)
191
+
192
+ # Sidebar controls
193
+ st.sidebar.header("Search Settings")
194
+
195
+ num_papers = st.sidebar.slider(
196
+ "Number of papers to fetch",
197
+ min_value=20,
198
+ max_value=100,
199
+ value=50,
200
+ step=10
201
+ )
202
+
203
+ top_papers_display = st.sidebar.slider(
204
+ "Top papers to display",
205
+ min_value=5,
206
+ max_value=30,
207
+ value=10,
208
+ step=5
209
+ )
210
+
211
+ top_authors_display = st.sidebar.slider(
212
+ "Top authors to display",
213
+ min_value=5,
214
+ max_value=50,
215
+ value=20,
216
+ step=5
217
+ )
218
+
219
+ min_papers_per_author = st.sidebar.slider(
220
+ "Minimum papers per author",
221
+ min_value=1,
222
+ max_value=5,
223
+ value=2,
224
+ step=1,
225
+ help="Minimum number of relevant papers an author must have to be included"
226
+ )
227
+
228
+ # Main search input
229
+ query = st.text_input(
230
+ "Enter your search query:",
231
+ placeholder="e.g., 'graph neural networks for protein structure prediction'",
232
+ help="Enter keywords or a description of what you're looking for"
233
+ )
234
+
235
+ search_button = st.button("πŸ” Search", type="primary")
236
+
237
+ if search_button and query:
238
+ # Load model
239
+ with st.spinner("Loading semantic model..."):
240
+ model = load_model()
241
+
242
+ # Search papers
243
+ with st.spinner(f"Searching OpenAlex for papers about '{query}'..."):
244
+ papers = search_openalex_papers(query, num_papers)
245
+
246
+ if not papers:
247
+ st.warning("No papers found. Try different search terms.")
248
+ return
249
+
250
+ st.success(f"Found {len(papers)} papers!")
251
+
252
+ # Prepare papers for semantic search
253
+ with st.spinner("Analyzing papers with semantic search..."):
254
+ paper_texts = []
255
+ valid_papers = []
256
+
257
+ for paper in papers:
258
+ title = paper.get('display_name', '') or paper.get('title', '')
259
+ abstract = reconstruct_abstract(paper.get('abstract_inverted_index', {}))
260
+
261
+ # Combine title and abstract (title weighted more)
262
+ text = f"{title} {title} {abstract}" # Title appears twice for emphasis
263
+
264
+ if text.strip():
265
+ paper_texts.append(text)
266
+ valid_papers.append(paper)
267
+
268
+ if not paper_texts:
269
+ st.error("No valid paper content found.")
270
+ return
271
+
272
+ # Generate embeddings
273
+ query_embedding = model.encode(query, convert_to_tensor=False)
274
+ paper_embeddings = model.encode(paper_texts, convert_to_tensor=False, show_progress_bar=True)
275
+
276
+ # Calculate similarities
277
+ similarities = calculate_semantic_similarity(query_embedding, paper_embeddings)
278
+
279
+ # Sort papers by similarity
280
+ sorted_indices = np.argsort(similarities)[::-1]
281
+ sorted_papers = [valid_papers[i] for i in sorted_indices]
282
+ sorted_scores = [similarities[i] for i in sorted_indices]
283
+
284
+ # Display top papers
285
+ st.header(f"πŸ“„ Top {top_papers_display} Most Relevant Papers")
286
+
287
+ for idx, (paper, score) in enumerate(zip(sorted_papers[:top_papers_display], sorted_scores[:top_papers_display])):
288
+ with st.expander(f"**{idx+1}. {paper.get('display_name', 'Untitled')}** (Relevance: {score:.3f})"):
289
+ col1, col2 = st.columns([3, 1])
290
+
291
+ with col1:
292
+ abstract = reconstruct_abstract(paper.get('abstract_inverted_index', {}))
293
+ if abstract:
294
+ st.markdown(f"**Abstract:** {abstract[:500]}{'...' if len(abstract) > 500 else ''}")
295
+ else:
296
+ st.markdown("*No abstract available*")
297
+
298
+ # Authors
299
+ authors = [a.get('author', {}).get('display_name', 'Unknown')
300
+ for a in paper.get('authorships', [])]
301
+ if authors:
302
+ st.markdown(f"**Authors:** {', '.join(authors[:5])}{'...' if len(authors) > 5 else ''}")
303
+
304
+ with col2:
305
+ st.metric("Year", paper.get('publication_year', 'N/A'))
306
+ st.metric("Citations", paper.get('cited_by_count', 0))
307
+
308
+ paper_id = paper.get('id', '').split('/')[-1]
309
+ if paper_id:
310
+ st.markdown(f"[View on OpenAlex](https://openalex.org/{paper_id})")
311
+
312
+ # Rank authors
313
+ st.header(f"πŸ‘¨β€πŸ”¬ Top {top_authors_display} Researchers")
314
+
315
+ ranked_authors = rank_authors(
316
+ sorted_papers,
317
+ sorted_scores,
318
+ model,
319
+ query_embedding,
320
+ min_papers=min_papers_per_author
321
+ )
322
+
323
+ if not ranked_authors:
324
+ st.warning(f"No authors found with at least {min_papers_per_author} relevant papers.")
325
+ return
326
+
327
+ # Display authors in a table
328
+ st.markdown(f"Found {len(ranked_authors)} researchers with at least {min_papers_per_author} relevant papers.")
329
+
330
+ for idx, author in enumerate(ranked_authors[:top_authors_display], 1):
331
+ with st.container():
332
+ col1, col2, col3, col4 = st.columns([3, 1, 1, 1])
333
+
334
+ with col1:
335
+ st.markdown(f"**{idx}. [{author['name']}]({author['openalex_url']})**")
336
+ if author['institution']:
337
+ st.caption(author['institution'])
338
+
339
+ with col2:
340
+ st.metric("H-Index", author['h_index'])
341
+
342
+ with col3:
343
+ st.metric("Citations", f"{author['total_citations']:,}")
344
+
345
+ with col4:
346
+ st.metric("Relevance", f"{author['avg_relevance_score']:.3f}")
347
+
348
+ st.caption(f"Total works: {author['works_count']} | Relevant papers: {author['num_relevant_papers']}")
349
+ st.divider()
350
+
351
+ # Download results
352
+ st.header("πŸ“₯ Download Results")
353
+
354
+ # Prepare CSV data for authors
355
+ import io
356
+ import csv
357
+
358
+ csv_buffer = io.StringIO()
359
+ csv_writer = csv.writer(csv_buffer)
360
+
361
+ # Write header
362
+ csv_writer.writerow([
363
+ 'Rank', 'Name', 'Institution', 'H-Index', 'Total Citations',
364
+ 'Total Works', 'Relevant Papers', 'Avg Relevance Score', 'Composite Score', 'OpenAlex URL'
365
+ ])
366
+
367
+ # Write data
368
+ for idx, author in enumerate(ranked_authors, 1):
369
+ csv_writer.writerow([
370
+ idx,
371
+ author['name'],
372
+ author['institution'],
373
+ author['h_index'],
374
+ author['total_citations'],
375
+ author['works_count'],
376
+ author['num_relevant_papers'],
377
+ f"{author['avg_relevance_score']:.4f}",
378
+ f"{author['composite_score']:.4f}",
379
+ author['openalex_url']
380
+ ])
381
+
382
+ csv_data = csv_buffer.getvalue()
383
+
384
+ st.download_button(
385
+ label="Download Author Rankings (CSV)",
386
+ data=csv_data,
387
+ file_name=f"openalex_authors_{query.replace(' ', '_')[:30]}.csv",
388
+ mime="text/csv"
389
+ )
390
 
391
+ if __name__ == "__main__":
392
+ main()