from typing import Dict from collections import Counter import matplotlib.pyplot as plt from tqdm import tqdm import json import re from pathlib import Path import arxiv import pandas as pd import requests from bs4 import BeautifulSoup from tqdm.auto import tqdm TAXONOMY_URL = "https://arxiv.org/category_taxonomy" def fetch_arxiv_categories(url: str = "https://arxiv.org/category_taxonomy") -> Dict[str, str]: resp = requests.get(url) resp.raise_for_status() soup = BeautifulSoup(resp.text, "html5lib") cat2main = {} for head in soup.select("h2.accordion-head"): main_class = head.get_text(strip=True) body = head.find_next_sibling("div", class_="accordion-body") if body is None: continue for block in body.select("div.columns.divided"): h4 = block.find("h4") if h4 is None: continue cat_code = h4.contents[0].text.strip() cat2main[cat_code] = main_class return cat2main def normalize_arxiv_id(arxiv_id: str) -> str: raw_id = arxiv_id.split("/")[-1] # 2107.05580v1 raw_id = re.sub(r"v\d+$", "", raw_id) # 2107.05580 raw_id = re.sub(r"\D", "", raw_id) # 210705580 return raw_id def make_dataset(data_json_path, remain_cats_path, categories: Dict[str, str], num_class_samples: int) -> None: if Path(data_json_path).exists(): with open(data_json_path, "r", encoding="utf-8") as f: records = json.load(f) else: records = {} if Path(remain_cats_path).exists(): with open(remain_cats_path, "r", encoding="utf-8") as f: remain_cats = f.readline().split() else: remain_cats = list(categories.keys()) client = arxiv.Client( page_size=1000, delay_seconds=3, num_retries=3 ) processed_cats = [] try: for category in tqdm(remain_cats, 'Fetching Categories'): search = arxiv.Search( query=f"cat:{category}", max_results=num_class_samples, sort_by=arxiv.SortCriterion.Relevance, sort_order=arxiv.SortOrder.Descending ) buffer = {} for result in client.results(search): if any([result.title is None, result.categories is None, result.entry_id is None, result.summary is None]): continue arxiv_id = normalize_arxiv_id(result.entry_id) base_cats = set() target_cats = set() for cat in result.categories: if cat not in categories.keys(): continue base_cats.add(categories[cat]) target_cats.add(cat) if arxiv_id not in records.keys(): new_record = { 'title': result.title, 'abstract': result.summary, 'categories': list(target_cats), 'base_categories': list(base_cats), } buffer[arxiv_id] = new_record records.update(buffer) del buffer processed_cats.append(category) finally: with open(data_json_path, "w", encoding="utf-8") as f: json.dump(records, f, ensure_ascii=False, indent=4) processed_cats_set = set(processed_cats) new_remain_cats = [cat for cat in remain_cats if cat not in processed_cats_set] with open(remain_cats_path, "w", encoding="utf-8") as f: f.write(" ".join(new_remain_cats)) if __name__ == "__main__": data_path = Path('dataset.json') remain_cats_path = Path('./logs/remain_cats.txt') remain_cats_path.parent.mkdir(parents=True, exist_ok=True) categories = fetch_arxiv_categories("https://arxiv.org/category_taxonomy") if not remain_cats_path.exists(): with open(remain_cats_path, 'w', encoding='utf-8') as file: file.write(' '.join(categories.keys())) #make_dataset(data_path, remain_cats_path, categories, 2) with open(data_path, "r", encoding="utf-8") as f: dataset = json.load(f) dataset = pd.DataFrame.from_dict(dataset, orient='index').reset_index() data = [] for x in dataset['categories']: data.extend(x) cnt = Counter(data) print(len(cnt.keys())) items = cnt.most_common() labels = [x[0] for x in items] values = [x[1] for x in items] plt.figure(figsize=(10, max(4, len(labels) * 0.4))) plt.barh(labels, values) plt.xlabel("Count") plt.ylabel("Object") plt.title("Counter result") plt.gca().invert_yaxis() plt.tight_layout() plt.show()