| from typing import Dict
|
| from collections import Counter
|
| import matplotlib.pyplot as plt
|
| from tqdm import tqdm
|
| import json
|
| import re
|
| from pathlib import Path
|
| import arxiv
|
| import pandas as pd
|
| import requests
|
| from bs4 import BeautifulSoup
|
| from tqdm.auto import tqdm
|
|
|
|
|
| TAXONOMY_URL = "https://arxiv.org/category_taxonomy"
|
|
|
|
|
| def fetch_arxiv_categories(url: str = "https://arxiv.org/category_taxonomy") -> Dict[str, str]:
|
| resp = requests.get(url)
|
| resp.raise_for_status()
|
| soup = BeautifulSoup(resp.text, "html5lib")
|
|
|
| cat2main = {}
|
| for head in soup.select("h2.accordion-head"):
|
| main_class = head.get_text(strip=True)
|
|
|
| body = head.find_next_sibling("div", class_="accordion-body")
|
| if body is None:
|
| continue
|
|
|
| for block in body.select("div.columns.divided"):
|
| h4 = block.find("h4")
|
| if h4 is None:
|
| continue
|
|
|
| cat_code = h4.contents[0].text.strip()
|
| cat2main[cat_code] = main_class
|
|
|
| return cat2main
|
|
|
| def normalize_arxiv_id(arxiv_id: str) -> str:
|
| raw_id = arxiv_id.split("/")[-1]
|
| raw_id = re.sub(r"v\d+$", "", raw_id)
|
| raw_id = re.sub(r"\D", "", raw_id)
|
| return raw_id
|
|
|
| def make_dataset(data_json_path, remain_cats_path, categories: Dict[str, str], num_class_samples: int) -> None:
|
|
|
| if Path(data_json_path).exists():
|
| with open(data_json_path, "r", encoding="utf-8") as f:
|
| records = json.load(f)
|
| else:
|
| records = {}
|
|
|
| if Path(remain_cats_path).exists():
|
| with open(remain_cats_path, "r", encoding="utf-8") as f:
|
| remain_cats = f.readline().split()
|
| else:
|
| remain_cats = list(categories.keys())
|
|
|
| client = arxiv.Client(
|
| page_size=1000,
|
| delay_seconds=3,
|
| num_retries=3
|
| )
|
|
|
| processed_cats = []
|
| try:
|
| for category in tqdm(remain_cats, 'Fetching Categories'):
|
|
|
| search = arxiv.Search(
|
| query=f"cat:{category}",
|
| max_results=num_class_samples,
|
| sort_by=arxiv.SortCriterion.Relevance,
|
| sort_order=arxiv.SortOrder.Descending
|
| )
|
|
|
| buffer = {}
|
| for result in client.results(search):
|
| if any([result.title is None,
|
| result.categories is None,
|
| result.entry_id is None,
|
| result.summary is None]):
|
| continue
|
|
|
| arxiv_id = normalize_arxiv_id(result.entry_id)
|
|
|
| base_cats = set()
|
| target_cats = set()
|
| for cat in result.categories:
|
| if cat not in categories.keys():
|
| continue
|
| base_cats.add(categories[cat])
|
| target_cats.add(cat)
|
|
|
| if arxiv_id not in records.keys():
|
| new_record = {
|
| 'title': result.title,
|
| 'abstract': result.summary,
|
| 'categories': list(target_cats),
|
| 'base_categories': list(base_cats),
|
| }
|
|
|
| buffer[arxiv_id] = new_record
|
|
|
| records.update(buffer)
|
| del buffer
|
| processed_cats.append(category)
|
|
|
| finally:
|
| with open(data_json_path, "w", encoding="utf-8") as f:
|
| json.dump(records, f, ensure_ascii=False, indent=4)
|
|
|
| processed_cats_set = set(processed_cats)
|
| new_remain_cats = [cat for cat in remain_cats if cat not in processed_cats_set]
|
|
|
| with open(remain_cats_path, "w", encoding="utf-8") as f:
|
| f.write(" ".join(new_remain_cats))
|
|
|
|
|
| if __name__ == "__main__":
|
| data_path = Path('dataset.json')
|
| remain_cats_path = Path('./logs/remain_cats.txt')
|
|
|
| remain_cats_path.parent.mkdir(parents=True, exist_ok=True)
|
| categories = fetch_arxiv_categories("https://arxiv.org/category_taxonomy")
|
|
|
| if not remain_cats_path.exists():
|
| with open(remain_cats_path, 'w', encoding='utf-8') as file:
|
| file.write(' '.join(categories.keys()))
|
|
|
|
|
|
|
| with open(data_path, "r", encoding="utf-8") as f:
|
| dataset = json.load(f)
|
|
|
| dataset = pd.DataFrame.from_dict(dataset, orient='index').reset_index()
|
|
|
| data = []
|
| for x in dataset['categories']:
|
| data.extend(x)
|
|
|
| cnt = Counter(data)
|
| print(len(cnt.keys()))
|
|
|
| items = cnt.most_common()
|
|
|
| labels = [x[0] for x in items]
|
| values = [x[1] for x in items]
|
|
|
| plt.figure(figsize=(10, max(4, len(labels) * 0.4)))
|
| plt.barh(labels, values)
|
| plt.xlabel("Count")
|
| plt.ylabel("Object")
|
| plt.title("Counter result")
|
| plt.gca().invert_yaxis()
|
| plt.tight_layout()
|
| plt.show() |