Spaces:
Sleeping
Sleeping
Commit ·
6c2a7c2
0
Parent(s):
first commit
Browse files- .gitignore +170 -0
- README.md +48 -0
- data_pipeline/__init__.py +0 -0
- data_pipeline/conference_scraper.py +262 -0
- data_pipeline/download_arxiv_kaggle.py +281 -0
- data_pipeline/loaders.py +22 -0
- data_pipeline/requirements.txt +0 -0
- data_pipeline/schools_scraper.py +196 -0
- data_pipeline/us_professor_verifier.py +513 -0
- requirements.txt +12 -0
.gitignore
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.env
|
| 2 |
+
data/*
|
| 3 |
+
runs/*
|
| 4 |
+
logs/*
|
| 5 |
+
nbs/*
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
# Byte-compiled / optimized / DLL files
|
| 10 |
+
__pycache__/
|
| 11 |
+
*.py[cod]
|
| 12 |
+
*$py.class
|
| 13 |
+
|
| 14 |
+
# C extensions
|
| 15 |
+
*.so
|
| 16 |
+
|
| 17 |
+
# Distribution / packaging
|
| 18 |
+
.Python
|
| 19 |
+
build/
|
| 20 |
+
develop-eggs/
|
| 21 |
+
dist/
|
| 22 |
+
downloads/
|
| 23 |
+
eggs/
|
| 24 |
+
.eggs/
|
| 25 |
+
lib/
|
| 26 |
+
lib64/
|
| 27 |
+
parts/
|
| 28 |
+
sdist/
|
| 29 |
+
var/
|
| 30 |
+
wheels/
|
| 31 |
+
share/python-wheels/
|
| 32 |
+
*.egg-info/
|
| 33 |
+
.installed.cfg
|
| 34 |
+
*.egg
|
| 35 |
+
MANIFEST
|
| 36 |
+
|
| 37 |
+
# PyInstaller
|
| 38 |
+
# Usually these files are written by a python script from a template
|
| 39 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 40 |
+
*.manifest
|
| 41 |
+
*.spec
|
| 42 |
+
|
| 43 |
+
# Installer logs
|
| 44 |
+
pip-log.txt
|
| 45 |
+
pip-delete-this-directory.txt
|
| 46 |
+
|
| 47 |
+
# Unit test / coverage reports
|
| 48 |
+
htmlcov/
|
| 49 |
+
.tox/
|
| 50 |
+
.nox/
|
| 51 |
+
.coverage
|
| 52 |
+
.coverage.*
|
| 53 |
+
.cache
|
| 54 |
+
nosetests.xml
|
| 55 |
+
coverage.xml
|
| 56 |
+
*.cover
|
| 57 |
+
*.py,cover
|
| 58 |
+
.hypothesis/
|
| 59 |
+
.pytest_cache/
|
| 60 |
+
cover/
|
| 61 |
+
|
| 62 |
+
# Translations
|
| 63 |
+
*.mo
|
| 64 |
+
*.pot
|
| 65 |
+
|
| 66 |
+
# Django stuff:
|
| 67 |
+
*.log
|
| 68 |
+
local_settings.py
|
| 69 |
+
db.sqlite3
|
| 70 |
+
db.sqlite3-journal
|
| 71 |
+
|
| 72 |
+
# Flask stuff:
|
| 73 |
+
instance/
|
| 74 |
+
.webassets-cache
|
| 75 |
+
|
| 76 |
+
# Scrapy stuff:
|
| 77 |
+
.scrapy
|
| 78 |
+
|
| 79 |
+
# Sphinx documentation
|
| 80 |
+
docs/_build/
|
| 81 |
+
|
| 82 |
+
# PyBuilder
|
| 83 |
+
.pybuilder/
|
| 84 |
+
target/
|
| 85 |
+
|
| 86 |
+
# Jupyter Notebook
|
| 87 |
+
.ipynb_checkpoints
|
| 88 |
+
|
| 89 |
+
# IPython
|
| 90 |
+
profile_default/
|
| 91 |
+
ipython_config.py
|
| 92 |
+
|
| 93 |
+
# pyenv
|
| 94 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 95 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 96 |
+
# .python-version
|
| 97 |
+
|
| 98 |
+
# pipenv
|
| 99 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 100 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 101 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 102 |
+
# install all needed dependencies.
|
| 103 |
+
#Pipfile.lock
|
| 104 |
+
|
| 105 |
+
# poetry
|
| 106 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 107 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 108 |
+
# commonly ignored for libraries.
|
| 109 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 110 |
+
#poetry.lock
|
| 111 |
+
|
| 112 |
+
# pdm
|
| 113 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 114 |
+
#pdm.lock
|
| 115 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 116 |
+
# in version control.
|
| 117 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
| 118 |
+
.pdm.toml
|
| 119 |
+
.pdm-python
|
| 120 |
+
.pdm-build/
|
| 121 |
+
|
| 122 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 123 |
+
__pypackages__/
|
| 124 |
+
|
| 125 |
+
# Celery stuff
|
| 126 |
+
celerybeat-schedule
|
| 127 |
+
celerybeat.pid
|
| 128 |
+
|
| 129 |
+
# SageMath parsed files
|
| 130 |
+
*.sage.py
|
| 131 |
+
|
| 132 |
+
# Environments
|
| 133 |
+
.env
|
| 134 |
+
.venv
|
| 135 |
+
env/
|
| 136 |
+
venv/
|
| 137 |
+
ENV/
|
| 138 |
+
env.bak/
|
| 139 |
+
venv.bak/
|
| 140 |
+
|
| 141 |
+
# Spyder project settings
|
| 142 |
+
.spyderproject
|
| 143 |
+
.spyproject
|
| 144 |
+
|
| 145 |
+
# Rope project settings
|
| 146 |
+
.ropeproject
|
| 147 |
+
|
| 148 |
+
# mkdocs documentation
|
| 149 |
+
/site
|
| 150 |
+
|
| 151 |
+
# mypy
|
| 152 |
+
.mypy_cache/
|
| 153 |
+
.dmypy.json
|
| 154 |
+
dmypy.json
|
| 155 |
+
|
| 156 |
+
# Pyre type checker
|
| 157 |
+
.pyre/
|
| 158 |
+
|
| 159 |
+
# pytype static type analyzer
|
| 160 |
+
.pytype/
|
| 161 |
+
|
| 162 |
+
# Cython debug symbols
|
| 163 |
+
cython_debug/
|
| 164 |
+
|
| 165 |
+
# PyCharm
|
| 166 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 167 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 168 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 169 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 170 |
+
#.idea/
|
README.md
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# U.S. ML PhD Recomendation System
|
| 2 |
+
|
| 3 |
+
Disclaimer: results are not 100% accurate and there is likely some bias to how papers / professors are filtered.
|
| 4 |
+
|
| 5 |
+
### Data Pipeline
|
| 6 |
+
|
| 7 |
+
First, a list of authors are gathered from recent conference proceedings. A batched RAG pipeline is used to determine which persons are U.S. professors (unsure how accurate the LLM here is). This can be reproduced as follows:
|
| 8 |
+
|
| 9 |
+
#### Repeat research until satisfactory
|
| 10 |
+
|
| 11 |
+
```python
|
| 12 |
+
# Scrape top conferences for potential U.S.-based professors, ~45 mins
|
| 13 |
+
python -m data_pipeline.conference_scraper
|
| 14 |
+
```
|
| 15 |
+
**Selected conferences**
|
| 16 |
+
- NeurIPS: 2022, 2023
|
| 17 |
+
- ICML: 2023, 2024
|
| 18 |
+
- AISTATS: 2023, 2024
|
| 19 |
+
- COLT: 2023, 2024
|
| 20 |
+
- AAAI: 2023, 2024
|
| 21 |
+
- EMNLP: 2023, 2024
|
| 22 |
+
- CVPR: 2023, 2024
|
| 23 |
+
|
| 24 |
+
```python
|
| 25 |
+
# Search authors and locally store search results. Uses Bing web search API.
|
| 26 |
+
python -m data_pipeline.us_professor_verifier --batch_search
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
NOTE 1: you may encounter caught exceptions due to HTTPError or invalid JSON outputs from the LLM. Would suggest to run the above multiple times until results are satisfactory.
|
| 30 |
+
|
| 31 |
+
NOTE 2: This pipeline does not handle name collisions, name changes, initials.
|
| 32 |
+
|
| 33 |
+
#### Create file containing U.S. professor data
|
| 34 |
+
|
| 35 |
+
```python
|
| 36 |
+
# Use locally stored search results as input to an LLM.
|
| 37 |
+
# Sends as batches, each one waiting for the previous to finish.
|
| 38 |
+
python -m data_pipeline.us_professor_verifier --batch_analyze
|
| 39 |
+
# After some time (at most 24 hrs per batch, ~5 batches), the batch results become available for retrieval.
|
| 40 |
+
# Took ~1 hr for me
|
| 41 |
+
python -m data_pipeline.us_professor_verifier --batch_retrieve
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
#### Extract embeddings for the relevant papers
|
| 45 |
+
```python
|
| 46 |
+
# Fetch arxiv data and extract embeddings
|
| 47 |
+
python -m data_pipeline.download_arxiv_kaggle
|
| 48 |
+
```
|
data_pipeline/__init__.py
ADDED
|
File without changes
|
data_pipeline/conference_scraper.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Scrape data from some famous ML conferences and saves into data/conference.
|
| 2 |
+
|
| 3 |
+
Every scrape function returns a list of 3-lists of the form
|
| 4 |
+
[paper_title, paper_authors, paper_url].
|
| 5 |
+
|
| 6 |
+
Conferences
|
| 7 |
+
-----------
|
| 8 |
+
NeurIPS: 2022, 2023
|
| 9 |
+
ICML: 2023, 2024
|
| 10 |
+
AISTATS: 2023, 2024
|
| 11 |
+
COLT: 2023, 2024
|
| 12 |
+
AAAI: 2023, 2024
|
| 13 |
+
EMNLP: 2023, 2024
|
| 14 |
+
CVPR: 2023, 2024
|
| 15 |
+
-----------
|
| 16 |
+
|
| 17 |
+
Disclaimer
|
| 18 |
+
-----------
|
| 19 |
+
The choice of conferences was sourced from here:
|
| 20 |
+
https://www.kaggle.com/discussions/getting-started/115799
|
| 21 |
+
|
| 22 |
+
The priority of including certain conferences and tracks was based on a 1st-year PhD's
|
| 23 |
+
judgment. Some very top conferences were excluded due to higher activation energy to
|
| 24 |
+
scrape data and/or the ignorance of the 1st-year PhD. Some notable exceptions include
|
| 25 |
+
ICLR, ICCV, ECCV, ACL, NAACL, and many others.
|
| 26 |
+
-----------
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
from collections import defaultdict
|
| 30 |
+
from functools import partial
|
| 31 |
+
import json
|
| 32 |
+
import os
|
| 33 |
+
import requests
|
| 34 |
+
import time
|
| 35 |
+
|
| 36 |
+
from bs4 import BeautifulSoup
|
| 37 |
+
from tqdm import tqdm
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
SAVE_DIR = "data/conference"
|
| 41 |
+
|
| 42 |
+
def scrape_nips(year):
|
| 43 |
+
nips_url = f"https://papers.nips.cc/paper/{year}"
|
| 44 |
+
response = requests.get(nips_url)
|
| 45 |
+
soup = BeautifulSoup(response.text, "html.parser")
|
| 46 |
+
|
| 47 |
+
conference_items = soup.find_all('li')
|
| 48 |
+
conference_items = [[ci.a.get_text(), ci.i.get_text(), ci.a['href']] for ci in conference_items]
|
| 49 |
+
conference_items = [ci for ci in conference_items if ci[0]!="" and ci[1]!=""]
|
| 50 |
+
return conference_items
|
| 51 |
+
|
| 52 |
+
def scrape_mlr_proceedings(conference, year):
|
| 53 |
+
|
| 54 |
+
cy2v = {
|
| 55 |
+
("ICML", 2024): "v235",
|
| 56 |
+
("ICML", 2023): "v202",
|
| 57 |
+
("AISTATS", 2024): "v238",
|
| 58 |
+
("AISTATS", 2023): "v206",
|
| 59 |
+
("COLT", 2024): "v247",
|
| 60 |
+
("COLT", 2023): "v195",
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
conference_url = f"https://proceedings.mlr.press/{cy2v[(conference, year)]}"
|
| 64 |
+
response = requests.get(conference_url)
|
| 65 |
+
soup = BeautifulSoup(response.text, "html.parser")
|
| 66 |
+
|
| 67 |
+
conference_items = soup.find_all('div', class_="paper")
|
| 68 |
+
conference_items = [
|
| 69 |
+
[
|
| 70 |
+
ci.find('p', class_="title").get_text(),
|
| 71 |
+
ci.find('p', class_="details").find('span', class_="authors").get_text(),
|
| 72 |
+
ci.find('p', class_="links").find('a')['href']
|
| 73 |
+
]
|
| 74 |
+
for ci in conference_items
|
| 75 |
+
]
|
| 76 |
+
return conference_items
|
| 77 |
+
|
| 78 |
+
def scrape_aaai():
|
| 79 |
+
# Scrape the technical tracks of past two years ('23, '24)
|
| 80 |
+
# Look at first two pages of archives that give links to tracks
|
| 81 |
+
# Look at each track
|
| 82 |
+
|
| 83 |
+
headers = {
|
| 84 |
+
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3"
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
# First two pages
|
| 88 |
+
track_links = []
|
| 89 |
+
|
| 90 |
+
aaai_urls = [
|
| 91 |
+
"https://ojs.aaai.org/index.php/AAAI/issue/archive",
|
| 92 |
+
"https://ojs.aaai.org/index.php/AAAI/issue/archive/2",
|
| 93 |
+
]
|
| 94 |
+
|
| 95 |
+
for aaai_url in aaai_urls:
|
| 96 |
+
|
| 97 |
+
response = requests.get(aaai_url, headers=headers)
|
| 98 |
+
soup = BeautifulSoup(response.text, "html.parser")
|
| 99 |
+
|
| 100 |
+
tracks = [track.find('a', class_="title") for track in soup.find_all('h2')]
|
| 101 |
+
track_links.extend(
|
| 102 |
+
[(track.text.strip(), track['href']) for track in tracks if track is not None]
|
| 103 |
+
)
|
| 104 |
+
print(track_links)
|
| 105 |
+
|
| 106 |
+
time.sleep(60) # respect scraping limits
|
| 107 |
+
|
| 108 |
+
# only look at past two years
|
| 109 |
+
track_links = [track_link for track_link in track_links if "AAAI-24" in track_link[0] or "AAAI-23" in track_link[0]]
|
| 110 |
+
print("track links: ", track_links)
|
| 111 |
+
|
| 112 |
+
conference_items = []
|
| 113 |
+
|
| 114 |
+
for track_link in tqdm(track_links):
|
| 115 |
+
print(f"Going through track {track_link[0]} @ {track_link[1]} ")
|
| 116 |
+
|
| 117 |
+
# Scrape tracks
|
| 118 |
+
response = requests.get(track_link[1], headers=headers)
|
| 119 |
+
soup = BeautifulSoup(response.text, "html.parser")
|
| 120 |
+
|
| 121 |
+
articles = soup.find_all('div', class_="obj_article_summary")
|
| 122 |
+
|
| 123 |
+
for article in articles:
|
| 124 |
+
|
| 125 |
+
aref = article.find('a')
|
| 126 |
+
conference_items.append(
|
| 127 |
+
[
|
| 128 |
+
aref.text.strip(),
|
| 129 |
+
article.find('div', class_="authors").text.strip(),
|
| 130 |
+
aref['href'],
|
| 131 |
+
]
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
time.sleep(60) # respect scraping limits
|
| 135 |
+
|
| 136 |
+
return conference_items
|
| 137 |
+
|
| 138 |
+
def scrape_emnlp(year):
|
| 139 |
+
|
| 140 |
+
emnlp_url = f"https://{year}.emnlp.org/program/accepted_main_conference/"
|
| 141 |
+
|
| 142 |
+
headers = {
|
| 143 |
+
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3"
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
response = requests.get(emnlp_url, headers=headers)
|
| 147 |
+
soup = BeautifulSoup(response.text, "html.parser")
|
| 148 |
+
|
| 149 |
+
ps = soup.find_all('p')
|
| 150 |
+
conference_items = [[p.contents[0].text, p.contents[-1].text, ''] for p in ps]
|
| 151 |
+
return conference_items
|
| 152 |
+
|
| 153 |
+
def scrape_cvpr(year):
|
| 154 |
+
cvpr_url = f"https://openaccess.thecvf.com/CVPR{year}?day=all"
|
| 155 |
+
response = requests.get(cvpr_url)
|
| 156 |
+
soup = BeautifulSoup(response.text, "html.parser")
|
| 157 |
+
|
| 158 |
+
# Separately extract title/link and authors
|
| 159 |
+
dts = soup.find_all('dt', class_="ptitle")
|
| 160 |
+
conference_items = [(dt.text, '', dt.a['href']) for dt in dts]
|
| 161 |
+
|
| 162 |
+
dds = soup.find_all('dd')
|
| 163 |
+
authors = []
|
| 164 |
+
for dd in dds:
|
| 165 |
+
if dd.find('form') is not None: # author entry
|
| 166 |
+
authors.append(
|
| 167 |
+
', '.join([x.text for x in dd.find_all('a')])
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
conference_items = [[dt.text, author, dt.a['href']] for dt, author in zip(dts, authors)]
|
| 171 |
+
return conference_items
|
| 172 |
+
|
| 173 |
+
def save_to_file(conference_items, filename):
|
| 174 |
+
with open(filename, 'w') as f:
|
| 175 |
+
for item in conference_items:
|
| 176 |
+
f.write(json.dumps(item) + '\n')
|
| 177 |
+
|
| 178 |
+
def load_from_file(filename):
|
| 179 |
+
with open(filename, 'r') as f:
|
| 180 |
+
conference_items = [json.loads(line) for line in f]
|
| 181 |
+
return conference_items
|
| 182 |
+
|
| 183 |
+
def main():
|
| 184 |
+
|
| 185 |
+
scrape_functions = {
|
| 186 |
+
"NeurIPS-2022": partial(scrape_nips, 2022),
|
| 187 |
+
"NeurIPS-2023": partial(scrape_nips, 2023),
|
| 188 |
+
"ICML-2023": partial(scrape_mlr_proceedings, "ICML", 2023),
|
| 189 |
+
"ICML-2024": partial(scrape_mlr_proceedings, "ICML", 2024),
|
| 190 |
+
"AISTATS-2023": partial(scrape_mlr_proceedings, "AISTATS", 2023),
|
| 191 |
+
"AISTATS-2024": partial(scrape_mlr_proceedings, "AISTATS", 2024),
|
| 192 |
+
"COLT-2023": partial(scrape_mlr_proceedings, "COLT", 2023),
|
| 193 |
+
"COLT-2024": partial(scrape_mlr_proceedings, "COLT", 2024),
|
| 194 |
+
"AAAI": scrape_aaai, # easier to scrape both years at once, takes ~40 mins
|
| 195 |
+
"EMNLP-2023": partial(scrape_emnlp, 2023),
|
| 196 |
+
"EMNLP-2024": partial(scrape_emnlp, 2024),
|
| 197 |
+
"CVPR-2023": partial(scrape_cvpr, 2023),
|
| 198 |
+
"CVPR-2024": partial(scrape_cvpr, 2024),
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
def load_progress():
|
| 202 |
+
if os.path.exists(SAVE_DIR):
|
| 203 |
+
file_paths = os.listdir(SAVE_DIR)
|
| 204 |
+
file_paths = [file_path for file_path in file_paths if file_path.endswith('.json')]
|
| 205 |
+
file_paths = [file_path.split('.')[0] for file_path in file_paths]
|
| 206 |
+
return set(file_paths)
|
| 207 |
+
return set()
|
| 208 |
+
|
| 209 |
+
def save_progress(conference, file_path):
|
| 210 |
+
with open(file_path, 'a') as f:
|
| 211 |
+
f.write(conference + '\n')
|
| 212 |
+
|
| 213 |
+
def log_progress(msg, conference, file_path):
|
| 214 |
+
with open(file_path, 'a') as f:
|
| 215 |
+
f.write(conference + ': ' + msg + '\n')
|
| 216 |
+
|
| 217 |
+
os.makedirs(SAVE_DIR, exist_ok=True)
|
| 218 |
+
|
| 219 |
+
# Load previous progress
|
| 220 |
+
scraped_conferences = load_progress()
|
| 221 |
+
|
| 222 |
+
# Progress file for current scrape
|
| 223 |
+
progress_file = "conference_scraper_progress.tmp"
|
| 224 |
+
|
| 225 |
+
for conference, scrape_function in tqdm(scrape_functions.items()):
|
| 226 |
+
|
| 227 |
+
if conference in scraped_conferences:
|
| 228 |
+
print(f"Skipping {conference}, already scraped.")
|
| 229 |
+
log_progress("Success!", conference, progress_file)
|
| 230 |
+
continue
|
| 231 |
+
|
| 232 |
+
try:
|
| 233 |
+
|
| 234 |
+
print(f"Scraping {conference}")
|
| 235 |
+
save_path = os.path.join(SAVE_DIR, f"{conference}.json")
|
| 236 |
+
conference_items = scrape_function()
|
| 237 |
+
save_to_file(conference_items, save_path)
|
| 238 |
+
print(f"Saved {conference} to {str(save_path)}")
|
| 239 |
+
save_progress(conference, progress_file)
|
| 240 |
+
log_progress("Success!", conference, progress_file)
|
| 241 |
+
|
| 242 |
+
except Exception as e:
|
| 243 |
+
print(f"Error scraping {conference}: {e}")
|
| 244 |
+
log_progress(f"ERROR: {e}", conference, progress_file)
|
| 245 |
+
continue
|
| 246 |
+
|
| 247 |
+
# Remove progress file
|
| 248 |
+
os.remove(progress_file)
|
| 249 |
+
|
| 250 |
+
def stats():
|
| 251 |
+
total = 0
|
| 252 |
+
for fname in os.listdir(SAVE_DIR):
|
| 253 |
+
with open(os.path.join(SAVE_DIR, fname), 'r') as file:
|
| 254 |
+
num_lines = sum(1 for _ in file)
|
| 255 |
+
print(fname + ": " + str(num_lines) + " lines")
|
| 256 |
+
total += num_lines
|
| 257 |
+
print("Total: " + str(total))
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
if __name__ == "__main__":
|
| 261 |
+
main()
|
| 262 |
+
stats()
|
data_pipeline/download_arxiv_kaggle.py
ADDED
|
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pulls papers from arxiv."""
|
| 2 |
+
from collections import defaultdict
|
| 3 |
+
from functools import partial
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
import heapq
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
import pickle
|
| 10 |
+
|
| 11 |
+
from datasets import Dataset
|
| 12 |
+
import kaggle
|
| 13 |
+
import numpy as np
|
| 14 |
+
import pandas as pd
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
from transformers import AutoTokenizer, AutoModel
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
arxiv_fname = "arxiv-metadata-oai-snapshot.json"
|
| 22 |
+
|
| 23 |
+
def download_arxiv_data(path = Path(".")):
|
| 24 |
+
"""Downloads and unzips arxiv dataset from Kaggle into the `data` subdirectory of `path`."""
|
| 25 |
+
dataset = "Cornell-University/arxiv"
|
| 26 |
+
data_path = path/"data"
|
| 27 |
+
|
| 28 |
+
if not any([arxiv_fname in file for file in os.listdir(data_path)]):
|
| 29 |
+
kaggle.api.dataset_download_cli(dataset, path=data_path, unzip=True)
|
| 30 |
+
else:
|
| 31 |
+
print(f"Data already downloaded at {data_path/arxiv_fname}.")
|
| 32 |
+
return data_path/arxiv_fname
|
| 33 |
+
|
| 34 |
+
def get_lbl_from_name(names):
|
| 35 |
+
"""Tuple (last_name, first_name, middle_name) => String 'first_name [middle_name] last_name'."""
|
| 36 |
+
return [
|
| 37 |
+
name[1] + ' ' + name[0] if name[2] == '' \
|
| 38 |
+
else name[1] + ' ' + name[2] + ' ' + name[0]
|
| 39 |
+
for name in names
|
| 40 |
+
]
|
| 41 |
+
|
| 42 |
+
def filter_arxiv_for_ml(arxiv_path, obtain_summary=False, authors_of_interest=None):
|
| 43 |
+
"""Sifts through downloaded arxiv file to find ML-related papers.
|
| 44 |
+
|
| 45 |
+
If `obtain_summary` is True, saves a pickled DataFrame to the same directory as
|
| 46 |
+
the downloaded arxiv file with the name `arxiv_fname` + `-summary.pkl`.
|
| 47 |
+
|
| 48 |
+
If `authors_of_interest` is not None, only save ML-related papers by those authors.
|
| 49 |
+
"""
|
| 50 |
+
ml_path = str(arxiv_path).split('.')[0]+'-ml.json'
|
| 51 |
+
summary_path = str(arxiv_path).split('.')[0]+'-summary.pkl'
|
| 52 |
+
|
| 53 |
+
ml_cats = ['cs.AI', 'cs.CL', 'cs.CV', 'cs.LG', 'stat.ML']
|
| 54 |
+
|
| 55 |
+
if obtain_summary and Path(ml_path).exists() and Path(summary_path).exists():
|
| 56 |
+
print(f"File {ml_path} with ML subset of arxiv already exists. Skipping.")
|
| 57 |
+
print(f"Summary file {summary_path} already exists. Skipping.")
|
| 58 |
+
return
|
| 59 |
+
if not obtain_summary and Path(ml_path).exists():
|
| 60 |
+
print(f"File {ml_path} with ML subset of arxiv already exists. Skipping.")
|
| 61 |
+
return
|
| 62 |
+
|
| 63 |
+
if obtain_summary:
|
| 64 |
+
gdf = {'categories': [], 'lv_date': []} # global data
|
| 65 |
+
|
| 66 |
+
if authors_of_interest:
|
| 67 |
+
authors_of_interest = set(authors_of_interest)
|
| 68 |
+
|
| 69 |
+
# Load the JSON file line by line
|
| 70 |
+
with open(arxiv_path, 'r') as f1, open(ml_path, 'w') as f2:
|
| 71 |
+
for line in tqdm(f1):
|
| 72 |
+
# Parse each line as JSON
|
| 73 |
+
try:
|
| 74 |
+
entry_data = json.loads(line)
|
| 75 |
+
except json.JSONDecodeError:
|
| 76 |
+
# Skip lines that cannot be parsed as JSON
|
| 77 |
+
continue
|
| 78 |
+
|
| 79 |
+
# check categories and last version in entry data
|
| 80 |
+
if (
|
| 81 |
+
obtain_summary
|
| 82 |
+
and 'categories' in entry_data
|
| 83 |
+
and 'versions' in entry_data
|
| 84 |
+
and len(entry_data['versions'])
|
| 85 |
+
and 'created' in entry_data['versions'][-1]
|
| 86 |
+
):
|
| 87 |
+
gdf['categories'].append(entry_data['categories'])
|
| 88 |
+
gdf['lv_date'].append(entry_data['versions'][-1]['created'])
|
| 89 |
+
|
| 90 |
+
# ml data
|
| 91 |
+
authors_on_paper = get_lbl_from_name(entry_data['authors_parsed'])
|
| 92 |
+
if ('categories' in entry_data
|
| 93 |
+
and (any(cat in entry_data['categories'] for cat in ml_cats))
|
| 94 |
+
and (any(author in authors_of_interest for author in authors_on_paper))
|
| 95 |
+
):
|
| 96 |
+
f2.write(line)
|
| 97 |
+
|
| 98 |
+
if obtain_summary:
|
| 99 |
+
gdf = pd.DataFrame(gdf)
|
| 100 |
+
gdf['lv_date'] = pd.to_datetime(gdf['lv_date'])
|
| 101 |
+
gdf = gdf.sort_values('lv_date', axis=0).reset_index(drop=True)
|
| 102 |
+
|
| 103 |
+
cats = set()
|
| 104 |
+
for cat_combo in gdf['categories'].unique():
|
| 105 |
+
cat_combo.split(' ')
|
| 106 |
+
cats.update(cat_combo.split(' '))
|
| 107 |
+
print(f'Columnizing {len(cats)} categories. ')
|
| 108 |
+
for cat in tqdm(cats):
|
| 109 |
+
gdf[cat] = pd.arrays.SparseArray(gdf['categories'].str.contains(cat), fill_value=0, dtype=np.int8)
|
| 110 |
+
|
| 111 |
+
# count number of categories item is associated with
|
| 112 |
+
gdf['ncats'] = gdf['categories'].str.count(' ') + 1
|
| 113 |
+
|
| 114 |
+
# write to pickle file
|
| 115 |
+
with open(summary_path, 'wb') as f:
|
| 116 |
+
pickle.dump(gdf, f)
|
| 117 |
+
|
| 118 |
+
def get_professors_and_relevant_papers(us_professors, k=8, cutoff=datetime(2022, 10, 1)):
|
| 119 |
+
"""
|
| 120 |
+
Returns a dictionary mapping U.S. professor names to a list of indices
|
| 121 |
+
corresponding to their most recent papers in `data/arxiv-metadata-oai-snapshot-ml.json`.
|
| 122 |
+
This function is necessary to specify the papers we are interested in for each
|
| 123 |
+
professor (e.g., the most recent papers after cutoff)
|
| 124 |
+
|
| 125 |
+
Parameters:
|
| 126 |
+
- us_professors: A list of U.S. professor names to match against.
|
| 127 |
+
- k: The number of most recent papers to keep for each professor, based on
|
| 128 |
+
the first version upload date.
|
| 129 |
+
- cutoff (datetime): Only considers papers published after this date
|
| 130 |
+
(default: October 1, 2022).
|
| 131 |
+
|
| 132 |
+
Returns:
|
| 133 |
+
- dict: A dictionary where keys are professor names and values are lists of
|
| 134 |
+
indices corresponding to their most recent papers.
|
| 135 |
+
"""
|
| 136 |
+
# professors to tuple of (datetime, papers)
|
| 137 |
+
p2p = defaultdict(list)
|
| 138 |
+
|
| 139 |
+
with open('data/arxiv-metadata-oai-snapshot-ml.json', 'r') as f:
|
| 140 |
+
line_nbr = 1
|
| 141 |
+
while True:
|
| 142 |
+
line = f.readline()
|
| 143 |
+
if not line: break
|
| 144 |
+
|
| 145 |
+
try:
|
| 146 |
+
ml_data = json.loads(line)
|
| 147 |
+
paper_authors = get_lbl_from_name(ml_data['authors_parsed'])
|
| 148 |
+
|
| 149 |
+
# filter the same way as in `conference_scraper.py`
|
| 150 |
+
# ignore solo-authored papers and papers with more than 20 authors
|
| 151 |
+
if len(paper_authors) == 1 or len(paper_authors) > 20:
|
| 152 |
+
continue
|
| 153 |
+
|
| 154 |
+
try:
|
| 155 |
+
dt = datetime.strptime(ml_data["versions"][0]["created"], '%a, %d %b %Y %H:%M:%S %Z')
|
| 156 |
+
if dt < cutoff:
|
| 157 |
+
continue
|
| 158 |
+
except (KeyError, ValueError) as e:
|
| 159 |
+
print(f"Failed to parse date: {ml_data}")
|
| 160 |
+
dt = datetime(2000, 1, 1) # before cutoff date
|
| 161 |
+
|
| 162 |
+
# consider if professor is first-author since we now care about semantics
|
| 163 |
+
for paper_author in paper_authors:
|
| 164 |
+
if paper_author in us_professors:
|
| 165 |
+
# make a connection
|
| 166 |
+
heapq.heappush(p2p[paper_author], (dt, line_nbr))
|
| 167 |
+
if len(p2p[paper_author]) > k:
|
| 168 |
+
heapq.heappop(p2p[paper_author])
|
| 169 |
+
except:
|
| 170 |
+
print(f"{line}")
|
| 171 |
+
line_nbr += 1
|
| 172 |
+
return p2p
|
| 173 |
+
|
| 174 |
+
def gen(p2p):
|
| 175 |
+
values = p2p.values()
|
| 176 |
+
relevant_lines = set()
|
| 177 |
+
for value in values:
|
| 178 |
+
relevant_lines.update([v[1] for v in value])
|
| 179 |
+
relevant_lines = sorted(list(relevant_lines))
|
| 180 |
+
|
| 181 |
+
idx = 0
|
| 182 |
+
with open('data/arxiv-metadata-oai-snapshot-ml.json', 'r') as f:
|
| 183 |
+
line_nbr = 1
|
| 184 |
+
while idx < len(relevant_lines):
|
| 185 |
+
line = f.readline()
|
| 186 |
+
if not line: break
|
| 187 |
+
|
| 188 |
+
if line_nbr == relevant_lines[idx]:
|
| 189 |
+
data = json.loads(line)
|
| 190 |
+
yield {"line_nbr": line_nbr,
|
| 191 |
+
"id": data["id"],
|
| 192 |
+
"title": data["title"],
|
| 193 |
+
"abstract": data["abstract"],
|
| 194 |
+
"authors": data["authors_parsed"]}
|
| 195 |
+
idx += 1
|
| 196 |
+
|
| 197 |
+
line_nbr += 1
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
class EmbeddingProcessor:
|
| 201 |
+
def __init__(self, model_name: str, custom_model_name: str, device: str = "cuda"):
|
| 202 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 203 |
+
self.model = AutoModel.from_pretrained(custom_model_name)
|
| 204 |
+
self.device = torch.device(device)
|
| 205 |
+
self.model.to(self.device)
|
| 206 |
+
torch.cuda.empty_cache()
|
| 207 |
+
|
| 208 |
+
@staticmethod
|
| 209 |
+
def mean_pooling(model_output, attention_mask):
|
| 210 |
+
# First element of model_output contains all token embeddings
|
| 211 |
+
token_embeddings = model_output[0]
|
| 212 |
+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
| 213 |
+
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
| 214 |
+
|
| 215 |
+
def get_embeddings(self, batch):
|
| 216 |
+
title_tkn, abstract_tkn = " [TITLE] ", " [ABSTRACT] "
|
| 217 |
+
titles = batch["title"]
|
| 218 |
+
abstracts = batch["abstract"]
|
| 219 |
+
|
| 220 |
+
texts = [title_tkn + t + abstract_tkn + a for t, a in zip(titles, abstracts)]
|
| 221 |
+
|
| 222 |
+
# Tokenize sentences
|
| 223 |
+
encoded_input = self.tokenizer(texts, padding=True, truncation=True, return_tensors='pt')
|
| 224 |
+
encoded_input = {k: v.to(self.device) for k, v in encoded_input.items()}
|
| 225 |
+
|
| 226 |
+
# Compute token embeddings
|
| 227 |
+
with torch.no_grad():
|
| 228 |
+
model_output = self.model(**encoded_input)
|
| 229 |
+
|
| 230 |
+
# Perform pooling
|
| 231 |
+
embeddings = self.mean_pooling(model_output, encoded_input['attention_mask'])
|
| 232 |
+
|
| 233 |
+
# Normalize embeddings
|
| 234 |
+
embeddings = F.normalize(embeddings, p=2, dim=1)
|
| 235 |
+
|
| 236 |
+
# Move embeddings to CPU and convert to list
|
| 237 |
+
return embeddings.cpu().numpy().tolist()
|
| 238 |
+
|
| 239 |
+
def process_dataset(self, dataset_path: str, save_path: str, batch_size: int = 128):
|
| 240 |
+
# Load dataset
|
| 241 |
+
ds = Dataset.load_from_disk(dataset_path)
|
| 242 |
+
|
| 243 |
+
# Compute embeddings and add as a new column
|
| 244 |
+
ds_with_embeddings = ds.map(lambda x: {"embeddings": self.get_embeddings(x)}, batched=True, batch_size=batch_size)
|
| 245 |
+
|
| 246 |
+
# Save the updated dataset
|
| 247 |
+
save_path = save_path
|
| 248 |
+
ds_with_embeddings.save_to_disk(save_path)
|
| 249 |
+
print(f"Dataset with embeddings saved to {save_path}")
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def main():
|
| 253 |
+
"""Downloads arxiv data and extract embeddings for papers."""
|
| 254 |
+
print("Downloading data...")
|
| 255 |
+
arxiv_path = download_arxiv_data()
|
| 256 |
+
with open('data/professor/us_professor.json', 'r') as f:
|
| 257 |
+
authors_of_interest = json.load(f)
|
| 258 |
+
authors_of_interest = [author['name'] for author in authors_of_interest]
|
| 259 |
+
print("Filtering data for ML papers...")
|
| 260 |
+
filter_arxiv_for_ml(arxiv_path, authors_of_interest=authors_of_interest)
|
| 261 |
+
|
| 262 |
+
# professor to list of paper indices
|
| 263 |
+
paper_data_path = "data/paper_embeddings/paper_data"
|
| 264 |
+
print("Saving data to disk at " + paper_data_path)
|
| 265 |
+
p2p = get_professors_and_relevant_papers(authors_of_interest)
|
| 266 |
+
ds = Dataset.from_generator(partial(gen, p2p))
|
| 267 |
+
ds.save_to_disk(paper_data_path)
|
| 268 |
+
|
| 269 |
+
print("Extracting embeddings (use GPU if possible)...")
|
| 270 |
+
# paper embeddings
|
| 271 |
+
save_path = "data/paper_embeddings/all-mpnet-base-v2-embds"
|
| 272 |
+
# Initialize the embedding processor with model names
|
| 273 |
+
embedding_processor = EmbeddingProcessor(
|
| 274 |
+
model_name='sentence-transformers/all-mpnet-base-v2',
|
| 275 |
+
custom_model_name='salsabiilashifa11/sbert-paper'
|
| 276 |
+
)
|
| 277 |
+
# Process dataset and save with embeddings
|
| 278 |
+
embedding_processor.process_dataset(paper_data_path, save_path, batch_size=128)
|
| 279 |
+
|
| 280 |
+
if __name__ == "__main__":
|
| 281 |
+
main()
|
data_pipeline/loaders.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
def load_conference_papers(conference_dir='data/conference'):
|
| 5 |
+
papers = []
|
| 6 |
+
files = os.listdir(conference_dir)
|
| 7 |
+
for file in files:
|
| 8 |
+
if not file.endswith('.json'):
|
| 9 |
+
continue
|
| 10 |
+
with open(os.path.join(conference_dir, file), 'r') as f:
|
| 11 |
+
while True:
|
| 12 |
+
line = f.readline()
|
| 13 |
+
if not line: break
|
| 14 |
+
paper = json.loads(line)
|
| 15 |
+
papers.append(paper)
|
| 16 |
+
return papers
|
| 17 |
+
|
| 18 |
+
def load_us_professor():
|
| 19 |
+
"""Returns a JSON list"""
|
| 20 |
+
with open('data/professor/us_professor.json', 'r') as f:
|
| 21 |
+
us_professors = json.load(f)
|
| 22 |
+
return us_professors
|
data_pipeline/requirements.txt
ADDED
|
File without changes
|
data_pipeline/schools_scraper.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://medium.com/@donadviser/running-selenium-and-chrome-on-wsl2-cfabe7db4bbb
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import time
|
| 5 |
+
|
| 6 |
+
from bs4 import BeautifulSoup
|
| 7 |
+
from dotenv import load_dotenv, find_dotenv
|
| 8 |
+
from langchain_together import ChatTogether
|
| 9 |
+
from langchain_core.output_parsers import StrOutputParser
|
| 10 |
+
from langchain_core.prompts import ChatPromptTemplate
|
| 11 |
+
from langchain_core.runnables import RunnableLambda
|
| 12 |
+
from selenium import webdriver
|
| 13 |
+
from selenium.webdriver.chrome.service import Service
|
| 14 |
+
from selenium.webdriver.common.by import By
|
| 15 |
+
from selenium.webdriver.chrome.options import Options
|
| 16 |
+
|
| 17 |
+
_ = load_dotenv(find_dotenv()) # read local .env file
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def get_service_and_chrome_options():
|
| 21 |
+
"""TODO: specific to chromedriver location."""
|
| 22 |
+
# Define Chrome options
|
| 23 |
+
chrome_options = Options()
|
| 24 |
+
chrome_options.add_argument("--headless")
|
| 25 |
+
chrome_options.add_argument("--no-sandbox")
|
| 26 |
+
# Add more options here if needed
|
| 27 |
+
|
| 28 |
+
# Define paths
|
| 29 |
+
user_home_dir = os.path.expanduser("~")
|
| 30 |
+
user_home_dir = os.path.expanduser("~")
|
| 31 |
+
chrome_binary_path = os.path.join(user_home_dir, "chrome-linux64", "chrome")
|
| 32 |
+
chromedriver_path = os.path.join(user_home_dir, "chromedriver-linux64", "chromedriver")
|
| 33 |
+
|
| 34 |
+
# Set binary location and service
|
| 35 |
+
chrome_options.binary_location = chrome_binary_path
|
| 36 |
+
service = Service(chromedriver_path)
|
| 37 |
+
|
| 38 |
+
return service, chrome_options
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def retrieve_csrankings_content(dump_file="soup.tmp"):
|
| 42 |
+
"""Write times higher page to a dump file."""
|
| 43 |
+
# https://medium.com/@donadviser/running-selenium-and-chrome-on-wsl2-cfabe7db4bbb
|
| 44 |
+
# Using WSL2
|
| 45 |
+
|
| 46 |
+
service, chrome_options = get_service_and_chrome_options()
|
| 47 |
+
|
| 48 |
+
# Initialize Chrome WebDriver
|
| 49 |
+
with webdriver.Chrome(service=service, options=chrome_options) as browser:
|
| 50 |
+
print("Get browser")
|
| 51 |
+
browser.get("https://www.timeshighereducation.com/student/best-universities/best-universities-united-states")
|
| 52 |
+
|
| 53 |
+
# Wait for the page to load
|
| 54 |
+
print("Wait for the page to load")
|
| 55 |
+
browser.implicitly_wait(10)
|
| 56 |
+
|
| 57 |
+
print("Get html")
|
| 58 |
+
# Retrieve the HTML content
|
| 59 |
+
html_content = browser.page_source
|
| 60 |
+
|
| 61 |
+
# Write HTML content to soup.txt
|
| 62 |
+
with open(dump_file, "w") as f:
|
| 63 |
+
f.write(html_content)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def extract_timeshigher_content(read_file="soup.tmp", dump_file="soup (1).tmp"):
|
| 67 |
+
"""Extract universities from a dump file."""
|
| 68 |
+
with open(read_file, "r") as f:
|
| 69 |
+
html_content = f.read()
|
| 70 |
+
|
| 71 |
+
# Parse the HTML content
|
| 72 |
+
soup = BeautifulSoup(html_content, "html.parser")
|
| 73 |
+
|
| 74 |
+
# Find universities
|
| 75 |
+
university_table = soup.find_all('tr')
|
| 76 |
+
universities = [tr.find('a').get_text() for tr in university_table if tr.find('a')]
|
| 77 |
+
|
| 78 |
+
# Remove duplicates while keeping the order
|
| 79 |
+
universities = list(dict.fromkeys(universities))
|
| 80 |
+
|
| 81 |
+
# Write universities line-by-line to a new file
|
| 82 |
+
with open(dump_file, "w") as f:
|
| 83 |
+
for uni in universities:
|
| 84 |
+
f.write(f"{uni}\n")
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def get_department_getter():
|
| 88 |
+
"""
|
| 89 |
+
Returns a function that leverages LangChain and TogetherAI to get a list of
|
| 90 |
+
department names in a university associated with machine learning.
|
| 91 |
+
"""
|
| 92 |
+
template_string = """\
|
| 93 |
+
You are an expert in PhD programs and know about \
|
| 94 |
+
specific departments at each university.\
|
| 95 |
+
You are helping to design a system that generates \
|
| 96 |
+
a list of professors that students interested in \
|
| 97 |
+
machine learning can apply to for their PhDs. \
|
| 98 |
+
Currently, recall is more important than precision. \
|
| 99 |
+
Include as many departments as possible, while \
|
| 100 |
+
maintaining relevancy. Which departments in {university} \
|
| 101 |
+
are associated with machine learning? Please format your \
|
| 102 |
+
answer as a numbered list. Afterwards, please generate a \
|
| 103 |
+
new line starting with \"Answer:\", followed by a concise \
|
| 104 |
+
list of department names generated, separated by
|
| 105 |
+
semicolons.\
|
| 106 |
+
"""
|
| 107 |
+
|
| 108 |
+
prompt_template = ChatPromptTemplate.from_template(template_string)
|
| 109 |
+
|
| 110 |
+
# # choose from our 50+ models here: https://docs.together.ai/docs/inference-models
|
| 111 |
+
chat = ChatTogether(
|
| 112 |
+
together_api_key=os.environ["TOGETHER_API_KEY"],
|
| 113 |
+
model="meta-llama/Llama-3-70b-chat-hf",
|
| 114 |
+
temperature=0.3
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
output_parser = StrOutputParser()
|
| 118 |
+
|
| 119 |
+
def extract_function(text):
|
| 120 |
+
"""Returns the line that starts with `Answer:`"""
|
| 121 |
+
if "Answer:" not in text:
|
| 122 |
+
return "No `Answer:` found"
|
| 123 |
+
return text.split("Answer:")[1].strip()
|
| 124 |
+
|
| 125 |
+
chain = prompt_template | chat | output_parser | RunnableLambda(extract_function)
|
| 126 |
+
|
| 127 |
+
def get_department_info(uni):
|
| 128 |
+
"""Get department info from the university."""
|
| 129 |
+
return chain.invoke({"university": uni})
|
| 130 |
+
|
| 131 |
+
return get_department_info
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def get_department_info(unis_file="soup (1).tmp", deps_file="departments.tsv"):
|
| 135 |
+
"""
|
| 136 |
+
Get department info for all universities in `unis_file` and
|
| 137 |
+
write it to `deps_file`."""
|
| 138 |
+
|
| 139 |
+
department_getter = get_department_getter()
|
| 140 |
+
with open(unis_file, "r") as fin, open(deps_file, "w") as fout:
|
| 141 |
+
|
| 142 |
+
# Iterate through universities in `fin`
|
| 143 |
+
for uni in fin.readlines():
|
| 144 |
+
uni = uni.strip()
|
| 145 |
+
|
| 146 |
+
deps = []
|
| 147 |
+
# Prompt the LLM multiple times for better recall
|
| 148 |
+
for i in range(3):
|
| 149 |
+
depstr = department_getter(uni)
|
| 150 |
+
time.sleep(3) # Respect usage limits!
|
| 151 |
+
try:
|
| 152 |
+
if depstr == "No `Answer:` found":
|
| 153 |
+
print(f"No departments found for {uni} on {i}'th prompt.")
|
| 154 |
+
else:
|
| 155 |
+
deps_ = [d.strip() for d in depstr.split(';')]
|
| 156 |
+
deps.extend(deps_)
|
| 157 |
+
except Exception as e:
|
| 158 |
+
print("Exception for {uni} on {i}'th prompt: ")
|
| 159 |
+
print("Parsing string: ", depstr)
|
| 160 |
+
print(e)
|
| 161 |
+
|
| 162 |
+
# Deduplicate deps list
|
| 163 |
+
deps = list(dict.fromkeys(deps))
|
| 164 |
+
|
| 165 |
+
# Write to tsv dump file
|
| 166 |
+
for dep in deps:
|
| 167 |
+
fout.write(f"{uni}\t{dep}\n")
|
| 168 |
+
|
| 169 |
+
# Print string info
|
| 170 |
+
print(f"{uni}: {deps}")
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
import requests
|
| 174 |
+
|
| 175 |
+
def get_faculty_list_potential_links_getter():
|
| 176 |
+
"""Returns a function that returns a list of links that may contain faculty lists."""
|
| 177 |
+
GOOGLE_API_KEY = os.environ['GOOGLE_API_KEY']
|
| 178 |
+
GOOGLE_SEARCH_ENGINE_ID = os.environ['GOOGLE_SEARCH_ENGINE_ID']
|
| 179 |
+
|
| 180 |
+
def get_faculty_list_potential_links(uni, dep):
|
| 181 |
+
"""Returns a list of links that may contain faculty lists."""
|
| 182 |
+
search_query = f'{uni} {dep} faculty list'
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
params = {
|
| 186 |
+
'q': search_query, 'key': GOOGLE_API_KEY, 'cx': GOOGLE_SEARCH_ENGINE_ID
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
response = requests.get('https://www.googleapis.com/customsearch/v1', params=params)
|
| 190 |
+
results = response.json()
|
| 191 |
+
title2link = {item['title']: item['link'] for item in results['items']}
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
# if __name__ == "__main__":
|
| 196 |
+
# get_department_info()
|
data_pipeline/us_professor_verifier.py
ADDED
|
@@ -0,0 +1,513 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import defaultdict
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
import pickle
|
| 5 |
+
import requests
|
| 6 |
+
import time
|
| 7 |
+
|
| 8 |
+
from bs4 import BeautifulSoup
|
| 9 |
+
from dotenv import load_dotenv, find_dotenv
|
| 10 |
+
from langchain.prompts import PromptTemplate
|
| 11 |
+
from openai import OpenAI
|
| 12 |
+
import regex as re
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
|
| 15 |
+
from data_pipeline.conference_scraper import get_authors
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
_ = load_dotenv(find_dotenv())
|
| 19 |
+
|
| 20 |
+
SEARCH_URL = "https://api.bing.microsoft.com/v7.0/search"
|
| 21 |
+
SUBSCRIPTION_KEY = os.environ["BING_SEARCH_API_KEY"]
|
| 22 |
+
HEADERS = {
|
| 23 |
+
"Ocp-Apim-Subscription-Key": SUBSCRIPTION_KEY,
|
| 24 |
+
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3",
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
EXAMPLE_PROFESSOR_JSON = {
|
| 28 |
+
"is_professor": True,
|
| 29 |
+
"title": "Assistant Professor",
|
| 30 |
+
"department": "Computer Science",
|
| 31 |
+
"university": "Stanford University",
|
| 32 |
+
"us_university": True,
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
EXAMPLE_not_professor_JSON = {
|
| 36 |
+
"is_professor": False,
|
| 37 |
+
"occupation": "Graduate Student",
|
| 38 |
+
"affiliation": "Carnegie Mellon University"
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
IS_PROFESSOR_TEMPLATE = """You are a helpful assistant tasked with determining if {person_name} is a machine learning \
|
| 42 |
+
professor. You have search results from the query "{person_name} machine learning professor". \
|
| 43 |
+
Based on the results, specify if {person_name} is a professor, and if so, provide \
|
| 44 |
+
their title, department, university, and whether their university is in the U.S. If not, give their occupation \
|
| 45 |
+
and affiliation. Note: multiple people may \
|
| 46 |
+
share the same name, so choose the one most likely in machine learning. Further, one person may have multiple \
|
| 47 |
+
positions. If this is the case and one of those positions include being a professor, specify they are a professor \
|
| 48 |
+
and provide their title, department, university, and whether their university is in the U.S.
|
| 49 |
+
|
| 50 |
+
Only return the raw JSON, no MarkDown!
|
| 51 |
+
|
| 52 |
+
If {person_name} **is** a professor, fill out:
|
| 53 |
+
- `is_professor`: true
|
| 54 |
+
- `title`: e.g., `Assistant Professor`, `Associate Professor`, `Professor` etc.
|
| 55 |
+
- `department`: Full name, e.g., `Computer Science` rather than `CS` and `Electrical Engineering` rather than `EE`.
|
| 56 |
+
- `university`: Full name, e.g., `California Institute of Technology` rather than `Caltech`
|
| 57 |
+
- `us_university`: `true` or `false`
|
| 58 |
+
|
| 59 |
+
Example:
|
| 60 |
+
{professor_json_template}
|
| 61 |
+
|
| 62 |
+
If {person_name} **is not** a professor, fill out:
|
| 63 |
+
- `is_professor`: false
|
| 64 |
+
- `occupation`: e.g., `Graduate Student`, `Researcher`, `Engineer`, `Scientist`
|
| 65 |
+
- `affiliation`: e.g., `Carnegie Mellon University`, `Deepmind`, `Apple`, `NVIDIA`
|
| 66 |
+
|
| 67 |
+
Example:
|
| 68 |
+
{not_professor_json_template}
|
| 69 |
+
|
| 70 |
+
Search results (formatted as a numbered list with link name and snippet). \
|
| 71 |
+
Again, only return the JSON, just with the dictionary and its fields.
|
| 72 |
+
{hits}"""
|
| 73 |
+
|
| 74 |
+
# import httpx
|
| 75 |
+
def bing_search(person_name, max_retries=0, wait_time=0.5):
|
| 76 |
+
"""Performs the bing search `person_name` machine learning professor."""
|
| 77 |
+
query = "{} machine learning professor".format(person_name)
|
| 78 |
+
params = {"q": query, "count": 10, "offset": 0, "mkt": "en-US", "textFormat": "HTML"}
|
| 79 |
+
|
| 80 |
+
for attempt in range(max_retries + 1):
|
| 81 |
+
try:
|
| 82 |
+
response = requests.get(SEARCH_URL, headers=HEADERS, params=params)
|
| 83 |
+
response.raise_for_status()
|
| 84 |
+
return response.json()
|
| 85 |
+
except requests.HTTPError as http_err:
|
| 86 |
+
if attempt == max_retries:
|
| 87 |
+
raise Exception(f"Max retries reached. Failed to get a valid response for {person_name}") from http_err
|
| 88 |
+
print(f"An error occurred while searching {person_name}: {http_err}. Retrying in {wait_time} seconds ...")
|
| 89 |
+
time.sleep(wait_time)
|
| 90 |
+
|
| 91 |
+
return "" # doesn't run
|
| 92 |
+
|
| 93 |
+
def process_search_results(search_results):
|
| 94 |
+
"""Cleans up bing search results."""
|
| 95 |
+
# What people see, url name and snippet
|
| 96 |
+
readable_results = "\n".join(["{0}. [{1}]: [{2}]".format(i + 1, v["name"], v["snippet"])
|
| 97 |
+
for i, v in enumerate(search_results["webPages"]["value"])])
|
| 98 |
+
soup = BeautifulSoup(readable_results, "html.parser")
|
| 99 |
+
cleaned_readable_results = soup.get_text()
|
| 100 |
+
cleaned_readable_results = re.sub(r'[^\x00-\x7F]+', '', cleaned_readable_results)
|
| 101 |
+
|
| 102 |
+
# Links
|
| 103 |
+
url_results = "\n".join(["{0}. {1}".format(i + 1, v["url"])
|
| 104 |
+
for i, v in enumerate(search_results["webPages"]["value"])])
|
| 105 |
+
|
| 106 |
+
# Combine human readable and links
|
| 107 |
+
web_results = [cleaned_readable_results, url_results]
|
| 108 |
+
return web_results
|
| 109 |
+
|
| 110 |
+
def get_prompt(person_name, top_hits):
|
| 111 |
+
template = PromptTemplate(
|
| 112 |
+
input_variables=["person_name", "professor_json_template", "not_professor_json_template", "hits"],
|
| 113 |
+
template=IS_PROFESSOR_TEMPLATE,
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
filled_prompt = template.format(person_name=person_name,
|
| 117 |
+
professor_json_template=json.dumps(EXAMPLE_PROFESSOR_JSON),
|
| 118 |
+
not_professor_json_template=json.dumps(EXAMPLE_not_professor_JSON),
|
| 119 |
+
hits="\n".join(top_hits))
|
| 120 |
+
|
| 121 |
+
return filled_prompt
|
| 122 |
+
|
| 123 |
+
def run_chatgpt(prompt, client, model="gpt-4o-mini", system_prompt=None):
|
| 124 |
+
messages = []
|
| 125 |
+
if system_prompt:
|
| 126 |
+
messages.append({"role": "system", "content": system_prompt})
|
| 127 |
+
messages.append({"role": "user", "content": prompt})
|
| 128 |
+
response = client.chat.completions.create(
|
| 129 |
+
model=model,
|
| 130 |
+
messages=messages,
|
| 131 |
+
temperature=0.0,
|
| 132 |
+
seed=123,
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
# Return response
|
| 136 |
+
return response.choices[0].message.content
|
| 137 |
+
|
| 138 |
+
def check_json(profile):
|
| 139 |
+
if not isinstance(profile, dict):
|
| 140 |
+
raise ValueError("Profile must be a dictionary")
|
| 141 |
+
|
| 142 |
+
if "is_professor" not in profile:
|
| 143 |
+
raise ValueError("Profile must contain a 'is_professor' key")
|
| 144 |
+
|
| 145 |
+
if profile["is_professor"]:
|
| 146 |
+
if "title" not in profile:
|
| 147 |
+
raise ValueError("Profile must contain a 'title' key")
|
| 148 |
+
if "department" not in profile:
|
| 149 |
+
raise ValueError("Profile must contain a 'department' key")
|
| 150 |
+
if "university" not in profile:
|
| 151 |
+
raise ValueError("Profile must contain a 'university' key")
|
| 152 |
+
if "us_university" not in profile:
|
| 153 |
+
raise ValueError("Profile must contain a 'us_university' key")
|
| 154 |
+
if type(profile["us_university"]) is not bool:
|
| 155 |
+
raise ValueError("Profile 'us_university' must be a boolean")
|
| 156 |
+
else:
|
| 157 |
+
if "occupation" not in profile:
|
| 158 |
+
raise ValueError("Profile must contain an 'occupation' key")
|
| 159 |
+
if "affiliation" not in profile:
|
| 160 |
+
raise ValueError("Profile must contain an 'affiliation' key")
|
| 161 |
+
|
| 162 |
+
def save_json(profiles, file_path):
|
| 163 |
+
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
| 164 |
+
with open(file_path, 'w') as file: # appending just the new ones would be better
|
| 165 |
+
json.dump(profiles, file, indent=4)
|
| 166 |
+
|
| 167 |
+
def load_json(file_path):
|
| 168 |
+
with open(file_path, 'r') as file:
|
| 169 |
+
return json.load(file)
|
| 170 |
+
|
| 171 |
+
def log_progress_to_file(progress_log, file_path):
|
| 172 |
+
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
| 173 |
+
with open(file_path, 'a') as file:
|
| 174 |
+
file.write('\n'.join(progress_log))
|
| 175 |
+
file.write('\n' + '-' * 10 + '\n')
|
| 176 |
+
|
| 177 |
+
def search_person(person_name, progress_log):
|
| 178 |
+
"""Completes a bing search for the person."""
|
| 179 |
+
try:
|
| 180 |
+
search_results = bing_search(person_name)
|
| 181 |
+
web_results = process_search_results(search_results)
|
| 182 |
+
top_hits = web_results[0].split("\n")[:5] # Extract top 5 results
|
| 183 |
+
progress_log.append(f"Success: Search results for {person_name}.")
|
| 184 |
+
return top_hits
|
| 185 |
+
except Exception as e:
|
| 186 |
+
print(f"Search exception for {person_name}: ", e)
|
| 187 |
+
progress_log.append(f"Failure: Search exception for {person_name}: {e}")
|
| 188 |
+
return ""
|
| 189 |
+
|
| 190 |
+
def extract_search_results(person_name, progress_log, client, us_professor_profiles, not_us_professor_profiles, top_hits):
|
| 191 |
+
"""Use LLM to extract data from search results."""
|
| 192 |
+
try:
|
| 193 |
+
prompt = get_prompt(person_name, top_hits)
|
| 194 |
+
gpt_output = run_chatgpt(prompt, client) # LLM plz help
|
| 195 |
+
gpt_json = json.loads(gpt_output)
|
| 196 |
+
gpt_profile = {"name": person_name}
|
| 197 |
+
gpt_profile.update(gpt_json)
|
| 198 |
+
check_json(gpt_profile)
|
| 199 |
+
if gpt_profile["is_professor"] and gpt_profile["us_university"]:
|
| 200 |
+
us_professor_profiles.append(gpt_profile)
|
| 201 |
+
else:
|
| 202 |
+
not_us_professor_profiles.append(gpt_profile)
|
| 203 |
+
except Exception as e:
|
| 204 |
+
print(f"LLM exception for {person_name}: ", e)
|
| 205 |
+
progress_log.append(f"Failure: LLM exception for {person_name}: {e}")
|
| 206 |
+
|
| 207 |
+
def research_person(person_name, client, progress_log, us_professor_profiles, not_us_professor_profiles):
|
| 208 |
+
"""Research who this person is and save results."""
|
| 209 |
+
top_hits = search_person(person_name, progress_log)
|
| 210 |
+
if top_hits == "":
|
| 211 |
+
return
|
| 212 |
+
extract_search_results(person_name, progress_log, client, us_professor_profiles, not_us_professor_profiles, top_hits)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def get_authors(save_dir="data/conference", min_papers=3, ignore_first_author=True):
|
| 216 |
+
"""
|
| 217 |
+
Reduce the list of authors to those with at least `min_papers` papers for
|
| 218 |
+
which they are not first authors. Ignores solo-authored papers and papers
|
| 219 |
+
with more than 20 authors.
|
| 220 |
+
|
| 221 |
+
Filters authors so that we don't have to do RAG on every author, which is
|
| 222 |
+
monetarily expensive. Feel free to edit if you have more resources.
|
| 223 |
+
"""
|
| 224 |
+
authors = defaultdict(int)
|
| 225 |
+
for fname in os.listdir(save_dir):
|
| 226 |
+
if not fname.endswith('.json'):
|
| 227 |
+
continue
|
| 228 |
+
|
| 229 |
+
with open(os.path.join(save_dir, fname), 'r') as file:
|
| 230 |
+
for line in file:
|
| 231 |
+
item = json.loads(line)
|
| 232 |
+
paper_authors = [x.strip() for x in item[1].split(",")]
|
| 233 |
+
|
| 234 |
+
# ignore solo-authored papers and papers with more than 20 authors
|
| 235 |
+
if len(paper_authors) == 1 or len(paper_authors) > 20:
|
| 236 |
+
continue
|
| 237 |
+
|
| 238 |
+
# professors generally are not first authors
|
| 239 |
+
if not ignore_first_author and len(paper_authors) > 0:
|
| 240 |
+
authors[paper_authors[0]] += 1
|
| 241 |
+
for i in range(1, len(paper_authors)):
|
| 242 |
+
authors[paper_authors[i]] += 1
|
| 243 |
+
|
| 244 |
+
authors = {k: v for k, v in authors.items() if v >= min_papers}
|
| 245 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 246 |
+
with open(os.path.join(save_dir, "authors.txt"), 'w') as f:
|
| 247 |
+
for k, v in authors.items():
|
| 248 |
+
f.write(f"{k}\t{v}\n")
|
| 249 |
+
return authors
|
| 250 |
+
|
| 251 |
+
def research_conference_profiles(save_freq=20):
|
| 252 |
+
"""Research each author as a stream.
|
| 253 |
+
|
| 254 |
+
NOTE: cannot deal w/ interrupts and continue from past progress.
|
| 255 |
+
"""
|
| 256 |
+
|
| 257 |
+
authors = get_authors("data/conference")
|
| 258 |
+
person_names = list(authors.keys())
|
| 259 |
+
|
| 260 |
+
client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
|
| 261 |
+
|
| 262 |
+
progress_log = []
|
| 263 |
+
us_professor_profiles = []
|
| 264 |
+
not_us_professor_profiles = []
|
| 265 |
+
|
| 266 |
+
def log_save_print(progress_log, us_professor_profiles, not_us_professor_profiles, i):
|
| 267 |
+
log_progress_to_file(progress_log, 'logs/progress_log.tmp')
|
| 268 |
+
save_json(us_professor_profiles, 'data/professor/us_professor.json')
|
| 269 |
+
save_json(not_us_professor_profiles, 'data/professor/not_us_professor.json')
|
| 270 |
+
print(f"Saved profiles to data/professor/us_professor.json and data/professor/not_us_professor.json after processing {i} people")
|
| 271 |
+
|
| 272 |
+
for i in range(len(person_names)):
|
| 273 |
+
research_person(person_names[i], client, progress_log, us_professor_profiles, not_us_professor_profiles)
|
| 274 |
+
if i % save_freq == 0:
|
| 275 |
+
log_save_print(progress_log, us_professor_profiles, not_us_professor_profiles, i)
|
| 276 |
+
|
| 277 |
+
log_save_print(progress_log, us_professor_profiles, not_us_professor_profiles, i)
|
| 278 |
+
print("Research complete.")
|
| 279 |
+
|
| 280 |
+
def batch_search_person(person_names, progress_log, save_freq=20):
|
| 281 |
+
"""Searches everyone given in `person_names`."""
|
| 282 |
+
# might start and stop, pull from previous efforts
|
| 283 |
+
try:
|
| 284 |
+
prev_researched_authors = load_json("data/professor/search_results.json")
|
| 285 |
+
except:
|
| 286 |
+
prev_researched_authors = []
|
| 287 |
+
ignore_set = set([x[0] for x in prev_researched_authors])
|
| 288 |
+
data = prev_researched_authors
|
| 289 |
+
unseen_person_names = []
|
| 290 |
+
for person_name in person_names:
|
| 291 |
+
if person_name not in ignore_set:
|
| 292 |
+
unseen_person_names.append(person_name)
|
| 293 |
+
print(f"Already researched {len(ignore_set)} / {len(person_names)} = {len(ignore_set) / len(person_names)} of the dataset")
|
| 294 |
+
person_names = unseen_person_names
|
| 295 |
+
|
| 296 |
+
# continue search
|
| 297 |
+
for i in tqdm(range(len(person_names))):
|
| 298 |
+
if person_names[i] in ignore_set:
|
| 299 |
+
continue # seen before
|
| 300 |
+
|
| 301 |
+
query_start = time.time()
|
| 302 |
+
top_hits = search_person(person_names[i], progress_log)
|
| 303 |
+
if top_hits != "":
|
| 304 |
+
data.append([person_names[i], top_hits])
|
| 305 |
+
|
| 306 |
+
if i % save_freq == 0:
|
| 307 |
+
save_json(data, "data/professor/search_results.json")
|
| 308 |
+
log_progress_to_file(progress_log, 'logs/progress_log.tmp')
|
| 309 |
+
|
| 310 |
+
# 3 queries per second max
|
| 311 |
+
wait_time = max(time.time() - (query_start + 0.334), 0.0)
|
| 312 |
+
time.sleep(wait_time)
|
| 313 |
+
|
| 314 |
+
save_json(data, "data/professor/search_results.json")
|
| 315 |
+
log_progress_to_file(progress_log, 'logs/progress_log.tmp')
|
| 316 |
+
|
| 317 |
+
def write_batch_files(search_results_path,
|
| 318 |
+
prompt_data_path_prefix,
|
| 319 |
+
model="gpt-4o-mini",
|
| 320 |
+
max_tokens=1000,
|
| 321 |
+
temperature=0.0,
|
| 322 |
+
seed=123,
|
| 323 |
+
batch_size=1999, # max_tokens * batch_size < 2M?
|
| 324 |
+
verbose=0):
|
| 325 |
+
"""Convert search results dump to jsonl for LLM batch request."""
|
| 326 |
+
with open(search_results_path, "r") as f:
|
| 327 |
+
search_results = json.load(f)
|
| 328 |
+
|
| 329 |
+
prompt_datas = []
|
| 330 |
+
for search_result in search_results:
|
| 331 |
+
prompt_data = {
|
| 332 |
+
"custom_id": f"request-{search_result[0]}", # don't change, needed for decoding
|
| 333 |
+
"method": "POST",
|
| 334 |
+
"url": "/v1/chat/completions",
|
| 335 |
+
"body": {
|
| 336 |
+
"model": model,
|
| 337 |
+
"temperature": temperature,
|
| 338 |
+
"seed": seed,
|
| 339 |
+
"messages": [{"role": "user", "content": get_prompt(search_result[0], search_result[1])}],
|
| 340 |
+
"max_tokens": max_tokens
|
| 341 |
+
}
|
| 342 |
+
}
|
| 343 |
+
prompt_datas.append(prompt_data)
|
| 344 |
+
|
| 345 |
+
print(f"Number of prompts: {len(prompt_datas)}")
|
| 346 |
+
if verbose > 0:
|
| 347 |
+
print(get_prompt(search_result[0], search_result[1]))
|
| 348 |
+
|
| 349 |
+
batch_paths = []
|
| 350 |
+
for i in range(0, len(prompt_datas) // batch_size + 1):
|
| 351 |
+
prompt_data_path = f"{prompt_data_path_prefix}_{i}.jsonl"
|
| 352 |
+
batch_range = i * batch_size, (min(len(prompt_datas), (i + 1) * batch_size))
|
| 353 |
+
with open(prompt_data_path, "w") as f:
|
| 354 |
+
for prompt_data in prompt_datas[batch_range[0]:batch_range[1]]:
|
| 355 |
+
f.write(json.dumps(prompt_data) + "\n")
|
| 356 |
+
batch_paths.append(prompt_data_path)
|
| 357 |
+
|
| 358 |
+
return batch_paths
|
| 359 |
+
|
| 360 |
+
def send_batch_files(prompt_data_path_prefix, batch_paths, client, timeout=24*60*60):
|
| 361 |
+
"""Create and send the batch request to API endpoint."""
|
| 362 |
+
batches = []
|
| 363 |
+
|
| 364 |
+
print("Batching and sending requests...")
|
| 365 |
+
for batch_path in tqdm(batch_paths):
|
| 366 |
+
batch_input_file = client.files.create(
|
| 367 |
+
file=open(batch_path, "rb"),
|
| 368 |
+
purpose="batch"
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
batch_input_file_id = batch_input_file.id
|
| 372 |
+
print(f"Batch input file ID: {batch_input_file_id}")
|
| 373 |
+
|
| 374 |
+
batch = client.batches.create(
|
| 375 |
+
input_file_id=batch_input_file_id,
|
| 376 |
+
endpoint="/v1/chat/completions",
|
| 377 |
+
completion_window="24h",
|
| 378 |
+
metadata={
|
| 379 |
+
"description": "search extraction job"
|
| 380 |
+
}
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
begin = time.time()
|
| 384 |
+
while time.time() - begin < timeout:
|
| 385 |
+
batch = client.batches.retrieve(batch.id)
|
| 386 |
+
if batch.status == "completed":
|
| 387 |
+
break
|
| 388 |
+
time.sleep(40)
|
| 389 |
+
print(f"Status ({time.time()-begin:2f}): {batch.status}")
|
| 390 |
+
print("seconds elapsed: ", time.time() - begin)
|
| 391 |
+
batches.append(batch)
|
| 392 |
+
|
| 393 |
+
# Keeps track of the paths to the batch files
|
| 394 |
+
with open(f"{prompt_data_path_prefix}_batches.pkl", "wb") as f:
|
| 395 |
+
pickle.dump(batches, f)
|
| 396 |
+
with open(f"{prompt_data_path_prefix}_ids.txt", "w") as f:
|
| 397 |
+
f.write("\n".join([x.id for x in batches]))
|
| 398 |
+
return batches
|
| 399 |
+
|
| 400 |
+
def retrieve_batch_output(client, batch_id):
|
| 401 |
+
"""OpenAI batch requests finish within 24 hrs."""
|
| 402 |
+
retrieved_batch = client.batches.retrieve(batch_id)
|
| 403 |
+
if retrieved_batch.status == "completed":
|
| 404 |
+
return client.files.content(retrieved_batch.output_file_id).text
|
| 405 |
+
else:
|
| 406 |
+
print("Batch process is still in progress.")
|
| 407 |
+
print(retrieved_batch)
|
| 408 |
+
return "INCOMPLETE"
|
| 409 |
+
|
| 410 |
+
def batch_process_llm_output(client, batches):
|
| 411 |
+
client = OpenAI()
|
| 412 |
+
|
| 413 |
+
outputs = []
|
| 414 |
+
for batch in batches:
|
| 415 |
+
batch_id = batch.id
|
| 416 |
+
output = retrieve_batch_output(client, batch_id)
|
| 417 |
+
if output == "INCOMPLETE":
|
| 418 |
+
return
|
| 419 |
+
outputs.append(output)
|
| 420 |
+
|
| 421 |
+
for output in outputs:
|
| 422 |
+
json_objects = output.split('\n')
|
| 423 |
+
custom_id_idx = len("request-") # where the name begins in "custom_id"
|
| 424 |
+
|
| 425 |
+
progress_log = []
|
| 426 |
+
us_professor_profiles = []
|
| 427 |
+
not_us_professor_profiles = []
|
| 428 |
+
|
| 429 |
+
for json_obj in json_objects:
|
| 430 |
+
if json_obj == '': continue
|
| 431 |
+
|
| 432 |
+
try:
|
| 433 |
+
parsed_data = json.loads(json_obj)
|
| 434 |
+
message_content = parsed_data["response"]["body"]["choices"][0]["message"]["content"]
|
| 435 |
+
gpt_json = json.loads(message_content)
|
| 436 |
+
gpt_profile = {"name": parsed_data["custom_id"][custom_id_idx:]}
|
| 437 |
+
gpt_profile.update(gpt_json)
|
| 438 |
+
check_json(gpt_profile)
|
| 439 |
+
if gpt_profile["is_professor"] and gpt_profile["us_university"]:
|
| 440 |
+
us_professor_profiles.append(gpt_profile)
|
| 441 |
+
else:
|
| 442 |
+
not_us_professor_profiles.append(gpt_profile)
|
| 443 |
+
|
| 444 |
+
progress_log.append(f"Success: Parsed LLM output for {gpt_profile['name']}")
|
| 445 |
+
except Exception as e:
|
| 446 |
+
try:
|
| 447 |
+
print(f"Failed to parse json object for custom-id `{parsed_data['custom_id']}`: {e}")
|
| 448 |
+
progress_log.append(f"Failed: Parsed LLM output for {gpt_profile['name']}: {e}")
|
| 449 |
+
except Exception as e2:
|
| 450 |
+
print(f"Failed to parse json object `{json_obj}`: {e2}")
|
| 451 |
+
progress_log.append(f"Failed UNKNOWN: Parsed LLM output: {e2}")
|
| 452 |
+
|
| 453 |
+
with open("data/professor/us_professor.json", 'w') as file:
|
| 454 |
+
json.dump(us_professor_profiles, file, indent=4)
|
| 455 |
+
with open("data/professor/not_us_professor.json", 'w') as file:
|
| 456 |
+
json.dump(not_us_professor_profiles, file, indent=4)
|
| 457 |
+
|
| 458 |
+
def main():
|
| 459 |
+
import argparse
|
| 460 |
+
|
| 461 |
+
parser = argparse.ArgumentParser(
|
| 462 |
+
description="US Professor Verifier: Search or LLM-Analyze batch operations."
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
+
# Add mutually exclusive group to ensure only one of the arguments is passed
|
| 466 |
+
group = parser.add_mutually_exclusive_group(required=True)
|
| 467 |
+
group.add_argument(
|
| 468 |
+
'--batch_search',
|
| 469 |
+
action='store_true',
|
| 470 |
+
help='Batch search the authors.'
|
| 471 |
+
)
|
| 472 |
+
group.add_argument(
|
| 473 |
+
'--batch_analyze',
|
| 474 |
+
action='store_true',
|
| 475 |
+
help='Sends search results to LLM for analysis.'
|
| 476 |
+
)
|
| 477 |
+
group.add_argument(
|
| 478 |
+
'--batch_retrieve',
|
| 479 |
+
action='store_true',
|
| 480 |
+
help='Retrieve results from an LLM batch request, requires --batch_id'
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
parser.add_argument(
|
| 484 |
+
'--batch_ids_path',
|
| 485 |
+
type=str,
|
| 486 |
+
help='The batch ID for retrieval'
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
+
args = parser.parse_args()
|
| 490 |
+
|
| 491 |
+
prompt_data_path_prefix = "data/professor/prompt_data"
|
| 492 |
+
|
| 493 |
+
if args.batch_search:
|
| 494 |
+
authors = get_authors("data/conference")
|
| 495 |
+
authors_list = list(authors.keys())
|
| 496 |
+
print("Researching people...")
|
| 497 |
+
progress_log = []
|
| 498 |
+
batch_search_person(authors_list, progress_log, save_freq=20)
|
| 499 |
+
elif args.batch_analyze:
|
| 500 |
+
client = OpenAI()
|
| 501 |
+
batch_paths = write_batch_files("data/professor/search_results.json", prompt_data_path_prefix)
|
| 502 |
+
send_batch_files(prompt_data_path_prefix, batch_paths, client)
|
| 503 |
+
elif args.batch_retrieve:
|
| 504 |
+
client = OpenAI()
|
| 505 |
+
with open(f"{prompt_data_path_prefix}_batches.pkl", "rb") as f:
|
| 506 |
+
batches = pickle.load(f)
|
| 507 |
+
batch_process_llm_output(client, batches)
|
| 508 |
+
else:
|
| 509 |
+
raise ValueError("Please specify --batch_search, --batch_analyze, or --batch_retrieve.")
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
if __name__ == "__main__":
|
| 513 |
+
main()
|
requirements.txt
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
python-dotenv
|
| 2 |
+
openai
|
| 3 |
+
langchain-together
|
| 4 |
+
lxml
|
| 5 |
+
|
| 6 |
+
einops
|
| 7 |
+
torch and everything else
|
| 8 |
+
datasets
|
| 9 |
+
transformers
|
| 10 |
+
|
| 11 |
+
datasets
|
| 12 |
+
transformers
|