livctr commited on
Commit
6c2a7c2
·
0 Parent(s):

first commit

Browse files
.gitignore ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .env
2
+ data/*
3
+ runs/*
4
+ logs/*
5
+ nbs/*
6
+
7
+
8
+
9
+ # Byte-compiled / optimized / DLL files
10
+ __pycache__/
11
+ *.py[cod]
12
+ *$py.class
13
+
14
+ # C extensions
15
+ *.so
16
+
17
+ # Distribution / packaging
18
+ .Python
19
+ build/
20
+ develop-eggs/
21
+ dist/
22
+ downloads/
23
+ eggs/
24
+ .eggs/
25
+ lib/
26
+ lib64/
27
+ parts/
28
+ sdist/
29
+ var/
30
+ wheels/
31
+ share/python-wheels/
32
+ *.egg-info/
33
+ .installed.cfg
34
+ *.egg
35
+ MANIFEST
36
+
37
+ # PyInstaller
38
+ # Usually these files are written by a python script from a template
39
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
40
+ *.manifest
41
+ *.spec
42
+
43
+ # Installer logs
44
+ pip-log.txt
45
+ pip-delete-this-directory.txt
46
+
47
+ # Unit test / coverage reports
48
+ htmlcov/
49
+ .tox/
50
+ .nox/
51
+ .coverage
52
+ .coverage.*
53
+ .cache
54
+ nosetests.xml
55
+ coverage.xml
56
+ *.cover
57
+ *.py,cover
58
+ .hypothesis/
59
+ .pytest_cache/
60
+ cover/
61
+
62
+ # Translations
63
+ *.mo
64
+ *.pot
65
+
66
+ # Django stuff:
67
+ *.log
68
+ local_settings.py
69
+ db.sqlite3
70
+ db.sqlite3-journal
71
+
72
+ # Flask stuff:
73
+ instance/
74
+ .webassets-cache
75
+
76
+ # Scrapy stuff:
77
+ .scrapy
78
+
79
+ # Sphinx documentation
80
+ docs/_build/
81
+
82
+ # PyBuilder
83
+ .pybuilder/
84
+ target/
85
+
86
+ # Jupyter Notebook
87
+ .ipynb_checkpoints
88
+
89
+ # IPython
90
+ profile_default/
91
+ ipython_config.py
92
+
93
+ # pyenv
94
+ # For a library or package, you might want to ignore these files since the code is
95
+ # intended to run in multiple environments; otherwise, check them in:
96
+ # .python-version
97
+
98
+ # pipenv
99
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
100
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
101
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
102
+ # install all needed dependencies.
103
+ #Pipfile.lock
104
+
105
+ # poetry
106
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
107
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
108
+ # commonly ignored for libraries.
109
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
110
+ #poetry.lock
111
+
112
+ # pdm
113
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
114
+ #pdm.lock
115
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
116
+ # in version control.
117
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
118
+ .pdm.toml
119
+ .pdm-python
120
+ .pdm-build/
121
+
122
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
123
+ __pypackages__/
124
+
125
+ # Celery stuff
126
+ celerybeat-schedule
127
+ celerybeat.pid
128
+
129
+ # SageMath parsed files
130
+ *.sage.py
131
+
132
+ # Environments
133
+ .env
134
+ .venv
135
+ env/
136
+ venv/
137
+ ENV/
138
+ env.bak/
139
+ venv.bak/
140
+
141
+ # Spyder project settings
142
+ .spyderproject
143
+ .spyproject
144
+
145
+ # Rope project settings
146
+ .ropeproject
147
+
148
+ # mkdocs documentation
149
+ /site
150
+
151
+ # mypy
152
+ .mypy_cache/
153
+ .dmypy.json
154
+ dmypy.json
155
+
156
+ # Pyre type checker
157
+ .pyre/
158
+
159
+ # pytype static type analyzer
160
+ .pytype/
161
+
162
+ # Cython debug symbols
163
+ cython_debug/
164
+
165
+ # PyCharm
166
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
167
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
168
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
169
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
170
+ #.idea/
README.md ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # U.S. ML PhD Recomendation System
2
+
3
+ Disclaimer: results are not 100% accurate and there is likely some bias to how papers / professors are filtered.
4
+
5
+ ### Data Pipeline
6
+
7
+ First, a list of authors are gathered from recent conference proceedings. A batched RAG pipeline is used to determine which persons are U.S. professors (unsure how accurate the LLM here is). This can be reproduced as follows:
8
+
9
+ #### Repeat research until satisfactory
10
+
11
+ ```python
12
+ # Scrape top conferences for potential U.S.-based professors, ~45 mins
13
+ python -m data_pipeline.conference_scraper
14
+ ```
15
+ **Selected conferences**
16
+ - NeurIPS: 2022, 2023
17
+ - ICML: 2023, 2024
18
+ - AISTATS: 2023, 2024
19
+ - COLT: 2023, 2024
20
+ - AAAI: 2023, 2024
21
+ - EMNLP: 2023, 2024
22
+ - CVPR: 2023, 2024
23
+
24
+ ```python
25
+ # Search authors and locally store search results. Uses Bing web search API.
26
+ python -m data_pipeline.us_professor_verifier --batch_search
27
+ ```
28
+
29
+ NOTE 1: you may encounter caught exceptions due to HTTPError or invalid JSON outputs from the LLM. Would suggest to run the above multiple times until results are satisfactory.
30
+
31
+ NOTE 2: This pipeline does not handle name collisions, name changes, initials.
32
+
33
+ #### Create file containing U.S. professor data
34
+
35
+ ```python
36
+ # Use locally stored search results as input to an LLM.
37
+ # Sends as batches, each one waiting for the previous to finish.
38
+ python -m data_pipeline.us_professor_verifier --batch_analyze
39
+ # After some time (at most 24 hrs per batch, ~5 batches), the batch results become available for retrieval.
40
+ # Took ~1 hr for me
41
+ python -m data_pipeline.us_professor_verifier --batch_retrieve
42
+ ```
43
+
44
+ #### Extract embeddings for the relevant papers
45
+ ```python
46
+ # Fetch arxiv data and extract embeddings
47
+ python -m data_pipeline.download_arxiv_kaggle
48
+ ```
data_pipeline/__init__.py ADDED
File without changes
data_pipeline/conference_scraper.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Scrape data from some famous ML conferences and saves into data/conference.
2
+
3
+ Every scrape function returns a list of 3-lists of the form
4
+ [paper_title, paper_authors, paper_url].
5
+
6
+ Conferences
7
+ -----------
8
+ NeurIPS: 2022, 2023
9
+ ICML: 2023, 2024
10
+ AISTATS: 2023, 2024
11
+ COLT: 2023, 2024
12
+ AAAI: 2023, 2024
13
+ EMNLP: 2023, 2024
14
+ CVPR: 2023, 2024
15
+ -----------
16
+
17
+ Disclaimer
18
+ -----------
19
+ The choice of conferences was sourced from here:
20
+ https://www.kaggle.com/discussions/getting-started/115799
21
+
22
+ The priority of including certain conferences and tracks was based on a 1st-year PhD's
23
+ judgment. Some very top conferences were excluded due to higher activation energy to
24
+ scrape data and/or the ignorance of the 1st-year PhD. Some notable exceptions include
25
+ ICLR, ICCV, ECCV, ACL, NAACL, and many others.
26
+ -----------
27
+ """
28
+
29
+ from collections import defaultdict
30
+ from functools import partial
31
+ import json
32
+ import os
33
+ import requests
34
+ import time
35
+
36
+ from bs4 import BeautifulSoup
37
+ from tqdm import tqdm
38
+
39
+
40
+ SAVE_DIR = "data/conference"
41
+
42
+ def scrape_nips(year):
43
+ nips_url = f"https://papers.nips.cc/paper/{year}"
44
+ response = requests.get(nips_url)
45
+ soup = BeautifulSoup(response.text, "html.parser")
46
+
47
+ conference_items = soup.find_all('li')
48
+ conference_items = [[ci.a.get_text(), ci.i.get_text(), ci.a['href']] for ci in conference_items]
49
+ conference_items = [ci for ci in conference_items if ci[0]!="" and ci[1]!=""]
50
+ return conference_items
51
+
52
+ def scrape_mlr_proceedings(conference, year):
53
+
54
+ cy2v = {
55
+ ("ICML", 2024): "v235",
56
+ ("ICML", 2023): "v202",
57
+ ("AISTATS", 2024): "v238",
58
+ ("AISTATS", 2023): "v206",
59
+ ("COLT", 2024): "v247",
60
+ ("COLT", 2023): "v195",
61
+ }
62
+
63
+ conference_url = f"https://proceedings.mlr.press/{cy2v[(conference, year)]}"
64
+ response = requests.get(conference_url)
65
+ soup = BeautifulSoup(response.text, "html.parser")
66
+
67
+ conference_items = soup.find_all('div', class_="paper")
68
+ conference_items = [
69
+ [
70
+ ci.find('p', class_="title").get_text(),
71
+ ci.find('p', class_="details").find('span', class_="authors").get_text(),
72
+ ci.find('p', class_="links").find('a')['href']
73
+ ]
74
+ for ci in conference_items
75
+ ]
76
+ return conference_items
77
+
78
+ def scrape_aaai():
79
+ # Scrape the technical tracks of past two years ('23, '24)
80
+ # Look at first two pages of archives that give links to tracks
81
+ # Look at each track
82
+
83
+ headers = {
84
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3"
85
+ }
86
+
87
+ # First two pages
88
+ track_links = []
89
+
90
+ aaai_urls = [
91
+ "https://ojs.aaai.org/index.php/AAAI/issue/archive",
92
+ "https://ojs.aaai.org/index.php/AAAI/issue/archive/2",
93
+ ]
94
+
95
+ for aaai_url in aaai_urls:
96
+
97
+ response = requests.get(aaai_url, headers=headers)
98
+ soup = BeautifulSoup(response.text, "html.parser")
99
+
100
+ tracks = [track.find('a', class_="title") for track in soup.find_all('h2')]
101
+ track_links.extend(
102
+ [(track.text.strip(), track['href']) for track in tracks if track is not None]
103
+ )
104
+ print(track_links)
105
+
106
+ time.sleep(60) # respect scraping limits
107
+
108
+ # only look at past two years
109
+ track_links = [track_link for track_link in track_links if "AAAI-24" in track_link[0] or "AAAI-23" in track_link[0]]
110
+ print("track links: ", track_links)
111
+
112
+ conference_items = []
113
+
114
+ for track_link in tqdm(track_links):
115
+ print(f"Going through track {track_link[0]} @ {track_link[1]} ")
116
+
117
+ # Scrape tracks
118
+ response = requests.get(track_link[1], headers=headers)
119
+ soup = BeautifulSoup(response.text, "html.parser")
120
+
121
+ articles = soup.find_all('div', class_="obj_article_summary")
122
+
123
+ for article in articles:
124
+
125
+ aref = article.find('a')
126
+ conference_items.append(
127
+ [
128
+ aref.text.strip(),
129
+ article.find('div', class_="authors").text.strip(),
130
+ aref['href'],
131
+ ]
132
+ )
133
+
134
+ time.sleep(60) # respect scraping limits
135
+
136
+ return conference_items
137
+
138
+ def scrape_emnlp(year):
139
+
140
+ emnlp_url = f"https://{year}.emnlp.org/program/accepted_main_conference/"
141
+
142
+ headers = {
143
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3"
144
+ }
145
+
146
+ response = requests.get(emnlp_url, headers=headers)
147
+ soup = BeautifulSoup(response.text, "html.parser")
148
+
149
+ ps = soup.find_all('p')
150
+ conference_items = [[p.contents[0].text, p.contents[-1].text, ''] for p in ps]
151
+ return conference_items
152
+
153
+ def scrape_cvpr(year):
154
+ cvpr_url = f"https://openaccess.thecvf.com/CVPR{year}?day=all"
155
+ response = requests.get(cvpr_url)
156
+ soup = BeautifulSoup(response.text, "html.parser")
157
+
158
+ # Separately extract title/link and authors
159
+ dts = soup.find_all('dt', class_="ptitle")
160
+ conference_items = [(dt.text, '', dt.a['href']) for dt in dts]
161
+
162
+ dds = soup.find_all('dd')
163
+ authors = []
164
+ for dd in dds:
165
+ if dd.find('form') is not None: # author entry
166
+ authors.append(
167
+ ', '.join([x.text for x in dd.find_all('a')])
168
+ )
169
+
170
+ conference_items = [[dt.text, author, dt.a['href']] for dt, author in zip(dts, authors)]
171
+ return conference_items
172
+
173
+ def save_to_file(conference_items, filename):
174
+ with open(filename, 'w') as f:
175
+ for item in conference_items:
176
+ f.write(json.dumps(item) + '\n')
177
+
178
+ def load_from_file(filename):
179
+ with open(filename, 'r') as f:
180
+ conference_items = [json.loads(line) for line in f]
181
+ return conference_items
182
+
183
+ def main():
184
+
185
+ scrape_functions = {
186
+ "NeurIPS-2022": partial(scrape_nips, 2022),
187
+ "NeurIPS-2023": partial(scrape_nips, 2023),
188
+ "ICML-2023": partial(scrape_mlr_proceedings, "ICML", 2023),
189
+ "ICML-2024": partial(scrape_mlr_proceedings, "ICML", 2024),
190
+ "AISTATS-2023": partial(scrape_mlr_proceedings, "AISTATS", 2023),
191
+ "AISTATS-2024": partial(scrape_mlr_proceedings, "AISTATS", 2024),
192
+ "COLT-2023": partial(scrape_mlr_proceedings, "COLT", 2023),
193
+ "COLT-2024": partial(scrape_mlr_proceedings, "COLT", 2024),
194
+ "AAAI": scrape_aaai, # easier to scrape both years at once, takes ~40 mins
195
+ "EMNLP-2023": partial(scrape_emnlp, 2023),
196
+ "EMNLP-2024": partial(scrape_emnlp, 2024),
197
+ "CVPR-2023": partial(scrape_cvpr, 2023),
198
+ "CVPR-2024": partial(scrape_cvpr, 2024),
199
+ }
200
+
201
+ def load_progress():
202
+ if os.path.exists(SAVE_DIR):
203
+ file_paths = os.listdir(SAVE_DIR)
204
+ file_paths = [file_path for file_path in file_paths if file_path.endswith('.json')]
205
+ file_paths = [file_path.split('.')[0] for file_path in file_paths]
206
+ return set(file_paths)
207
+ return set()
208
+
209
+ def save_progress(conference, file_path):
210
+ with open(file_path, 'a') as f:
211
+ f.write(conference + '\n')
212
+
213
+ def log_progress(msg, conference, file_path):
214
+ with open(file_path, 'a') as f:
215
+ f.write(conference + ': ' + msg + '\n')
216
+
217
+ os.makedirs(SAVE_DIR, exist_ok=True)
218
+
219
+ # Load previous progress
220
+ scraped_conferences = load_progress()
221
+
222
+ # Progress file for current scrape
223
+ progress_file = "conference_scraper_progress.tmp"
224
+
225
+ for conference, scrape_function in tqdm(scrape_functions.items()):
226
+
227
+ if conference in scraped_conferences:
228
+ print(f"Skipping {conference}, already scraped.")
229
+ log_progress("Success!", conference, progress_file)
230
+ continue
231
+
232
+ try:
233
+
234
+ print(f"Scraping {conference}")
235
+ save_path = os.path.join(SAVE_DIR, f"{conference}.json")
236
+ conference_items = scrape_function()
237
+ save_to_file(conference_items, save_path)
238
+ print(f"Saved {conference} to {str(save_path)}")
239
+ save_progress(conference, progress_file)
240
+ log_progress("Success!", conference, progress_file)
241
+
242
+ except Exception as e:
243
+ print(f"Error scraping {conference}: {e}")
244
+ log_progress(f"ERROR: {e}", conference, progress_file)
245
+ continue
246
+
247
+ # Remove progress file
248
+ os.remove(progress_file)
249
+
250
+ def stats():
251
+ total = 0
252
+ for fname in os.listdir(SAVE_DIR):
253
+ with open(os.path.join(SAVE_DIR, fname), 'r') as file:
254
+ num_lines = sum(1 for _ in file)
255
+ print(fname + ": " + str(num_lines) + " lines")
256
+ total += num_lines
257
+ print("Total: " + str(total))
258
+
259
+
260
+ if __name__ == "__main__":
261
+ main()
262
+ stats()
data_pipeline/download_arxiv_kaggle.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pulls papers from arxiv."""
2
+ from collections import defaultdict
3
+ from functools import partial
4
+ from datetime import datetime
5
+ import heapq
6
+ import json
7
+ import os
8
+ from pathlib import Path
9
+ import pickle
10
+
11
+ from datasets import Dataset
12
+ import kaggle
13
+ import numpy as np
14
+ import pandas as pd
15
+ import torch
16
+ import torch.nn.functional as F
17
+ from tqdm import tqdm
18
+ from transformers import AutoTokenizer, AutoModel
19
+
20
+
21
+ arxiv_fname = "arxiv-metadata-oai-snapshot.json"
22
+
23
+ def download_arxiv_data(path = Path(".")):
24
+ """Downloads and unzips arxiv dataset from Kaggle into the `data` subdirectory of `path`."""
25
+ dataset = "Cornell-University/arxiv"
26
+ data_path = path/"data"
27
+
28
+ if not any([arxiv_fname in file for file in os.listdir(data_path)]):
29
+ kaggle.api.dataset_download_cli(dataset, path=data_path, unzip=True)
30
+ else:
31
+ print(f"Data already downloaded at {data_path/arxiv_fname}.")
32
+ return data_path/arxiv_fname
33
+
34
+ def get_lbl_from_name(names):
35
+ """Tuple (last_name, first_name, middle_name) => String 'first_name [middle_name] last_name'."""
36
+ return [
37
+ name[1] + ' ' + name[0] if name[2] == '' \
38
+ else name[1] + ' ' + name[2] + ' ' + name[0]
39
+ for name in names
40
+ ]
41
+
42
+ def filter_arxiv_for_ml(arxiv_path, obtain_summary=False, authors_of_interest=None):
43
+ """Sifts through downloaded arxiv file to find ML-related papers.
44
+
45
+ If `obtain_summary` is True, saves a pickled DataFrame to the same directory as
46
+ the downloaded arxiv file with the name `arxiv_fname` + `-summary.pkl`.
47
+
48
+ If `authors_of_interest` is not None, only save ML-related papers by those authors.
49
+ """
50
+ ml_path = str(arxiv_path).split('.')[0]+'-ml.json'
51
+ summary_path = str(arxiv_path).split('.')[0]+'-summary.pkl'
52
+
53
+ ml_cats = ['cs.AI', 'cs.CL', 'cs.CV', 'cs.LG', 'stat.ML']
54
+
55
+ if obtain_summary and Path(ml_path).exists() and Path(summary_path).exists():
56
+ print(f"File {ml_path} with ML subset of arxiv already exists. Skipping.")
57
+ print(f"Summary file {summary_path} already exists. Skipping.")
58
+ return
59
+ if not obtain_summary and Path(ml_path).exists():
60
+ print(f"File {ml_path} with ML subset of arxiv already exists. Skipping.")
61
+ return
62
+
63
+ if obtain_summary:
64
+ gdf = {'categories': [], 'lv_date': []} # global data
65
+
66
+ if authors_of_interest:
67
+ authors_of_interest = set(authors_of_interest)
68
+
69
+ # Load the JSON file line by line
70
+ with open(arxiv_path, 'r') as f1, open(ml_path, 'w') as f2:
71
+ for line in tqdm(f1):
72
+ # Parse each line as JSON
73
+ try:
74
+ entry_data = json.loads(line)
75
+ except json.JSONDecodeError:
76
+ # Skip lines that cannot be parsed as JSON
77
+ continue
78
+
79
+ # check categories and last version in entry data
80
+ if (
81
+ obtain_summary
82
+ and 'categories' in entry_data
83
+ and 'versions' in entry_data
84
+ and len(entry_data['versions'])
85
+ and 'created' in entry_data['versions'][-1]
86
+ ):
87
+ gdf['categories'].append(entry_data['categories'])
88
+ gdf['lv_date'].append(entry_data['versions'][-1]['created'])
89
+
90
+ # ml data
91
+ authors_on_paper = get_lbl_from_name(entry_data['authors_parsed'])
92
+ if ('categories' in entry_data
93
+ and (any(cat in entry_data['categories'] for cat in ml_cats))
94
+ and (any(author in authors_of_interest for author in authors_on_paper))
95
+ ):
96
+ f2.write(line)
97
+
98
+ if obtain_summary:
99
+ gdf = pd.DataFrame(gdf)
100
+ gdf['lv_date'] = pd.to_datetime(gdf['lv_date'])
101
+ gdf = gdf.sort_values('lv_date', axis=0).reset_index(drop=True)
102
+
103
+ cats = set()
104
+ for cat_combo in gdf['categories'].unique():
105
+ cat_combo.split(' ')
106
+ cats.update(cat_combo.split(' '))
107
+ print(f'Columnizing {len(cats)} categories. ')
108
+ for cat in tqdm(cats):
109
+ gdf[cat] = pd.arrays.SparseArray(gdf['categories'].str.contains(cat), fill_value=0, dtype=np.int8)
110
+
111
+ # count number of categories item is associated with
112
+ gdf['ncats'] = gdf['categories'].str.count(' ') + 1
113
+
114
+ # write to pickle file
115
+ with open(summary_path, 'wb') as f:
116
+ pickle.dump(gdf, f)
117
+
118
+ def get_professors_and_relevant_papers(us_professors, k=8, cutoff=datetime(2022, 10, 1)):
119
+ """
120
+ Returns a dictionary mapping U.S. professor names to a list of indices
121
+ corresponding to their most recent papers in `data/arxiv-metadata-oai-snapshot-ml.json`.
122
+ This function is necessary to specify the papers we are interested in for each
123
+ professor (e.g., the most recent papers after cutoff)
124
+
125
+ Parameters:
126
+ - us_professors: A list of U.S. professor names to match against.
127
+ - k: The number of most recent papers to keep for each professor, based on
128
+ the first version upload date.
129
+ - cutoff (datetime): Only considers papers published after this date
130
+ (default: October 1, 2022).
131
+
132
+ Returns:
133
+ - dict: A dictionary where keys are professor names and values are lists of
134
+ indices corresponding to their most recent papers.
135
+ """
136
+ # professors to tuple of (datetime, papers)
137
+ p2p = defaultdict(list)
138
+
139
+ with open('data/arxiv-metadata-oai-snapshot-ml.json', 'r') as f:
140
+ line_nbr = 1
141
+ while True:
142
+ line = f.readline()
143
+ if not line: break
144
+
145
+ try:
146
+ ml_data = json.loads(line)
147
+ paper_authors = get_lbl_from_name(ml_data['authors_parsed'])
148
+
149
+ # filter the same way as in `conference_scraper.py`
150
+ # ignore solo-authored papers and papers with more than 20 authors
151
+ if len(paper_authors) == 1 or len(paper_authors) > 20:
152
+ continue
153
+
154
+ try:
155
+ dt = datetime.strptime(ml_data["versions"][0]["created"], '%a, %d %b %Y %H:%M:%S %Z')
156
+ if dt < cutoff:
157
+ continue
158
+ except (KeyError, ValueError) as e:
159
+ print(f"Failed to parse date: {ml_data}")
160
+ dt = datetime(2000, 1, 1) # before cutoff date
161
+
162
+ # consider if professor is first-author since we now care about semantics
163
+ for paper_author in paper_authors:
164
+ if paper_author in us_professors:
165
+ # make a connection
166
+ heapq.heappush(p2p[paper_author], (dt, line_nbr))
167
+ if len(p2p[paper_author]) > k:
168
+ heapq.heappop(p2p[paper_author])
169
+ except:
170
+ print(f"{line}")
171
+ line_nbr += 1
172
+ return p2p
173
+
174
+ def gen(p2p):
175
+ values = p2p.values()
176
+ relevant_lines = set()
177
+ for value in values:
178
+ relevant_lines.update([v[1] for v in value])
179
+ relevant_lines = sorted(list(relevant_lines))
180
+
181
+ idx = 0
182
+ with open('data/arxiv-metadata-oai-snapshot-ml.json', 'r') as f:
183
+ line_nbr = 1
184
+ while idx < len(relevant_lines):
185
+ line = f.readline()
186
+ if not line: break
187
+
188
+ if line_nbr == relevant_lines[idx]:
189
+ data = json.loads(line)
190
+ yield {"line_nbr": line_nbr,
191
+ "id": data["id"],
192
+ "title": data["title"],
193
+ "abstract": data["abstract"],
194
+ "authors": data["authors_parsed"]}
195
+ idx += 1
196
+
197
+ line_nbr += 1
198
+
199
+
200
+ class EmbeddingProcessor:
201
+ def __init__(self, model_name: str, custom_model_name: str, device: str = "cuda"):
202
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
203
+ self.model = AutoModel.from_pretrained(custom_model_name)
204
+ self.device = torch.device(device)
205
+ self.model.to(self.device)
206
+ torch.cuda.empty_cache()
207
+
208
+ @staticmethod
209
+ def mean_pooling(model_output, attention_mask):
210
+ # First element of model_output contains all token embeddings
211
+ token_embeddings = model_output[0]
212
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
213
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
214
+
215
+ def get_embeddings(self, batch):
216
+ title_tkn, abstract_tkn = " [TITLE] ", " [ABSTRACT] "
217
+ titles = batch["title"]
218
+ abstracts = batch["abstract"]
219
+
220
+ texts = [title_tkn + t + abstract_tkn + a for t, a in zip(titles, abstracts)]
221
+
222
+ # Tokenize sentences
223
+ encoded_input = self.tokenizer(texts, padding=True, truncation=True, return_tensors='pt')
224
+ encoded_input = {k: v.to(self.device) for k, v in encoded_input.items()}
225
+
226
+ # Compute token embeddings
227
+ with torch.no_grad():
228
+ model_output = self.model(**encoded_input)
229
+
230
+ # Perform pooling
231
+ embeddings = self.mean_pooling(model_output, encoded_input['attention_mask'])
232
+
233
+ # Normalize embeddings
234
+ embeddings = F.normalize(embeddings, p=2, dim=1)
235
+
236
+ # Move embeddings to CPU and convert to list
237
+ return embeddings.cpu().numpy().tolist()
238
+
239
+ def process_dataset(self, dataset_path: str, save_path: str, batch_size: int = 128):
240
+ # Load dataset
241
+ ds = Dataset.load_from_disk(dataset_path)
242
+
243
+ # Compute embeddings and add as a new column
244
+ ds_with_embeddings = ds.map(lambda x: {"embeddings": self.get_embeddings(x)}, batched=True, batch_size=batch_size)
245
+
246
+ # Save the updated dataset
247
+ save_path = save_path
248
+ ds_with_embeddings.save_to_disk(save_path)
249
+ print(f"Dataset with embeddings saved to {save_path}")
250
+
251
+
252
+ def main():
253
+ """Downloads arxiv data and extract embeddings for papers."""
254
+ print("Downloading data...")
255
+ arxiv_path = download_arxiv_data()
256
+ with open('data/professor/us_professor.json', 'r') as f:
257
+ authors_of_interest = json.load(f)
258
+ authors_of_interest = [author['name'] for author in authors_of_interest]
259
+ print("Filtering data for ML papers...")
260
+ filter_arxiv_for_ml(arxiv_path, authors_of_interest=authors_of_interest)
261
+
262
+ # professor to list of paper indices
263
+ paper_data_path = "data/paper_embeddings/paper_data"
264
+ print("Saving data to disk at " + paper_data_path)
265
+ p2p = get_professors_and_relevant_papers(authors_of_interest)
266
+ ds = Dataset.from_generator(partial(gen, p2p))
267
+ ds.save_to_disk(paper_data_path)
268
+
269
+ print("Extracting embeddings (use GPU if possible)...")
270
+ # paper embeddings
271
+ save_path = "data/paper_embeddings/all-mpnet-base-v2-embds"
272
+ # Initialize the embedding processor with model names
273
+ embedding_processor = EmbeddingProcessor(
274
+ model_name='sentence-transformers/all-mpnet-base-v2',
275
+ custom_model_name='salsabiilashifa11/sbert-paper'
276
+ )
277
+ # Process dataset and save with embeddings
278
+ embedding_processor.process_dataset(paper_data_path, save_path, batch_size=128)
279
+
280
+ if __name__ == "__main__":
281
+ main()
data_pipeline/loaders.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+ def load_conference_papers(conference_dir='data/conference'):
5
+ papers = []
6
+ files = os.listdir(conference_dir)
7
+ for file in files:
8
+ if not file.endswith('.json'):
9
+ continue
10
+ with open(os.path.join(conference_dir, file), 'r') as f:
11
+ while True:
12
+ line = f.readline()
13
+ if not line: break
14
+ paper = json.loads(line)
15
+ papers.append(paper)
16
+ return papers
17
+
18
+ def load_us_professor():
19
+ """Returns a JSON list"""
20
+ with open('data/professor/us_professor.json', 'r') as f:
21
+ us_professors = json.load(f)
22
+ return us_professors
data_pipeline/requirements.txt ADDED
File without changes
data_pipeline/schools_scraper.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://medium.com/@donadviser/running-selenium-and-chrome-on-wsl2-cfabe7db4bbb
2
+
3
+ import os
4
+ import time
5
+
6
+ from bs4 import BeautifulSoup
7
+ from dotenv import load_dotenv, find_dotenv
8
+ from langchain_together import ChatTogether
9
+ from langchain_core.output_parsers import StrOutputParser
10
+ from langchain_core.prompts import ChatPromptTemplate
11
+ from langchain_core.runnables import RunnableLambda
12
+ from selenium import webdriver
13
+ from selenium.webdriver.chrome.service import Service
14
+ from selenium.webdriver.common.by import By
15
+ from selenium.webdriver.chrome.options import Options
16
+
17
+ _ = load_dotenv(find_dotenv()) # read local .env file
18
+
19
+
20
+ def get_service_and_chrome_options():
21
+ """TODO: specific to chromedriver location."""
22
+ # Define Chrome options
23
+ chrome_options = Options()
24
+ chrome_options.add_argument("--headless")
25
+ chrome_options.add_argument("--no-sandbox")
26
+ # Add more options here if needed
27
+
28
+ # Define paths
29
+ user_home_dir = os.path.expanduser("~")
30
+ user_home_dir = os.path.expanduser("~")
31
+ chrome_binary_path = os.path.join(user_home_dir, "chrome-linux64", "chrome")
32
+ chromedriver_path = os.path.join(user_home_dir, "chromedriver-linux64", "chromedriver")
33
+
34
+ # Set binary location and service
35
+ chrome_options.binary_location = chrome_binary_path
36
+ service = Service(chromedriver_path)
37
+
38
+ return service, chrome_options
39
+
40
+
41
+ def retrieve_csrankings_content(dump_file="soup.tmp"):
42
+ """Write times higher page to a dump file."""
43
+ # https://medium.com/@donadviser/running-selenium-and-chrome-on-wsl2-cfabe7db4bbb
44
+ # Using WSL2
45
+
46
+ service, chrome_options = get_service_and_chrome_options()
47
+
48
+ # Initialize Chrome WebDriver
49
+ with webdriver.Chrome(service=service, options=chrome_options) as browser:
50
+ print("Get browser")
51
+ browser.get("https://www.timeshighereducation.com/student/best-universities/best-universities-united-states")
52
+
53
+ # Wait for the page to load
54
+ print("Wait for the page to load")
55
+ browser.implicitly_wait(10)
56
+
57
+ print("Get html")
58
+ # Retrieve the HTML content
59
+ html_content = browser.page_source
60
+
61
+ # Write HTML content to soup.txt
62
+ with open(dump_file, "w") as f:
63
+ f.write(html_content)
64
+
65
+
66
+ def extract_timeshigher_content(read_file="soup.tmp", dump_file="soup (1).tmp"):
67
+ """Extract universities from a dump file."""
68
+ with open(read_file, "r") as f:
69
+ html_content = f.read()
70
+
71
+ # Parse the HTML content
72
+ soup = BeautifulSoup(html_content, "html.parser")
73
+
74
+ # Find universities
75
+ university_table = soup.find_all('tr')
76
+ universities = [tr.find('a').get_text() for tr in university_table if tr.find('a')]
77
+
78
+ # Remove duplicates while keeping the order
79
+ universities = list(dict.fromkeys(universities))
80
+
81
+ # Write universities line-by-line to a new file
82
+ with open(dump_file, "w") as f:
83
+ for uni in universities:
84
+ f.write(f"{uni}\n")
85
+
86
+
87
+ def get_department_getter():
88
+ """
89
+ Returns a function that leverages LangChain and TogetherAI to get a list of
90
+ department names in a university associated with machine learning.
91
+ """
92
+ template_string = """\
93
+ You are an expert in PhD programs and know about \
94
+ specific departments at each university.\
95
+ You are helping to design a system that generates \
96
+ a list of professors that students interested in \
97
+ machine learning can apply to for their PhDs. \
98
+ Currently, recall is more important than precision. \
99
+ Include as many departments as possible, while \
100
+ maintaining relevancy. Which departments in {university} \
101
+ are associated with machine learning? Please format your \
102
+ answer as a numbered list. Afterwards, please generate a \
103
+ new line starting with \"Answer:\", followed by a concise \
104
+ list of department names generated, separated by
105
+ semicolons.\
106
+ """
107
+
108
+ prompt_template = ChatPromptTemplate.from_template(template_string)
109
+
110
+ # # choose from our 50+ models here: https://docs.together.ai/docs/inference-models
111
+ chat = ChatTogether(
112
+ together_api_key=os.environ["TOGETHER_API_KEY"],
113
+ model="meta-llama/Llama-3-70b-chat-hf",
114
+ temperature=0.3
115
+ )
116
+
117
+ output_parser = StrOutputParser()
118
+
119
+ def extract_function(text):
120
+ """Returns the line that starts with `Answer:`"""
121
+ if "Answer:" not in text:
122
+ return "No `Answer:` found"
123
+ return text.split("Answer:")[1].strip()
124
+
125
+ chain = prompt_template | chat | output_parser | RunnableLambda(extract_function)
126
+
127
+ def get_department_info(uni):
128
+ """Get department info from the university."""
129
+ return chain.invoke({"university": uni})
130
+
131
+ return get_department_info
132
+
133
+
134
+ def get_department_info(unis_file="soup (1).tmp", deps_file="departments.tsv"):
135
+ """
136
+ Get department info for all universities in `unis_file` and
137
+ write it to `deps_file`."""
138
+
139
+ department_getter = get_department_getter()
140
+ with open(unis_file, "r") as fin, open(deps_file, "w") as fout:
141
+
142
+ # Iterate through universities in `fin`
143
+ for uni in fin.readlines():
144
+ uni = uni.strip()
145
+
146
+ deps = []
147
+ # Prompt the LLM multiple times for better recall
148
+ for i in range(3):
149
+ depstr = department_getter(uni)
150
+ time.sleep(3) # Respect usage limits!
151
+ try:
152
+ if depstr == "No `Answer:` found":
153
+ print(f"No departments found for {uni} on {i}'th prompt.")
154
+ else:
155
+ deps_ = [d.strip() for d in depstr.split(';')]
156
+ deps.extend(deps_)
157
+ except Exception as e:
158
+ print("Exception for {uni} on {i}'th prompt: ")
159
+ print("Parsing string: ", depstr)
160
+ print(e)
161
+
162
+ # Deduplicate deps list
163
+ deps = list(dict.fromkeys(deps))
164
+
165
+ # Write to tsv dump file
166
+ for dep in deps:
167
+ fout.write(f"{uni}\t{dep}\n")
168
+
169
+ # Print string info
170
+ print(f"{uni}: {deps}")
171
+
172
+
173
+ import requests
174
+
175
+ def get_faculty_list_potential_links_getter():
176
+ """Returns a function that returns a list of links that may contain faculty lists."""
177
+ GOOGLE_API_KEY = os.environ['GOOGLE_API_KEY']
178
+ GOOGLE_SEARCH_ENGINE_ID = os.environ['GOOGLE_SEARCH_ENGINE_ID']
179
+
180
+ def get_faculty_list_potential_links(uni, dep):
181
+ """Returns a list of links that may contain faculty lists."""
182
+ search_query = f'{uni} {dep} faculty list'
183
+
184
+
185
+ params = {
186
+ 'q': search_query, 'key': GOOGLE_API_KEY, 'cx': GOOGLE_SEARCH_ENGINE_ID
187
+ }
188
+
189
+ response = requests.get('https://www.googleapis.com/customsearch/v1', params=params)
190
+ results = response.json()
191
+ title2link = {item['title']: item['link'] for item in results['items']}
192
+
193
+
194
+
195
+ # if __name__ == "__main__":
196
+ # get_department_info()
data_pipeline/us_professor_verifier.py ADDED
@@ -0,0 +1,513 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ import json
3
+ import os
4
+ import pickle
5
+ import requests
6
+ import time
7
+
8
+ from bs4 import BeautifulSoup
9
+ from dotenv import load_dotenv, find_dotenv
10
+ from langchain.prompts import PromptTemplate
11
+ from openai import OpenAI
12
+ import regex as re
13
+ from tqdm import tqdm
14
+
15
+ from data_pipeline.conference_scraper import get_authors
16
+
17
+
18
+ _ = load_dotenv(find_dotenv())
19
+
20
+ SEARCH_URL = "https://api.bing.microsoft.com/v7.0/search"
21
+ SUBSCRIPTION_KEY = os.environ["BING_SEARCH_API_KEY"]
22
+ HEADERS = {
23
+ "Ocp-Apim-Subscription-Key": SUBSCRIPTION_KEY,
24
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3",
25
+ }
26
+
27
+ EXAMPLE_PROFESSOR_JSON = {
28
+ "is_professor": True,
29
+ "title": "Assistant Professor",
30
+ "department": "Computer Science",
31
+ "university": "Stanford University",
32
+ "us_university": True,
33
+ }
34
+
35
+ EXAMPLE_not_professor_JSON = {
36
+ "is_professor": False,
37
+ "occupation": "Graduate Student",
38
+ "affiliation": "Carnegie Mellon University"
39
+ }
40
+
41
+ IS_PROFESSOR_TEMPLATE = """You are a helpful assistant tasked with determining if {person_name} is a machine learning \
42
+ professor. You have search results from the query "{person_name} machine learning professor". \
43
+ Based on the results, specify if {person_name} is a professor, and if so, provide \
44
+ their title, department, university, and whether their university is in the U.S. If not, give their occupation \
45
+ and affiliation. Note: multiple people may \
46
+ share the same name, so choose the one most likely in machine learning. Further, one person may have multiple \
47
+ positions. If this is the case and one of those positions include being a professor, specify they are a professor \
48
+ and provide their title, department, university, and whether their university is in the U.S.
49
+
50
+ Only return the raw JSON, no MarkDown!
51
+
52
+ If {person_name} **is** a professor, fill out:
53
+ - `is_professor`: true
54
+ - `title`: e.g., `Assistant Professor`, `Associate Professor`, `Professor` etc.
55
+ - `department`: Full name, e.g., `Computer Science` rather than `CS` and `Electrical Engineering` rather than `EE`.
56
+ - `university`: Full name, e.g., `California Institute of Technology` rather than `Caltech`
57
+ - `us_university`: `true` or `false`
58
+
59
+ Example:
60
+ {professor_json_template}
61
+
62
+ If {person_name} **is not** a professor, fill out:
63
+ - `is_professor`: false
64
+ - `occupation`: e.g., `Graduate Student`, `Researcher`, `Engineer`, `Scientist`
65
+ - `affiliation`: e.g., `Carnegie Mellon University`, `Deepmind`, `Apple`, `NVIDIA`
66
+
67
+ Example:
68
+ {not_professor_json_template}
69
+
70
+ Search results (formatted as a numbered list with link name and snippet). \
71
+ Again, only return the JSON, just with the dictionary and its fields.
72
+ {hits}"""
73
+
74
+ # import httpx
75
+ def bing_search(person_name, max_retries=0, wait_time=0.5):
76
+ """Performs the bing search `person_name` machine learning professor."""
77
+ query = "{} machine learning professor".format(person_name)
78
+ params = {"q": query, "count": 10, "offset": 0, "mkt": "en-US", "textFormat": "HTML"}
79
+
80
+ for attempt in range(max_retries + 1):
81
+ try:
82
+ response = requests.get(SEARCH_URL, headers=HEADERS, params=params)
83
+ response.raise_for_status()
84
+ return response.json()
85
+ except requests.HTTPError as http_err:
86
+ if attempt == max_retries:
87
+ raise Exception(f"Max retries reached. Failed to get a valid response for {person_name}") from http_err
88
+ print(f"An error occurred while searching {person_name}: {http_err}. Retrying in {wait_time} seconds ...")
89
+ time.sleep(wait_time)
90
+
91
+ return "" # doesn't run
92
+
93
+ def process_search_results(search_results):
94
+ """Cleans up bing search results."""
95
+ # What people see, url name and snippet
96
+ readable_results = "\n".join(["{0}. [{1}]: [{2}]".format(i + 1, v["name"], v["snippet"])
97
+ for i, v in enumerate(search_results["webPages"]["value"])])
98
+ soup = BeautifulSoup(readable_results, "html.parser")
99
+ cleaned_readable_results = soup.get_text()
100
+ cleaned_readable_results = re.sub(r'[^\x00-\x7F]+', '', cleaned_readable_results)
101
+
102
+ # Links
103
+ url_results = "\n".join(["{0}. {1}".format(i + 1, v["url"])
104
+ for i, v in enumerate(search_results["webPages"]["value"])])
105
+
106
+ # Combine human readable and links
107
+ web_results = [cleaned_readable_results, url_results]
108
+ return web_results
109
+
110
+ def get_prompt(person_name, top_hits):
111
+ template = PromptTemplate(
112
+ input_variables=["person_name", "professor_json_template", "not_professor_json_template", "hits"],
113
+ template=IS_PROFESSOR_TEMPLATE,
114
+ )
115
+
116
+ filled_prompt = template.format(person_name=person_name,
117
+ professor_json_template=json.dumps(EXAMPLE_PROFESSOR_JSON),
118
+ not_professor_json_template=json.dumps(EXAMPLE_not_professor_JSON),
119
+ hits="\n".join(top_hits))
120
+
121
+ return filled_prompt
122
+
123
+ def run_chatgpt(prompt, client, model="gpt-4o-mini", system_prompt=None):
124
+ messages = []
125
+ if system_prompt:
126
+ messages.append({"role": "system", "content": system_prompt})
127
+ messages.append({"role": "user", "content": prompt})
128
+ response = client.chat.completions.create(
129
+ model=model,
130
+ messages=messages,
131
+ temperature=0.0,
132
+ seed=123,
133
+ )
134
+
135
+ # Return response
136
+ return response.choices[0].message.content
137
+
138
+ def check_json(profile):
139
+ if not isinstance(profile, dict):
140
+ raise ValueError("Profile must be a dictionary")
141
+
142
+ if "is_professor" not in profile:
143
+ raise ValueError("Profile must contain a 'is_professor' key")
144
+
145
+ if profile["is_professor"]:
146
+ if "title" not in profile:
147
+ raise ValueError("Profile must contain a 'title' key")
148
+ if "department" not in profile:
149
+ raise ValueError("Profile must contain a 'department' key")
150
+ if "university" not in profile:
151
+ raise ValueError("Profile must contain a 'university' key")
152
+ if "us_university" not in profile:
153
+ raise ValueError("Profile must contain a 'us_university' key")
154
+ if type(profile["us_university"]) is not bool:
155
+ raise ValueError("Profile 'us_university' must be a boolean")
156
+ else:
157
+ if "occupation" not in profile:
158
+ raise ValueError("Profile must contain an 'occupation' key")
159
+ if "affiliation" not in profile:
160
+ raise ValueError("Profile must contain an 'affiliation' key")
161
+
162
+ def save_json(profiles, file_path):
163
+ os.makedirs(os.path.dirname(file_path), exist_ok=True)
164
+ with open(file_path, 'w') as file: # appending just the new ones would be better
165
+ json.dump(profiles, file, indent=4)
166
+
167
+ def load_json(file_path):
168
+ with open(file_path, 'r') as file:
169
+ return json.load(file)
170
+
171
+ def log_progress_to_file(progress_log, file_path):
172
+ os.makedirs(os.path.dirname(file_path), exist_ok=True)
173
+ with open(file_path, 'a') as file:
174
+ file.write('\n'.join(progress_log))
175
+ file.write('\n' + '-' * 10 + '\n')
176
+
177
+ def search_person(person_name, progress_log):
178
+ """Completes a bing search for the person."""
179
+ try:
180
+ search_results = bing_search(person_name)
181
+ web_results = process_search_results(search_results)
182
+ top_hits = web_results[0].split("\n")[:5] # Extract top 5 results
183
+ progress_log.append(f"Success: Search results for {person_name}.")
184
+ return top_hits
185
+ except Exception as e:
186
+ print(f"Search exception for {person_name}: ", e)
187
+ progress_log.append(f"Failure: Search exception for {person_name}: {e}")
188
+ return ""
189
+
190
+ def extract_search_results(person_name, progress_log, client, us_professor_profiles, not_us_professor_profiles, top_hits):
191
+ """Use LLM to extract data from search results."""
192
+ try:
193
+ prompt = get_prompt(person_name, top_hits)
194
+ gpt_output = run_chatgpt(prompt, client) # LLM plz help
195
+ gpt_json = json.loads(gpt_output)
196
+ gpt_profile = {"name": person_name}
197
+ gpt_profile.update(gpt_json)
198
+ check_json(gpt_profile)
199
+ if gpt_profile["is_professor"] and gpt_profile["us_university"]:
200
+ us_professor_profiles.append(gpt_profile)
201
+ else:
202
+ not_us_professor_profiles.append(gpt_profile)
203
+ except Exception as e:
204
+ print(f"LLM exception for {person_name}: ", e)
205
+ progress_log.append(f"Failure: LLM exception for {person_name}: {e}")
206
+
207
+ def research_person(person_name, client, progress_log, us_professor_profiles, not_us_professor_profiles):
208
+ """Research who this person is and save results."""
209
+ top_hits = search_person(person_name, progress_log)
210
+ if top_hits == "":
211
+ return
212
+ extract_search_results(person_name, progress_log, client, us_professor_profiles, not_us_professor_profiles, top_hits)
213
+
214
+
215
+ def get_authors(save_dir="data/conference", min_papers=3, ignore_first_author=True):
216
+ """
217
+ Reduce the list of authors to those with at least `min_papers` papers for
218
+ which they are not first authors. Ignores solo-authored papers and papers
219
+ with more than 20 authors.
220
+
221
+ Filters authors so that we don't have to do RAG on every author, which is
222
+ monetarily expensive. Feel free to edit if you have more resources.
223
+ """
224
+ authors = defaultdict(int)
225
+ for fname in os.listdir(save_dir):
226
+ if not fname.endswith('.json'):
227
+ continue
228
+
229
+ with open(os.path.join(save_dir, fname), 'r') as file:
230
+ for line in file:
231
+ item = json.loads(line)
232
+ paper_authors = [x.strip() for x in item[1].split(",")]
233
+
234
+ # ignore solo-authored papers and papers with more than 20 authors
235
+ if len(paper_authors) == 1 or len(paper_authors) > 20:
236
+ continue
237
+
238
+ # professors generally are not first authors
239
+ if not ignore_first_author and len(paper_authors) > 0:
240
+ authors[paper_authors[0]] += 1
241
+ for i in range(1, len(paper_authors)):
242
+ authors[paper_authors[i]] += 1
243
+
244
+ authors = {k: v for k, v in authors.items() if v >= min_papers}
245
+ os.makedirs(save_dir, exist_ok=True)
246
+ with open(os.path.join(save_dir, "authors.txt"), 'w') as f:
247
+ for k, v in authors.items():
248
+ f.write(f"{k}\t{v}\n")
249
+ return authors
250
+
251
+ def research_conference_profiles(save_freq=20):
252
+ """Research each author as a stream.
253
+
254
+ NOTE: cannot deal w/ interrupts and continue from past progress.
255
+ """
256
+
257
+ authors = get_authors("data/conference")
258
+ person_names = list(authors.keys())
259
+
260
+ client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
261
+
262
+ progress_log = []
263
+ us_professor_profiles = []
264
+ not_us_professor_profiles = []
265
+
266
+ def log_save_print(progress_log, us_professor_profiles, not_us_professor_profiles, i):
267
+ log_progress_to_file(progress_log, 'logs/progress_log.tmp')
268
+ save_json(us_professor_profiles, 'data/professor/us_professor.json')
269
+ save_json(not_us_professor_profiles, 'data/professor/not_us_professor.json')
270
+ print(f"Saved profiles to data/professor/us_professor.json and data/professor/not_us_professor.json after processing {i} people")
271
+
272
+ for i in range(len(person_names)):
273
+ research_person(person_names[i], client, progress_log, us_professor_profiles, not_us_professor_profiles)
274
+ if i % save_freq == 0:
275
+ log_save_print(progress_log, us_professor_profiles, not_us_professor_profiles, i)
276
+
277
+ log_save_print(progress_log, us_professor_profiles, not_us_professor_profiles, i)
278
+ print("Research complete.")
279
+
280
+ def batch_search_person(person_names, progress_log, save_freq=20):
281
+ """Searches everyone given in `person_names`."""
282
+ # might start and stop, pull from previous efforts
283
+ try:
284
+ prev_researched_authors = load_json("data/professor/search_results.json")
285
+ except:
286
+ prev_researched_authors = []
287
+ ignore_set = set([x[0] for x in prev_researched_authors])
288
+ data = prev_researched_authors
289
+ unseen_person_names = []
290
+ for person_name in person_names:
291
+ if person_name not in ignore_set:
292
+ unseen_person_names.append(person_name)
293
+ print(f"Already researched {len(ignore_set)} / {len(person_names)} = {len(ignore_set) / len(person_names)} of the dataset")
294
+ person_names = unseen_person_names
295
+
296
+ # continue search
297
+ for i in tqdm(range(len(person_names))):
298
+ if person_names[i] in ignore_set:
299
+ continue # seen before
300
+
301
+ query_start = time.time()
302
+ top_hits = search_person(person_names[i], progress_log)
303
+ if top_hits != "":
304
+ data.append([person_names[i], top_hits])
305
+
306
+ if i % save_freq == 0:
307
+ save_json(data, "data/professor/search_results.json")
308
+ log_progress_to_file(progress_log, 'logs/progress_log.tmp')
309
+
310
+ # 3 queries per second max
311
+ wait_time = max(time.time() - (query_start + 0.334), 0.0)
312
+ time.sleep(wait_time)
313
+
314
+ save_json(data, "data/professor/search_results.json")
315
+ log_progress_to_file(progress_log, 'logs/progress_log.tmp')
316
+
317
+ def write_batch_files(search_results_path,
318
+ prompt_data_path_prefix,
319
+ model="gpt-4o-mini",
320
+ max_tokens=1000,
321
+ temperature=0.0,
322
+ seed=123,
323
+ batch_size=1999, # max_tokens * batch_size < 2M?
324
+ verbose=0):
325
+ """Convert search results dump to jsonl for LLM batch request."""
326
+ with open(search_results_path, "r") as f:
327
+ search_results = json.load(f)
328
+
329
+ prompt_datas = []
330
+ for search_result in search_results:
331
+ prompt_data = {
332
+ "custom_id": f"request-{search_result[0]}", # don't change, needed for decoding
333
+ "method": "POST",
334
+ "url": "/v1/chat/completions",
335
+ "body": {
336
+ "model": model,
337
+ "temperature": temperature,
338
+ "seed": seed,
339
+ "messages": [{"role": "user", "content": get_prompt(search_result[0], search_result[1])}],
340
+ "max_tokens": max_tokens
341
+ }
342
+ }
343
+ prompt_datas.append(prompt_data)
344
+
345
+ print(f"Number of prompts: {len(prompt_datas)}")
346
+ if verbose > 0:
347
+ print(get_prompt(search_result[0], search_result[1]))
348
+
349
+ batch_paths = []
350
+ for i in range(0, len(prompt_datas) // batch_size + 1):
351
+ prompt_data_path = f"{prompt_data_path_prefix}_{i}.jsonl"
352
+ batch_range = i * batch_size, (min(len(prompt_datas), (i + 1) * batch_size))
353
+ with open(prompt_data_path, "w") as f:
354
+ for prompt_data in prompt_datas[batch_range[0]:batch_range[1]]:
355
+ f.write(json.dumps(prompt_data) + "\n")
356
+ batch_paths.append(prompt_data_path)
357
+
358
+ return batch_paths
359
+
360
+ def send_batch_files(prompt_data_path_prefix, batch_paths, client, timeout=24*60*60):
361
+ """Create and send the batch request to API endpoint."""
362
+ batches = []
363
+
364
+ print("Batching and sending requests...")
365
+ for batch_path in tqdm(batch_paths):
366
+ batch_input_file = client.files.create(
367
+ file=open(batch_path, "rb"),
368
+ purpose="batch"
369
+ )
370
+
371
+ batch_input_file_id = batch_input_file.id
372
+ print(f"Batch input file ID: {batch_input_file_id}")
373
+
374
+ batch = client.batches.create(
375
+ input_file_id=batch_input_file_id,
376
+ endpoint="/v1/chat/completions",
377
+ completion_window="24h",
378
+ metadata={
379
+ "description": "search extraction job"
380
+ }
381
+ )
382
+
383
+ begin = time.time()
384
+ while time.time() - begin < timeout:
385
+ batch = client.batches.retrieve(batch.id)
386
+ if batch.status == "completed":
387
+ break
388
+ time.sleep(40)
389
+ print(f"Status ({time.time()-begin:2f}): {batch.status}")
390
+ print("seconds elapsed: ", time.time() - begin)
391
+ batches.append(batch)
392
+
393
+ # Keeps track of the paths to the batch files
394
+ with open(f"{prompt_data_path_prefix}_batches.pkl", "wb") as f:
395
+ pickle.dump(batches, f)
396
+ with open(f"{prompt_data_path_prefix}_ids.txt", "w") as f:
397
+ f.write("\n".join([x.id for x in batches]))
398
+ return batches
399
+
400
+ def retrieve_batch_output(client, batch_id):
401
+ """OpenAI batch requests finish within 24 hrs."""
402
+ retrieved_batch = client.batches.retrieve(batch_id)
403
+ if retrieved_batch.status == "completed":
404
+ return client.files.content(retrieved_batch.output_file_id).text
405
+ else:
406
+ print("Batch process is still in progress.")
407
+ print(retrieved_batch)
408
+ return "INCOMPLETE"
409
+
410
+ def batch_process_llm_output(client, batches):
411
+ client = OpenAI()
412
+
413
+ outputs = []
414
+ for batch in batches:
415
+ batch_id = batch.id
416
+ output = retrieve_batch_output(client, batch_id)
417
+ if output == "INCOMPLETE":
418
+ return
419
+ outputs.append(output)
420
+
421
+ for output in outputs:
422
+ json_objects = output.split('\n')
423
+ custom_id_idx = len("request-") # where the name begins in "custom_id"
424
+
425
+ progress_log = []
426
+ us_professor_profiles = []
427
+ not_us_professor_profiles = []
428
+
429
+ for json_obj in json_objects:
430
+ if json_obj == '': continue
431
+
432
+ try:
433
+ parsed_data = json.loads(json_obj)
434
+ message_content = parsed_data["response"]["body"]["choices"][0]["message"]["content"]
435
+ gpt_json = json.loads(message_content)
436
+ gpt_profile = {"name": parsed_data["custom_id"][custom_id_idx:]}
437
+ gpt_profile.update(gpt_json)
438
+ check_json(gpt_profile)
439
+ if gpt_profile["is_professor"] and gpt_profile["us_university"]:
440
+ us_professor_profiles.append(gpt_profile)
441
+ else:
442
+ not_us_professor_profiles.append(gpt_profile)
443
+
444
+ progress_log.append(f"Success: Parsed LLM output for {gpt_profile['name']}")
445
+ except Exception as e:
446
+ try:
447
+ print(f"Failed to parse json object for custom-id `{parsed_data['custom_id']}`: {e}")
448
+ progress_log.append(f"Failed: Parsed LLM output for {gpt_profile['name']}: {e}")
449
+ except Exception as e2:
450
+ print(f"Failed to parse json object `{json_obj}`: {e2}")
451
+ progress_log.append(f"Failed UNKNOWN: Parsed LLM output: {e2}")
452
+
453
+ with open("data/professor/us_professor.json", 'w') as file:
454
+ json.dump(us_professor_profiles, file, indent=4)
455
+ with open("data/professor/not_us_professor.json", 'w') as file:
456
+ json.dump(not_us_professor_profiles, file, indent=4)
457
+
458
+ def main():
459
+ import argparse
460
+
461
+ parser = argparse.ArgumentParser(
462
+ description="US Professor Verifier: Search or LLM-Analyze batch operations."
463
+ )
464
+
465
+ # Add mutually exclusive group to ensure only one of the arguments is passed
466
+ group = parser.add_mutually_exclusive_group(required=True)
467
+ group.add_argument(
468
+ '--batch_search',
469
+ action='store_true',
470
+ help='Batch search the authors.'
471
+ )
472
+ group.add_argument(
473
+ '--batch_analyze',
474
+ action='store_true',
475
+ help='Sends search results to LLM for analysis.'
476
+ )
477
+ group.add_argument(
478
+ '--batch_retrieve',
479
+ action='store_true',
480
+ help='Retrieve results from an LLM batch request, requires --batch_id'
481
+ )
482
+
483
+ parser.add_argument(
484
+ '--batch_ids_path',
485
+ type=str,
486
+ help='The batch ID for retrieval'
487
+ )
488
+
489
+ args = parser.parse_args()
490
+
491
+ prompt_data_path_prefix = "data/professor/prompt_data"
492
+
493
+ if args.batch_search:
494
+ authors = get_authors("data/conference")
495
+ authors_list = list(authors.keys())
496
+ print("Researching people...")
497
+ progress_log = []
498
+ batch_search_person(authors_list, progress_log, save_freq=20)
499
+ elif args.batch_analyze:
500
+ client = OpenAI()
501
+ batch_paths = write_batch_files("data/professor/search_results.json", prompt_data_path_prefix)
502
+ send_batch_files(prompt_data_path_prefix, batch_paths, client)
503
+ elif args.batch_retrieve:
504
+ client = OpenAI()
505
+ with open(f"{prompt_data_path_prefix}_batches.pkl", "rb") as f:
506
+ batches = pickle.load(f)
507
+ batch_process_llm_output(client, batches)
508
+ else:
509
+ raise ValueError("Please specify --batch_search, --batch_analyze, or --batch_retrieve.")
510
+
511
+
512
+ if __name__ == "__main__":
513
+ main()
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python-dotenv
2
+ openai
3
+ langchain-together
4
+ lxml
5
+
6
+ einops
7
+ torch and everything else
8
+ datasets
9
+ transformers
10
+
11
+ datasets
12
+ transformers