BERT / src /download_data.py
Empfloo's picture
Upload 12 files
e829681 verified
from typing import Dict
from collections import Counter
import matplotlib.pyplot as plt
from tqdm import tqdm
import json
import re
from pathlib import Path
import arxiv
import pandas as pd
import requests
from bs4 import BeautifulSoup
from tqdm.auto import tqdm
TAXONOMY_URL = "https://arxiv.org/category_taxonomy"
def fetch_arxiv_categories(url: str = "https://arxiv.org/category_taxonomy") -> Dict[str, str]:
resp = requests.get(url)
resp.raise_for_status()
soup = BeautifulSoup(resp.text, "html5lib")
cat2main = {}
for head in soup.select("h2.accordion-head"):
main_class = head.get_text(strip=True)
body = head.find_next_sibling("div", class_="accordion-body")
if body is None:
continue
for block in body.select("div.columns.divided"):
h4 = block.find("h4")
if h4 is None:
continue
cat_code = h4.contents[0].text.strip()
cat2main[cat_code] = main_class
return cat2main
def normalize_arxiv_id(arxiv_id: str) -> str:
raw_id = arxiv_id.split("/")[-1] # 2107.05580v1
raw_id = re.sub(r"v\d+$", "", raw_id) # 2107.05580
raw_id = re.sub(r"\D", "", raw_id) # 210705580
return raw_id
def make_dataset(data_json_path, remain_cats_path, categories: Dict[str, str], num_class_samples: int) -> None:
if Path(data_json_path).exists():
with open(data_json_path, "r", encoding="utf-8") as f:
records = json.load(f)
else:
records = {}
if Path(remain_cats_path).exists():
with open(remain_cats_path, "r", encoding="utf-8") as f:
remain_cats = f.readline().split()
else:
remain_cats = list(categories.keys())
client = arxiv.Client(
page_size=1000,
delay_seconds=3,
num_retries=3
)
processed_cats = []
try:
for category in tqdm(remain_cats, 'Fetching Categories'):
search = arxiv.Search(
query=f"cat:{category}",
max_results=num_class_samples,
sort_by=arxiv.SortCriterion.Relevance,
sort_order=arxiv.SortOrder.Descending
)
buffer = {}
for result in client.results(search):
if any([result.title is None,
result.categories is None,
result.entry_id is None,
result.summary is None]):
continue
arxiv_id = normalize_arxiv_id(result.entry_id)
base_cats = set()
target_cats = set()
for cat in result.categories:
if cat not in categories.keys():
continue
base_cats.add(categories[cat])
target_cats.add(cat)
if arxiv_id not in records.keys():
new_record = {
'title': result.title,
'abstract': result.summary,
'categories': list(target_cats),
'base_categories': list(base_cats),
}
buffer[arxiv_id] = new_record
records.update(buffer)
del buffer
processed_cats.append(category)
finally:
with open(data_json_path, "w", encoding="utf-8") as f:
json.dump(records, f, ensure_ascii=False, indent=4)
processed_cats_set = set(processed_cats)
new_remain_cats = [cat for cat in remain_cats if cat not in processed_cats_set]
with open(remain_cats_path, "w", encoding="utf-8") as f:
f.write(" ".join(new_remain_cats))
if __name__ == "__main__":
data_path = Path('dataset.json')
remain_cats_path = Path('./logs/remain_cats.txt')
remain_cats_path.parent.mkdir(parents=True, exist_ok=True)
categories = fetch_arxiv_categories("https://arxiv.org/category_taxonomy")
if not remain_cats_path.exists():
with open(remain_cats_path, 'w', encoding='utf-8') as file:
file.write(' '.join(categories.keys()))
#make_dataset(data_path, remain_cats_path, categories, 2)
with open(data_path, "r", encoding="utf-8") as f:
dataset = json.load(f)
dataset = pd.DataFrame.from_dict(dataset, orient='index').reset_index()
data = []
for x in dataset['categories']:
data.extend(x)
cnt = Counter(data)
print(len(cnt.keys()))
items = cnt.most_common()
labels = [x[0] for x in items]
values = [x[1] for x in items]
plt.figure(figsize=(10, max(4, len(labels) * 0.4)))
plt.barh(labels, values)
plt.xlabel("Count")
plt.ylabel("Object")
plt.title("Counter result")
plt.gca().invert_yaxis()
plt.tight_layout()
plt.show()