| from __future__ import annotations |
|
|
| import logging |
| from typing import Optional |
| import string |
| import nltk |
| import arxiv |
|
|
| logger = logging.getLogger(__name__) |
|
|
| def extract_title_abst(arxiv_id: str): |
| try: |
| paper = next(arxiv.Search(id_list=[arxiv_id]).results()) |
| doc = paper.title + ' ' + paper.summary |
| except: |
| doc = None |
| return doc |
|
|
| def doc_to_ids( |
| doc: Optional[str], |
| word_to_id_: dict[str, int], |
| stemming: bool, |
| lower: bool = True, |
| ): |
| from nltk.stem.porter import PorterStemmer |
|
|
| if not doc: |
| y = [] |
| else: |
| if lower: |
| doc = doc.lower() |
| doc = "".join([char for char in doc if char not in string.punctuation]) |
| words = nltk.word_tokenize(doc) |
| if stemming: |
| porter = PorterStemmer() |
| words = [porter.stem(word) for word in words] |
|
|
| |
| y = [word_to_id_[word] for word in words if word in word_to_id_] |
| |
| |
| return y |