Spaces:
Sleeping
Sleeping
refactor data paths
Browse files- streamlit.py → USMLPhDRecommender.py +0 -0
- core/recommender.py +8 -12
- data_pipeline/conference_scraper.py +8 -8
- data_pipeline/config.py +45 -0
- data_pipeline/loaders.py +0 -22
- data_pipeline/paper_embeddings_extractor.py +33 -36
- data_pipeline/schools_scraper.py +0 -196
- data_pipeline/us_professor_verifier.py +30 -34
streamlit.py → USMLPhDRecommender.py
RENAMED
|
File without changes
|
core/recommender.py
CHANGED
|
@@ -1,13 +1,13 @@
|
|
| 1 |
from collections import Counter, defaultdict
|
| 2 |
import json
|
| 3 |
-
from operator import itemgetter
|
| 4 |
-
from typing import List
|
| 5 |
|
| 6 |
from datasets import Dataset
|
| 7 |
import torch
|
| 8 |
import torch.nn.functional as F
|
| 9 |
from transformers import AutoTokenizer, AutoModel
|
| 10 |
|
|
|
|
|
|
|
| 11 |
|
| 12 |
class EmbeddingProcessor:
|
| 13 |
def __init__(self,
|
|
@@ -63,23 +63,19 @@ class EmbeddingProcessor:
|
|
| 63 |
ds_with_embeddings.save_to_disk(save_path)
|
| 64 |
print(f"Dataset with embeddings saved to {save_path}")
|
| 65 |
|
| 66 |
-
import os
|
| 67 |
|
| 68 |
class Recommender:
|
| 69 |
def __init__(self,
|
| 70 |
embedding_processor: EmbeddingProcessor,
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
frontend_us_professor_path: str =
|
| 74 |
):
|
| 75 |
self.embedding_processor = embedding_processor
|
| 76 |
-
self.ita = Dataset.load_from_disk(
|
| 77 |
-
self.embds = torch.load(
|
| 78 |
-
|
| 79 |
-
# with open(frontend_id2professor_path, 'r') as f:
|
| 80 |
-
# self.id2professors = json.load(f)
|
| 81 |
with open(frontend_us_professor_path, 'r') as f:
|
| 82 |
-
# dictionary with professor names as keys and their metadata as values
|
| 83 |
self.us_professor_profiles = json.load(f)
|
| 84 |
|
| 85 |
def get_top_k(self, query: str, top_k: int = 5):
|
|
|
|
| 1 |
from collections import Counter, defaultdict
|
| 2 |
import json
|
|
|
|
|
|
|
| 3 |
|
| 4 |
from datasets import Dataset
|
| 5 |
import torch
|
| 6 |
import torch.nn.functional as F
|
| 7 |
from transformers import AutoTokenizer, AutoModel
|
| 8 |
|
| 9 |
+
from data_pipeline.config import DataPaths
|
| 10 |
+
|
| 11 |
|
| 12 |
class EmbeddingProcessor:
|
| 13 |
def __init__(self,
|
|
|
|
| 63 |
ds_with_embeddings.save_to_disk(save_path)
|
| 64 |
print(f"Dataset with embeddings saved to {save_path}")
|
| 65 |
|
|
|
|
| 66 |
|
| 67 |
class Recommender:
|
| 68 |
def __init__(self,
|
| 69 |
embedding_processor: EmbeddingProcessor,
|
| 70 |
+
ita_path: str = DataPaths.FRONTEND_ITA_PATH,
|
| 71 |
+
weights_path: str = DataPaths.FRONTEND_WEIGHTS_PATH,
|
| 72 |
+
frontend_us_professor_path: str = DataPaths.FRONTEND_PROF_PATH,
|
| 73 |
):
|
| 74 |
self.embedding_processor = embedding_processor
|
| 75 |
+
self.ita = Dataset.load_from_disk(ita_path)
|
| 76 |
+
self.embds = torch.load(weights_path, weights_only=True)
|
| 77 |
+
# dictionary with professor names as keys and their metadata as values
|
|
|
|
|
|
|
| 78 |
with open(frontend_us_professor_path, 'r') as f:
|
|
|
|
| 79 |
self.us_professor_profiles = json.load(f)
|
| 80 |
|
| 81 |
def get_top_k(self, query: str, top_k: int = 5):
|
data_pipeline/conference_scraper.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
"""Scrape data from some famous ML conferences and saves into
|
| 2 |
|
| 3 |
Every scrape function returns a list of 3-lists of the form
|
| 4 |
[paper_title, paper_authors, paper_url].
|
|
@@ -36,8 +36,8 @@ import time
|
|
| 36 |
from bs4 import BeautifulSoup
|
| 37 |
from tqdm import tqdm
|
| 38 |
|
|
|
|
| 39 |
|
| 40 |
-
SAVE_DIR = "data/conference"
|
| 41 |
|
| 42 |
def scrape_nips(year):
|
| 43 |
nips_url = f"https://papers.nips.cc/paper/{year}"
|
|
@@ -199,8 +199,8 @@ def main():
|
|
| 199 |
}
|
| 200 |
|
| 201 |
def load_progress():
|
| 202 |
-
if os.path.exists(
|
| 203 |
-
file_paths = os.listdir(
|
| 204 |
file_paths = [file_path for file_path in file_paths if file_path.endswith('.json')]
|
| 205 |
file_paths = [file_path.split('.')[0] for file_path in file_paths]
|
| 206 |
return set(file_paths)
|
|
@@ -214,7 +214,7 @@ def main():
|
|
| 214 |
with open(file_path, 'a') as f:
|
| 215 |
f.write(conference + ': ' + msg + '\n')
|
| 216 |
|
| 217 |
-
os.makedirs(
|
| 218 |
|
| 219 |
# Load previous progress
|
| 220 |
scraped_conferences = load_progress()
|
|
@@ -232,7 +232,7 @@ def main():
|
|
| 232 |
try:
|
| 233 |
|
| 234 |
print(f"Scraping {conference}")
|
| 235 |
-
save_path = os.path.join(
|
| 236 |
conference_items = scrape_function()
|
| 237 |
save_to_file(conference_items, save_path)
|
| 238 |
print(f"Saved {conference} to {str(save_path)}")
|
|
@@ -249,8 +249,8 @@ def main():
|
|
| 249 |
|
| 250 |
def stats():
|
| 251 |
total = 0
|
| 252 |
-
for fname in os.listdir(
|
| 253 |
-
with open(os.path.join(
|
| 254 |
num_lines = sum(1 for _ in file)
|
| 255 |
print(fname + ": " + str(num_lines) + " lines")
|
| 256 |
total += num_lines
|
|
|
|
| 1 |
+
"""Scrape data from some famous ML conferences and saves into `DataPaths.CONFERENCE_DIR`.
|
| 2 |
|
| 3 |
Every scrape function returns a list of 3-lists of the form
|
| 4 |
[paper_title, paper_authors, paper_url].
|
|
|
|
| 36 |
from bs4 import BeautifulSoup
|
| 37 |
from tqdm import tqdm
|
| 38 |
|
| 39 |
+
from data_pipeline.config import DataPaths
|
| 40 |
|
|
|
|
| 41 |
|
| 42 |
def scrape_nips(year):
|
| 43 |
nips_url = f"https://papers.nips.cc/paper/{year}"
|
|
|
|
| 199 |
}
|
| 200 |
|
| 201 |
def load_progress():
|
| 202 |
+
if os.path.exists(DataPaths.CONFERENCE_DIR):
|
| 203 |
+
file_paths = os.listdir(DataPaths.CONFERENCE_DIR)
|
| 204 |
file_paths = [file_path for file_path in file_paths if file_path.endswith('.json')]
|
| 205 |
file_paths = [file_path.split('.')[0] for file_path in file_paths]
|
| 206 |
return set(file_paths)
|
|
|
|
| 214 |
with open(file_path, 'a') as f:
|
| 215 |
f.write(conference + ': ' + msg + '\n')
|
| 216 |
|
| 217 |
+
os.makedirs(DataPaths.CONFERENCE_DIR, exist_ok=True)
|
| 218 |
|
| 219 |
# Load previous progress
|
| 220 |
scraped_conferences = load_progress()
|
|
|
|
| 232 |
try:
|
| 233 |
|
| 234 |
print(f"Scraping {conference}")
|
| 235 |
+
save_path = os.path.join(DataPaths.CONFERENCE_DIR, f"{conference}.json")
|
| 236 |
conference_items = scrape_function()
|
| 237 |
save_to_file(conference_items, save_path)
|
| 238 |
print(f"Saved {conference} to {str(save_path)}")
|
|
|
|
| 249 |
|
| 250 |
def stats():
|
| 251 |
total = 0
|
| 252 |
+
for fname in os.listdir(DataPaths.CONFERENCE_DIR):
|
| 253 |
+
with open(os.path.join(DataPaths.CONFERENCE_DIR, fname), 'r') as file:
|
| 254 |
num_lines = sum(1 for _ in file)
|
| 255 |
print(fname + ": " + str(num_lines) + " lines")
|
| 256 |
total += num_lines
|
data_pipeline/config.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
class DataPaths:
|
| 4 |
+
|
| 5 |
+
BASE_DIR = "data"
|
| 6 |
+
LOG_DIR = "logs"
|
| 7 |
+
|
| 8 |
+
PROGRESS_LOG_PATH = os.path.join(LOG_DIR, 'progress_log.tmp')
|
| 9 |
+
|
| 10 |
+
CONFERENCE_DIR = os.path.join(BASE_DIR, 'conference')
|
| 11 |
+
AUTHORS_PATH = os.path.join(CONFERENCE_DIR, 'authors.txt')
|
| 12 |
+
|
| 13 |
+
PROF_DIR = os.path.join(BASE_DIR, 'professor')
|
| 14 |
+
SEARCH_RESULTS_PATH = os.path.join(PROF_DIR, 'search_results.json')
|
| 15 |
+
US_PROF_PATH = os.path.join(PROF_DIR, 'us_professor.json')
|
| 16 |
+
NOT_US_PROF_PATH = os.path.join(PROF_DIR, 'not_us_professor.json')
|
| 17 |
+
PROMPT_DATA_PREFIX = str(os.path.join(PROF_DIR, 'prompt_data'))
|
| 18 |
+
|
| 19 |
+
ARXIV_FNAME = 'arxiv-metadata-oai-snapshot.json'
|
| 20 |
+
ARXIV_PATH = os.path.join(BASE_DIR, ARXIV_FNAME)
|
| 21 |
+
ML_ARXIV_PATH = os.path.join(BASE_DIR, 'arxiv-metadata-oai-snapshot-ml.json')
|
| 22 |
+
|
| 23 |
+
PAPER_DIR = os.path.join(BASE_DIR, "paper_embeddings")
|
| 24 |
+
EMBD_MODEL = "all-mpnet-base-v2-embds"
|
| 25 |
+
EMBD_PATH = os.path.join(PAPER_DIR, EMBD_MODEL)
|
| 26 |
+
PAPER_DATA_PATH = os.path.join(PAPER_DIR, "paper_data")
|
| 27 |
+
|
| 28 |
+
FRONTEND_DIR = os.path.join(BASE_DIR, 'frontend_data')
|
| 29 |
+
FRONTEND_PROF_PATH = os.path.join(FRONTEND_DIR, 'us_professor.json')
|
| 30 |
+
FRONTEND_EMBD_PATH = os.path.join(FRONTEND_DIR, EMBD_MODEL) # contains id, title, author, weights
|
| 31 |
+
FRONTEND_ITA_PATH = os.path.join(FRONTEND_EMBD_PATH, 'id_title_author')
|
| 32 |
+
FRONTEND_WEIGHTS_PATH = os.path.join(FRONTEND_EMBD_PATH, 'weights.pt')
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# create FRONTEND_DIR PROF_DIR CONFERENCE_DIR
|
| 36 |
+
|
| 37 |
+
@staticmethod
|
| 38 |
+
def ensure_directories():
|
| 39 |
+
# Create the directories if they do not exist
|
| 40 |
+
os.makedirs(DataPaths.RAW_DATA_DIR, exist_ok=True)
|
| 41 |
+
os.makedirs(DataPaths.PROCESSED_DATA_DIR, exist_ok=True)
|
| 42 |
+
os.makedirs(DataPaths.MODEL_OUTPUT_DIR, exist_ok=True)
|
| 43 |
+
|
| 44 |
+
# Call this function early in your pipeline
|
| 45 |
+
DataPaths.ensure_directories()
|
data_pipeline/loaders.py
DELETED
|
@@ -1,22 +0,0 @@
|
|
| 1 |
-
import json
|
| 2 |
-
import os
|
| 3 |
-
|
| 4 |
-
def load_conference_papers(conference_dir='data/conference'):
|
| 5 |
-
papers = []
|
| 6 |
-
files = os.listdir(conference_dir)
|
| 7 |
-
for file in files:
|
| 8 |
-
if not file.endswith('.json'):
|
| 9 |
-
continue
|
| 10 |
-
with open(os.path.join(conference_dir, file), 'r') as f:
|
| 11 |
-
while True:
|
| 12 |
-
line = f.readline()
|
| 13 |
-
if not line: break
|
| 14 |
-
paper = json.loads(line)
|
| 15 |
-
papers.append(paper)
|
| 16 |
-
return papers
|
| 17 |
-
|
| 18 |
-
def load_us_professor():
|
| 19 |
-
"""Returns a JSON list"""
|
| 20 |
-
with open('data/professor/us_professor.json', 'r') as f:
|
| 21 |
-
us_professors = json.load(f)
|
| 22 |
-
return us_professors
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data_pipeline/paper_embeddings_extractor.py
CHANGED
|
@@ -16,20 +16,20 @@ import torch
|
|
| 16 |
from tqdm import tqdm
|
| 17 |
|
| 18 |
from core.recommender import EmbeddingProcessor
|
| 19 |
-
|
| 20 |
|
| 21 |
arxiv_fname = "arxiv-metadata-oai-snapshot.json"
|
| 22 |
|
| 23 |
-
def download_arxiv_data(
|
| 24 |
-
"""Downloads and unzips arxiv dataset from Kaggle into
|
| 25 |
dataset = "Cornell-University/arxiv"
|
| 26 |
-
data_path =
|
| 27 |
|
| 28 |
-
if not any([
|
| 29 |
kaggle.api.dataset_download_cli(dataset, path=data_path, unzip=True)
|
| 30 |
else:
|
| 31 |
-
print(f"Data already downloaded at {
|
| 32 |
-
return
|
| 33 |
|
| 34 |
def get_lbl_from_name(names):
|
| 35 |
"""Tuple (last_name, first_name, middle_name) => String 'first_name [middle_name] last_name'."""
|
|
@@ -39,7 +39,7 @@ def get_lbl_from_name(names):
|
|
| 39 |
for name in names
|
| 40 |
]
|
| 41 |
|
| 42 |
-
def filter_arxiv_for_ml(
|
| 43 |
"""Sifts through downloaded arxiv file to find ML-related papers.
|
| 44 |
|
| 45 |
If `obtain_summary` is True, saves a pickled DataFrame to the same directory as
|
|
@@ -47,8 +47,8 @@ def filter_arxiv_for_ml(arxiv_path, obtain_summary=False, authors_of_interest=No
|
|
| 47 |
|
| 48 |
If `authors_of_interest` is not None, only save ML-related papers by those authors.
|
| 49 |
"""
|
| 50 |
-
ml_path = str(
|
| 51 |
-
summary_path = str(
|
| 52 |
|
| 53 |
ml_cats = ['cs.AI', 'cs.CL', 'cs.CV', 'cs.LG', 'stat.ML']
|
| 54 |
|
|
@@ -67,7 +67,7 @@ def filter_arxiv_for_ml(arxiv_path, obtain_summary=False, authors_of_interest=No
|
|
| 67 |
authors_of_interest = set(authors_of_interest)
|
| 68 |
|
| 69 |
# Load the JSON file line by line
|
| 70 |
-
with open(
|
| 71 |
for line in tqdm(f1):
|
| 72 |
# Parse each line as JSON
|
| 73 |
try:
|
|
@@ -118,7 +118,7 @@ def filter_arxiv_for_ml(arxiv_path, obtain_summary=False, authors_of_interest=No
|
|
| 118 |
def get_professors_and_relevant_papers(us_professors, k=8, cutoff=datetime(2022, 10, 1)):
|
| 119 |
"""
|
| 120 |
Returns a dictionary mapping U.S. professor names to a list of indices
|
| 121 |
-
corresponding to their most recent papers in
|
| 122 |
This function is necessary to specify the papers we are interested in for each
|
| 123 |
professor (e.g., the most recent papers after cutoff)
|
| 124 |
|
|
@@ -136,7 +136,7 @@ def get_professors_and_relevant_papers(us_professors, k=8, cutoff=datetime(2022,
|
|
| 136 |
# professors to tuple of (datetime, arxiv_id)
|
| 137 |
p2p = defaultdict(list)
|
| 138 |
|
| 139 |
-
with open(
|
| 140 |
while True:
|
| 141 |
line = f.readline()
|
| 142 |
if not line: break
|
|
@@ -174,7 +174,7 @@ def gen(p2p):
|
|
| 174 |
relevant_ids = set()
|
| 175 |
for value in values:
|
| 176 |
relevant_ids.update([v[1] for v in value])
|
| 177 |
-
with open(
|
| 178 |
while True:
|
| 179 |
line = f.readline()
|
| 180 |
if not line: break
|
|
@@ -193,7 +193,7 @@ def gen(p2p):
|
|
| 193 |
def save_paper_to_professor(p2p, save_path):
|
| 194 |
"""Returns a dictionary mapping an Arxiv ID to U.S. professor names
|
| 195 |
|
| 196 |
-
`p2p`: mapping from professor to list of paper indices in
|
| 197 |
`ds`: dataset with Arxiv links and line_nbr
|
| 198 |
"""
|
| 199 |
|
|
@@ -215,42 +215,39 @@ def main():
|
|
| 215 |
|
| 216 |
### Download and filter for ML papers written by U.S. professors ###
|
| 217 |
print("Downloading data...")
|
| 218 |
-
|
| 219 |
-
with open(
|
| 220 |
authors_of_interest = json.load(f)
|
| 221 |
authors_of_interest = [author['name'] for author in authors_of_interest]
|
| 222 |
print("Filtering data for ML papers...")
|
| 223 |
-
filter_arxiv_for_ml(
|
| 224 |
|
| 225 |
### Create a dataset containing paper info, e.g., title, abstract, authors, etc. ###
|
| 226 |
-
|
| 227 |
-
print("Saving paper data to disk at " + paper_data_path)
|
| 228 |
p2p = get_professors_and_relevant_papers(authors_of_interest)
|
| 229 |
ds = Dataset.from_generator(partial(gen, p2p))
|
| 230 |
-
ds.save_to_disk(
|
| 231 |
-
|
| 232 |
-
#
|
| 233 |
-
|
| 234 |
-
#
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
#
|
| 240 |
-
|
| 241 |
-
# embedding_processor.process_dataset(paper_data_path, embds_save_path, batch_size=128)
|
| 242 |
|
| 243 |
### Create front-end data ###
|
| 244 |
|
| 245 |
# Filter ds for paper title, id, authors, and embedding
|
| 246 |
-
embds = Dataset.load_from_disk(
|
| 247 |
-
embds_frontend_save_path = "data/frontend_data/all-mpnet-base-v2-embds"
|
| 248 |
|
| 249 |
# save id and title to disk
|
| 250 |
-
embds.select_columns(['id', 'title', 'authors']).save_to_disk(
|
| 251 |
# save embeddings as torch tensor
|
| 252 |
embds_weights = torch.Tensor(embds['embeddings'])
|
| 253 |
-
torch.save(embds_weights,
|
| 254 |
|
| 255 |
if __name__ == "__main__":
|
| 256 |
main()
|
|
|
|
| 16 |
from tqdm import tqdm
|
| 17 |
|
| 18 |
from core.recommender import EmbeddingProcessor
|
| 19 |
+
from data_pipeline.config import DataPaths
|
| 20 |
|
| 21 |
arxiv_fname = "arxiv-metadata-oai-snapshot.json"
|
| 22 |
|
| 23 |
+
def download_arxiv_data():
|
| 24 |
+
"""Downloads and unzips arxiv dataset from Kaggle into `data` directory."""
|
| 25 |
dataset = "Cornell-University/arxiv"
|
| 26 |
+
data_path = DataPaths.BASE_DIR
|
| 27 |
|
| 28 |
+
if not any([DataPaths.ARXIV_FNAME in file for file in os.listdir(data_path)]):
|
| 29 |
kaggle.api.dataset_download_cli(dataset, path=data_path, unzip=True)
|
| 30 |
else:
|
| 31 |
+
print(f"Data already downloaded at {DataPaths.ARXIV_FNAME}.")
|
| 32 |
+
return DataPaths.ARXIV_FNAME
|
| 33 |
|
| 34 |
def get_lbl_from_name(names):
|
| 35 |
"""Tuple (last_name, first_name, middle_name) => String 'first_name [middle_name] last_name'."""
|
|
|
|
| 39 |
for name in names
|
| 40 |
]
|
| 41 |
|
| 42 |
+
def filter_arxiv_for_ml(obtain_summary=False, authors_of_interest=None):
|
| 43 |
"""Sifts through downloaded arxiv file to find ML-related papers.
|
| 44 |
|
| 45 |
If `obtain_summary` is True, saves a pickled DataFrame to the same directory as
|
|
|
|
| 47 |
|
| 48 |
If `authors_of_interest` is not None, only save ML-related papers by those authors.
|
| 49 |
"""
|
| 50 |
+
ml_path = str(DataPaths.ARXIV_PATH).split('.')[0]+'-ml.json'
|
| 51 |
+
summary_path = str(DataPaths.ARXIV_PATH).split('.')[0]+'-summary.pkl'
|
| 52 |
|
| 53 |
ml_cats = ['cs.AI', 'cs.CL', 'cs.CV', 'cs.LG', 'stat.ML']
|
| 54 |
|
|
|
|
| 67 |
authors_of_interest = set(authors_of_interest)
|
| 68 |
|
| 69 |
# Load the JSON file line by line
|
| 70 |
+
with open(DataPaths.ARXIV_PATH, 'r') as f1, open(ml_path, 'w') as f2:
|
| 71 |
for line in tqdm(f1):
|
| 72 |
# Parse each line as JSON
|
| 73 |
try:
|
|
|
|
| 118 |
def get_professors_and_relevant_papers(us_professors, k=8, cutoff=datetime(2022, 10, 1)):
|
| 119 |
"""
|
| 120 |
Returns a dictionary mapping U.S. professor names to a list of indices
|
| 121 |
+
corresponding to their most recent papers in DataPaths.ML_ARXIV_PATH.
|
| 122 |
This function is necessary to specify the papers we are interested in for each
|
| 123 |
professor (e.g., the most recent papers after cutoff)
|
| 124 |
|
|
|
|
| 136 |
# professors to tuple of (datetime, arxiv_id)
|
| 137 |
p2p = defaultdict(list)
|
| 138 |
|
| 139 |
+
with open(DataPaths.ML_ARXIV_PATH, 'r') as f:
|
| 140 |
while True:
|
| 141 |
line = f.readline()
|
| 142 |
if not line: break
|
|
|
|
| 174 |
relevant_ids = set()
|
| 175 |
for value in values:
|
| 176 |
relevant_ids.update([v[1] for v in value])
|
| 177 |
+
with open(DataPaths.ML_ARXIV_PATH, 'r') as f:
|
| 178 |
while True:
|
| 179 |
line = f.readline()
|
| 180 |
if not line: break
|
|
|
|
| 193 |
def save_paper_to_professor(p2p, save_path):
|
| 194 |
"""Returns a dictionary mapping an Arxiv ID to U.S. professor names
|
| 195 |
|
| 196 |
+
`p2p`: mapping from professor to list of paper indices in DataPaths.ML_ARXIV_PATH
|
| 197 |
`ds`: dataset with Arxiv links and line_nbr
|
| 198 |
"""
|
| 199 |
|
|
|
|
| 215 |
|
| 216 |
### Download and filter for ML papers written by U.S. professors ###
|
| 217 |
print("Downloading data...")
|
| 218 |
+
download_arxiv_data()
|
| 219 |
+
with open(DataPaths.US_PROF_PATH, 'r') as f:
|
| 220 |
authors_of_interest = json.load(f)
|
| 221 |
authors_of_interest = [author['name'] for author in authors_of_interest]
|
| 222 |
print("Filtering data for ML papers...")
|
| 223 |
+
filter_arxiv_for_ml(authors_of_interest=authors_of_interest)
|
| 224 |
|
| 225 |
### Create a dataset containing paper info, e.g., title, abstract, authors, etc. ###
|
| 226 |
+
print("Saving paper data to disk at " + DataPaths.PAPER_DATA_PATH)
|
|
|
|
| 227 |
p2p = get_professors_and_relevant_papers(authors_of_interest)
|
| 228 |
ds = Dataset.from_generator(partial(gen, p2p))
|
| 229 |
+
ds.save_to_disk(DataPaths.PAPER_DATA_PATH)
|
| 230 |
+
|
| 231 |
+
### Extract paper embeddings ###
|
| 232 |
+
print("Extracting embeddings (use GPU if possible)...")
|
| 233 |
+
# Initialize the embedding processor with model names
|
| 234 |
+
embedding_processor = EmbeddingProcessor(
|
| 235 |
+
model_name='sentence-transformers/all-mpnet-base-v2',
|
| 236 |
+
custom_model_name='salsabiilashifa11/sbert-paper'
|
| 237 |
+
)
|
| 238 |
+
# Process dataset and save with embeddings
|
| 239 |
+
embedding_processor.process_dataset(DataPaths.PAPER_DATA_PATH, DataPaths.EMBD_PATH, batch_size=128)
|
|
|
|
| 240 |
|
| 241 |
### Create front-end data ###
|
| 242 |
|
| 243 |
# Filter ds for paper title, id, authors, and embedding
|
| 244 |
+
embds = Dataset.load_from_disk(DataPaths.EMBD_PATH)
|
|
|
|
| 245 |
|
| 246 |
# save id and title to disk
|
| 247 |
+
embds.select_columns(['id', 'title', 'authors']).save_to_disk(DataPaths.FRONTEND_ITA_PATH)
|
| 248 |
# save embeddings as torch tensor
|
| 249 |
embds_weights = torch.Tensor(embds['embeddings'])
|
| 250 |
+
torch.save(embds_weights, DataPaths.FRONTEND_WEIGHTS_PATH)
|
| 251 |
|
| 252 |
if __name__ == "__main__":
|
| 253 |
main()
|
data_pipeline/schools_scraper.py
DELETED
|
@@ -1,196 +0,0 @@
|
|
| 1 |
-
# https://medium.com/@donadviser/running-selenium-and-chrome-on-wsl2-cfabe7db4bbb
|
| 2 |
-
|
| 3 |
-
import os
|
| 4 |
-
import time
|
| 5 |
-
|
| 6 |
-
from bs4 import BeautifulSoup
|
| 7 |
-
from dotenv import load_dotenv, find_dotenv
|
| 8 |
-
from langchain_together import ChatTogether
|
| 9 |
-
from langchain_core.output_parsers import StrOutputParser
|
| 10 |
-
from langchain_core.prompts import ChatPromptTemplate
|
| 11 |
-
from langchain_core.runnables import RunnableLambda
|
| 12 |
-
from selenium import webdriver
|
| 13 |
-
from selenium.webdriver.chrome.service import Service
|
| 14 |
-
from selenium.webdriver.common.by import By
|
| 15 |
-
from selenium.webdriver.chrome.options import Options
|
| 16 |
-
|
| 17 |
-
_ = load_dotenv(find_dotenv()) # read local .env file
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
def get_service_and_chrome_options():
|
| 21 |
-
"""TODO: specific to chromedriver location."""
|
| 22 |
-
# Define Chrome options
|
| 23 |
-
chrome_options = Options()
|
| 24 |
-
chrome_options.add_argument("--headless")
|
| 25 |
-
chrome_options.add_argument("--no-sandbox")
|
| 26 |
-
# Add more options here if needed
|
| 27 |
-
|
| 28 |
-
# Define paths
|
| 29 |
-
user_home_dir = os.path.expanduser("~")
|
| 30 |
-
user_home_dir = os.path.expanduser("~")
|
| 31 |
-
chrome_binary_path = os.path.join(user_home_dir, "chrome-linux64", "chrome")
|
| 32 |
-
chromedriver_path = os.path.join(user_home_dir, "chromedriver-linux64", "chromedriver")
|
| 33 |
-
|
| 34 |
-
# Set binary location and service
|
| 35 |
-
chrome_options.binary_location = chrome_binary_path
|
| 36 |
-
service = Service(chromedriver_path)
|
| 37 |
-
|
| 38 |
-
return service, chrome_options
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
def retrieve_csrankings_content(dump_file="soup.tmp"):
|
| 42 |
-
"""Write times higher page to a dump file."""
|
| 43 |
-
# https://medium.com/@donadviser/running-selenium-and-chrome-on-wsl2-cfabe7db4bbb
|
| 44 |
-
# Using WSL2
|
| 45 |
-
|
| 46 |
-
service, chrome_options = get_service_and_chrome_options()
|
| 47 |
-
|
| 48 |
-
# Initialize Chrome WebDriver
|
| 49 |
-
with webdriver.Chrome(service=service, options=chrome_options) as browser:
|
| 50 |
-
print("Get browser")
|
| 51 |
-
browser.get("https://www.timeshighereducation.com/student/best-universities/best-universities-united-states")
|
| 52 |
-
|
| 53 |
-
# Wait for the page to load
|
| 54 |
-
print("Wait for the page to load")
|
| 55 |
-
browser.implicitly_wait(10)
|
| 56 |
-
|
| 57 |
-
print("Get html")
|
| 58 |
-
# Retrieve the HTML content
|
| 59 |
-
html_content = browser.page_source
|
| 60 |
-
|
| 61 |
-
# Write HTML content to soup.txt
|
| 62 |
-
with open(dump_file, "w") as f:
|
| 63 |
-
f.write(html_content)
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
def extract_timeshigher_content(read_file="soup.tmp", dump_file="soup (1).tmp"):
|
| 67 |
-
"""Extract universities from a dump file."""
|
| 68 |
-
with open(read_file, "r") as f:
|
| 69 |
-
html_content = f.read()
|
| 70 |
-
|
| 71 |
-
# Parse the HTML content
|
| 72 |
-
soup = BeautifulSoup(html_content, "html.parser")
|
| 73 |
-
|
| 74 |
-
# Find universities
|
| 75 |
-
university_table = soup.find_all('tr')
|
| 76 |
-
universities = [tr.find('a').get_text() for tr in university_table if tr.find('a')]
|
| 77 |
-
|
| 78 |
-
# Remove duplicates while keeping the order
|
| 79 |
-
universities = list(dict.fromkeys(universities))
|
| 80 |
-
|
| 81 |
-
# Write universities line-by-line to a new file
|
| 82 |
-
with open(dump_file, "w") as f:
|
| 83 |
-
for uni in universities:
|
| 84 |
-
f.write(f"{uni}\n")
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
def get_department_getter():
|
| 88 |
-
"""
|
| 89 |
-
Returns a function that leverages LangChain and TogetherAI to get a list of
|
| 90 |
-
department names in a university associated with machine learning.
|
| 91 |
-
"""
|
| 92 |
-
template_string = """\
|
| 93 |
-
You are an expert in PhD programs and know about \
|
| 94 |
-
specific departments at each university.\
|
| 95 |
-
You are helping to design a system that generates \
|
| 96 |
-
a list of professors that students interested in \
|
| 97 |
-
machine learning can apply to for their PhDs. \
|
| 98 |
-
Currently, recall is more important than precision. \
|
| 99 |
-
Include as many departments as possible, while \
|
| 100 |
-
maintaining relevancy. Which departments in {university} \
|
| 101 |
-
are associated with machine learning? Please format your \
|
| 102 |
-
answer as a numbered list. Afterwards, please generate a \
|
| 103 |
-
new line starting with \"Answer:\", followed by a concise \
|
| 104 |
-
list of department names generated, separated by
|
| 105 |
-
semicolons.\
|
| 106 |
-
"""
|
| 107 |
-
|
| 108 |
-
prompt_template = ChatPromptTemplate.from_template(template_string)
|
| 109 |
-
|
| 110 |
-
# # choose from our 50+ models here: https://docs.together.ai/docs/inference-models
|
| 111 |
-
chat = ChatTogether(
|
| 112 |
-
together_api_key=os.environ["TOGETHER_API_KEY"],
|
| 113 |
-
model="meta-llama/Llama-3-70b-chat-hf",
|
| 114 |
-
temperature=0.3
|
| 115 |
-
)
|
| 116 |
-
|
| 117 |
-
output_parser = StrOutputParser()
|
| 118 |
-
|
| 119 |
-
def extract_function(text):
|
| 120 |
-
"""Returns the line that starts with `Answer:`"""
|
| 121 |
-
if "Answer:" not in text:
|
| 122 |
-
return "No `Answer:` found"
|
| 123 |
-
return text.split("Answer:")[1].strip()
|
| 124 |
-
|
| 125 |
-
chain = prompt_template | chat | output_parser | RunnableLambda(extract_function)
|
| 126 |
-
|
| 127 |
-
def get_department_info(uni):
|
| 128 |
-
"""Get department info from the university."""
|
| 129 |
-
return chain.invoke({"university": uni})
|
| 130 |
-
|
| 131 |
-
return get_department_info
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
def get_department_info(unis_file="soup (1).tmp", deps_file="departments.tsv"):
|
| 135 |
-
"""
|
| 136 |
-
Get department info for all universities in `unis_file` and
|
| 137 |
-
write it to `deps_file`."""
|
| 138 |
-
|
| 139 |
-
department_getter = get_department_getter()
|
| 140 |
-
with open(unis_file, "r") as fin, open(deps_file, "w") as fout:
|
| 141 |
-
|
| 142 |
-
# Iterate through universities in `fin`
|
| 143 |
-
for uni in fin.readlines():
|
| 144 |
-
uni = uni.strip()
|
| 145 |
-
|
| 146 |
-
deps = []
|
| 147 |
-
# Prompt the LLM multiple times for better recall
|
| 148 |
-
for i in range(3):
|
| 149 |
-
depstr = department_getter(uni)
|
| 150 |
-
time.sleep(3) # Respect usage limits!
|
| 151 |
-
try:
|
| 152 |
-
if depstr == "No `Answer:` found":
|
| 153 |
-
print(f"No departments found for {uni} on {i}'th prompt.")
|
| 154 |
-
else:
|
| 155 |
-
deps_ = [d.strip() for d in depstr.split(';')]
|
| 156 |
-
deps.extend(deps_)
|
| 157 |
-
except Exception as e:
|
| 158 |
-
print("Exception for {uni} on {i}'th prompt: ")
|
| 159 |
-
print("Parsing string: ", depstr)
|
| 160 |
-
print(e)
|
| 161 |
-
|
| 162 |
-
# Deduplicate deps list
|
| 163 |
-
deps = list(dict.fromkeys(deps))
|
| 164 |
-
|
| 165 |
-
# Write to tsv dump file
|
| 166 |
-
for dep in deps:
|
| 167 |
-
fout.write(f"{uni}\t{dep}\n")
|
| 168 |
-
|
| 169 |
-
# Print string info
|
| 170 |
-
print(f"{uni}: {deps}")
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
import requests
|
| 174 |
-
|
| 175 |
-
def get_faculty_list_potential_links_getter():
|
| 176 |
-
"""Returns a function that returns a list of links that may contain faculty lists."""
|
| 177 |
-
GOOGLE_API_KEY = os.environ['GOOGLE_API_KEY']
|
| 178 |
-
GOOGLE_SEARCH_ENGINE_ID = os.environ['GOOGLE_SEARCH_ENGINE_ID']
|
| 179 |
-
|
| 180 |
-
def get_faculty_list_potential_links(uni, dep):
|
| 181 |
-
"""Returns a list of links that may contain faculty lists."""
|
| 182 |
-
search_query = f'{uni} {dep} faculty list'
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
params = {
|
| 186 |
-
'q': search_query, 'key': GOOGLE_API_KEY, 'cx': GOOGLE_SEARCH_ENGINE_ID
|
| 187 |
-
}
|
| 188 |
-
|
| 189 |
-
response = requests.get('https://www.googleapis.com/customsearch/v1', params=params)
|
| 190 |
-
results = response.json()
|
| 191 |
-
title2link = {item['title']: item['link'] for item in results['items']}
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
# if __name__ == "__main__":
|
| 196 |
-
# get_department_info()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data_pipeline/us_professor_verifier.py
CHANGED
|
@@ -13,6 +13,7 @@ import regex as re
|
|
| 13 |
from tqdm import tqdm
|
| 14 |
|
| 15 |
from data_pipeline.conference_scraper import get_authors
|
|
|
|
| 16 |
|
| 17 |
|
| 18 |
_ = load_dotenv(find_dotenv())
|
|
@@ -161,7 +162,7 @@ def check_json(profile):
|
|
| 161 |
|
| 162 |
def save_json(profiles, file_path):
|
| 163 |
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
| 164 |
-
with open(file_path, 'w') as file: #
|
| 165 |
json.dump(profiles, file, indent=4)
|
| 166 |
|
| 167 |
def load_json(file_path):
|
|
@@ -212,7 +213,7 @@ def research_person(person_name, client, progress_log, us_professor_profiles, no
|
|
| 212 |
extract_search_results(person_name, progress_log, client, us_professor_profiles, not_us_professor_profiles, top_hits)
|
| 213 |
|
| 214 |
|
| 215 |
-
def get_authors(
|
| 216 |
"""
|
| 217 |
Reduce the list of authors to those with at least `min_papers` papers for
|
| 218 |
which they are not first authors. Ignores solo-authored papers and papers
|
|
@@ -222,11 +223,11 @@ def get_authors(save_dir="data/conference", min_papers=3, ignore_first_author=Tr
|
|
| 222 |
monetarily expensive. Feel free to edit if you have more resources.
|
| 223 |
"""
|
| 224 |
authors = defaultdict(int)
|
| 225 |
-
for fname in os.listdir(
|
| 226 |
if not fname.endswith('.json'):
|
| 227 |
continue
|
| 228 |
|
| 229 |
-
with open(os.path.join(
|
| 230 |
for line in file:
|
| 231 |
item = json.loads(line)
|
| 232 |
paper_authors = [x.strip() for x in item[1].split(",")]
|
|
@@ -242,8 +243,8 @@ def get_authors(save_dir="data/conference", min_papers=3, ignore_first_author=Tr
|
|
| 242 |
authors[paper_authors[i]] += 1
|
| 243 |
|
| 244 |
authors = {k: v for k, v in authors.items() if v >= min_papers}
|
| 245 |
-
os.makedirs(
|
| 246 |
-
with open(
|
| 247 |
for k, v in authors.items():
|
| 248 |
f.write(f"{k}\t{v}\n")
|
| 249 |
return authors
|
|
@@ -254,7 +255,7 @@ def research_conference_profiles(save_freq=20):
|
|
| 254 |
NOTE: cannot deal w/ interrupts and continue from past progress.
|
| 255 |
"""
|
| 256 |
|
| 257 |
-
authors = get_authors(
|
| 258 |
person_names = list(authors.keys())
|
| 259 |
|
| 260 |
client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
|
|
@@ -264,10 +265,10 @@ def research_conference_profiles(save_freq=20):
|
|
| 264 |
not_us_professor_profiles = []
|
| 265 |
|
| 266 |
def log_save_print(progress_log, us_professor_profiles, not_us_professor_profiles, i):
|
| 267 |
-
log_progress_to_file(progress_log,
|
| 268 |
-
save_json(us_professor_profiles,
|
| 269 |
-
save_json(not_us_professor_profiles,
|
| 270 |
-
print(f"Saved profiles to
|
| 271 |
|
| 272 |
for i in range(len(person_names)):
|
| 273 |
research_person(person_names[i], client, progress_log, us_professor_profiles, not_us_professor_profiles)
|
|
@@ -281,7 +282,7 @@ def batch_search_person(person_names, progress_log, save_freq=20):
|
|
| 281 |
"""Searches everyone given in `person_names`."""
|
| 282 |
# might start and stop, pull from previous efforts
|
| 283 |
try:
|
| 284 |
-
prev_researched_authors = load_json(
|
| 285 |
except:
|
| 286 |
prev_researched_authors = []
|
| 287 |
ignore_set = set([x[0] for x in prev_researched_authors])
|
|
@@ -304,18 +305,17 @@ def batch_search_person(person_names, progress_log, save_freq=20):
|
|
| 304 |
data.append([person_names[i], top_hits])
|
| 305 |
|
| 306 |
if i % save_freq == 0:
|
| 307 |
-
save_json(data,
|
| 308 |
-
log_progress_to_file(progress_log,
|
| 309 |
|
| 310 |
# 3 queries per second max
|
| 311 |
wait_time = max(time.time() - (query_start + 0.334), 0.0)
|
| 312 |
time.sleep(wait_time)
|
| 313 |
|
| 314 |
-
save_json(data,
|
| 315 |
-
log_progress_to_file(progress_log,
|
| 316 |
|
| 317 |
def write_batch_files(search_results_path,
|
| 318 |
-
prompt_data_path_prefix,
|
| 319 |
model="gpt-4o-mini",
|
| 320 |
max_tokens=1000,
|
| 321 |
temperature=0.0,
|
|
@@ -348,7 +348,7 @@ def write_batch_files(search_results_path,
|
|
| 348 |
|
| 349 |
batch_paths = []
|
| 350 |
for i in range(0, len(prompt_datas) // batch_size + 1):
|
| 351 |
-
prompt_data_path = f"{
|
| 352 |
batch_range = i * batch_size, (min(len(prompt_datas), (i + 1) * batch_size))
|
| 353 |
with open(prompt_data_path, "w") as f:
|
| 354 |
for prompt_data in prompt_datas[batch_range[0]:batch_range[1]]:
|
|
@@ -357,7 +357,7 @@ def write_batch_files(search_results_path,
|
|
| 357 |
|
| 358 |
return batch_paths
|
| 359 |
|
| 360 |
-
def send_batch_files(
|
| 361 |
"""Create and send the batch request to API endpoint."""
|
| 362 |
batches = []
|
| 363 |
|
|
@@ -391,10 +391,8 @@ def send_batch_files(prompt_data_path_prefix, batch_paths, client, timeout=24*60
|
|
| 391 |
batches.append(batch)
|
| 392 |
|
| 393 |
# Keeps track of the paths to the batch files
|
| 394 |
-
with open(f"{
|
| 395 |
pickle.dump(batches, f)
|
| 396 |
-
with open(f"{prompt_data_path_prefix}_ids.txt", "w") as f:
|
| 397 |
-
f.write("\n".join([x.id for x in batches]))
|
| 398 |
return batches
|
| 399 |
|
| 400 |
def retrieve_batch_output(client, batch_id):
|
|
@@ -450,14 +448,14 @@ def batch_process_llm_output(client, batches):
|
|
| 450 |
print(f"Failed to parse json object `{json_obj}`: {e2}")
|
| 451 |
progress_log.append(f"Failed UNKNOWN: Parsed LLM output: {e2}")
|
| 452 |
|
| 453 |
-
with open(
|
| 454 |
json.dump(us_professor_profiles, file, indent=4)
|
| 455 |
|
| 456 |
-
with open(
|
| 457 |
json.dump(not_us_professor_profiles, file, indent=4)
|
| 458 |
|
| 459 |
-
def
|
| 460 |
-
with open(
|
| 461 |
us_professor_profiles = json.load(file)
|
| 462 |
|
| 463 |
professors_dict = {
|
|
@@ -469,7 +467,7 @@ def create_frontend_data(us_professor_profiles_path="data/professor/us_professor
|
|
| 469 |
for professor in us_professor_profiles
|
| 470 |
}
|
| 471 |
|
| 472 |
-
with open(
|
| 473 |
json.dump(professors_dict, file)
|
| 474 |
|
| 475 |
def main():
|
|
@@ -505,24 +503,22 @@ def main():
|
|
| 505 |
|
| 506 |
args = parser.parse_args()
|
| 507 |
|
| 508 |
-
prompt_data_path_prefix = "data/professor/prompt_data"
|
| 509 |
-
|
| 510 |
if args.batch_search:
|
| 511 |
-
authors = get_authors(
|
| 512 |
authors_list = list(authors.keys())
|
| 513 |
print("Researching people...")
|
| 514 |
progress_log = []
|
| 515 |
batch_search_person(authors_list, progress_log, save_freq=20)
|
| 516 |
elif args.batch_analyze:
|
| 517 |
client = OpenAI()
|
| 518 |
-
batch_paths = write_batch_files(
|
| 519 |
-
send_batch_files(
|
| 520 |
elif args.batch_retrieve:
|
| 521 |
client = OpenAI()
|
| 522 |
-
with open(f"{
|
| 523 |
batches = pickle.load(f)
|
| 524 |
batch_process_llm_output(client, batches)
|
| 525 |
-
|
| 526 |
else:
|
| 527 |
raise ValueError("Please specify --batch_search, --batch_analyze, or --batch_retrieve.")
|
| 528 |
|
|
|
|
| 13 |
from tqdm import tqdm
|
| 14 |
|
| 15 |
from data_pipeline.conference_scraper import get_authors
|
| 16 |
+
from data_pipeline.config import DataPaths
|
| 17 |
|
| 18 |
|
| 19 |
_ = load_dotenv(find_dotenv())
|
|
|
|
| 162 |
|
| 163 |
def save_json(profiles, file_path):
|
| 164 |
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
| 165 |
+
with open(file_path, 'w') as file: # TODO: in the future use append mode
|
| 166 |
json.dump(profiles, file, indent=4)
|
| 167 |
|
| 168 |
def load_json(file_path):
|
|
|
|
| 213 |
extract_search_results(person_name, progress_log, client, us_professor_profiles, not_us_professor_profiles, top_hits)
|
| 214 |
|
| 215 |
|
| 216 |
+
def get_authors(min_papers=3, ignore_first_author=True):
|
| 217 |
"""
|
| 218 |
Reduce the list of authors to those with at least `min_papers` papers for
|
| 219 |
which they are not first authors. Ignores solo-authored papers and papers
|
|
|
|
| 223 |
monetarily expensive. Feel free to edit if you have more resources.
|
| 224 |
"""
|
| 225 |
authors = defaultdict(int)
|
| 226 |
+
for fname in os.listdir(DataPaths.CONFERENCE_DIR):
|
| 227 |
if not fname.endswith('.json'):
|
| 228 |
continue
|
| 229 |
|
| 230 |
+
with open(os.path.join(DataPaths.CONFERENCE_DIR, fname), 'r') as file:
|
| 231 |
for line in file:
|
| 232 |
item = json.loads(line)
|
| 233 |
paper_authors = [x.strip() for x in item[1].split(",")]
|
|
|
|
| 243 |
authors[paper_authors[i]] += 1
|
| 244 |
|
| 245 |
authors = {k: v for k, v in authors.items() if v >= min_papers}
|
| 246 |
+
os.makedirs(DataPaths.CONFERENCE_DIR, exist_ok=True)
|
| 247 |
+
with open(DataPaths.AUTHORS_PATH, 'w') as f:
|
| 248 |
for k, v in authors.items():
|
| 249 |
f.write(f"{k}\t{v}\n")
|
| 250 |
return authors
|
|
|
|
| 255 |
NOTE: cannot deal w/ interrupts and continue from past progress.
|
| 256 |
"""
|
| 257 |
|
| 258 |
+
authors = get_authors()
|
| 259 |
person_names = list(authors.keys())
|
| 260 |
|
| 261 |
client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
|
|
|
|
| 265 |
not_us_professor_profiles = []
|
| 266 |
|
| 267 |
def log_save_print(progress_log, us_professor_profiles, not_us_professor_profiles, i):
|
| 268 |
+
log_progress_to_file(progress_log, DataPaths.PROGRESS_LOG_PATH)
|
| 269 |
+
save_json(us_professor_profiles, DataPaths.US_PROF_PATH)
|
| 270 |
+
save_json(not_us_professor_profiles, DataPaths.NOT_US_PROF_PATH)
|
| 271 |
+
print(f"Saved profiles to {DataPaths.US_PROF_PATH} and {DataPaths.NOT_US_PROF_PATH} after processing {i} people")
|
| 272 |
|
| 273 |
for i in range(len(person_names)):
|
| 274 |
research_person(person_names[i], client, progress_log, us_professor_profiles, not_us_professor_profiles)
|
|
|
|
| 282 |
"""Searches everyone given in `person_names`."""
|
| 283 |
# might start and stop, pull from previous efforts
|
| 284 |
try:
|
| 285 |
+
prev_researched_authors = load_json(DataPaths.SEARCH_RESULTS_PATH)
|
| 286 |
except:
|
| 287 |
prev_researched_authors = []
|
| 288 |
ignore_set = set([x[0] for x in prev_researched_authors])
|
|
|
|
| 305 |
data.append([person_names[i], top_hits])
|
| 306 |
|
| 307 |
if i % save_freq == 0:
|
| 308 |
+
save_json(data, DataPaths.SEARCH_RESULTS_PATH)
|
| 309 |
+
log_progress_to_file(progress_log, DataPaths.PROGRESS_LOG_PATH)
|
| 310 |
|
| 311 |
# 3 queries per second max
|
| 312 |
wait_time = max(time.time() - (query_start + 0.334), 0.0)
|
| 313 |
time.sleep(wait_time)
|
| 314 |
|
| 315 |
+
save_json(data, DataPaths.SEARCH_RESULTS_PATH)
|
| 316 |
+
log_progress_to_file(progress_log, DataPaths.PROGRESS_LOG_PATH)
|
| 317 |
|
| 318 |
def write_batch_files(search_results_path,
|
|
|
|
| 319 |
model="gpt-4o-mini",
|
| 320 |
max_tokens=1000,
|
| 321 |
temperature=0.0,
|
|
|
|
| 348 |
|
| 349 |
batch_paths = []
|
| 350 |
for i in range(0, len(prompt_datas) // batch_size + 1):
|
| 351 |
+
prompt_data_path = f"{DataPaths.PROMPT_DATA_PREFIX}_{i:04d}.jsonl"
|
| 352 |
batch_range = i * batch_size, (min(len(prompt_datas), (i + 1) * batch_size))
|
| 353 |
with open(prompt_data_path, "w") as f:
|
| 354 |
for prompt_data in prompt_datas[batch_range[0]:batch_range[1]]:
|
|
|
|
| 357 |
|
| 358 |
return batch_paths
|
| 359 |
|
| 360 |
+
def send_batch_files(batch_paths, client, timeout=24*60*60):
|
| 361 |
"""Create and send the batch request to API endpoint."""
|
| 362 |
batches = []
|
| 363 |
|
|
|
|
| 391 |
batches.append(batch)
|
| 392 |
|
| 393 |
# Keeps track of the paths to the batch files
|
| 394 |
+
with open(f"{DataPaths.PROMPT_DATA_PREFIX}_batches.pkl", "wb") as f:
|
| 395 |
pickle.dump(batches, f)
|
|
|
|
|
|
|
| 396 |
return batches
|
| 397 |
|
| 398 |
def retrieve_batch_output(client, batch_id):
|
|
|
|
| 448 |
print(f"Failed to parse json object `{json_obj}`: {e2}")
|
| 449 |
progress_log.append(f"Failed UNKNOWN: Parsed LLM output: {e2}")
|
| 450 |
|
| 451 |
+
with open(DataPaths.US_PROF_PATH, 'w') as file:
|
| 452 |
json.dump(us_professor_profiles, file, indent=4)
|
| 453 |
|
| 454 |
+
with open(DataPaths.NOT_US_PROF_PATH, 'w') as file:
|
| 455 |
json.dump(not_us_professor_profiles, file, indent=4)
|
| 456 |
|
| 457 |
+
def create_professor_frontend_data():
|
| 458 |
+
with open(DataPaths.US_PROF_PATH, 'r') as file:
|
| 459 |
us_professor_profiles = json.load(file)
|
| 460 |
|
| 461 |
professors_dict = {
|
|
|
|
| 467 |
for professor in us_professor_profiles
|
| 468 |
}
|
| 469 |
|
| 470 |
+
with open(DataPaths.FRONTEND_PROF_PATH, 'w') as file:
|
| 471 |
json.dump(professors_dict, file)
|
| 472 |
|
| 473 |
def main():
|
|
|
|
| 503 |
|
| 504 |
args = parser.parse_args()
|
| 505 |
|
|
|
|
|
|
|
| 506 |
if args.batch_search:
|
| 507 |
+
authors = get_authors()
|
| 508 |
authors_list = list(authors.keys())
|
| 509 |
print("Researching people...")
|
| 510 |
progress_log = []
|
| 511 |
batch_search_person(authors_list, progress_log, save_freq=20)
|
| 512 |
elif args.batch_analyze:
|
| 513 |
client = OpenAI()
|
| 514 |
+
batch_paths = write_batch_files(DataPaths.SEARCH_RESULTS_PATH)
|
| 515 |
+
send_batch_files(batch_paths, client)
|
| 516 |
elif args.batch_retrieve:
|
| 517 |
client = OpenAI()
|
| 518 |
+
with open(f"{DataPaths.PROMPT_DATA_PREFIX}_batches.pkl", "rb") as f:
|
| 519 |
batches = pickle.load(f)
|
| 520 |
batch_process_llm_output(client, batches)
|
| 521 |
+
create_professor_frontend_data()
|
| 522 |
else:
|
| 523 |
raise ValueError("Please specify --batch_search, --batch_analyze, or --batch_retrieve.")
|
| 524 |
|