File size: 886 Bytes
e0b0f3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from pathlib import Path
import re

ROOT = Path(__file__).parent
DATA_DIR = ROOT / "data"
MODEL_DIR = ROOT / "model"
RAW_DATA_PATH = ROOT / "arxivData.json"

SEED = 42
BATCH_SIZE = 16
NUM_EPOCHS = 10
VAL_RATIO = 0.1
TEST_RATIO = 0.1
LEARNING_RATE = 1e-3
MAX_LENGTH = 512


def _load_taxonomy(path):
    tag_names = {}
    for line in open(path):
        line = line.strip()
        if not line:
            continue

        regex_tag_and_name = re.match(r"^([\w.-]+)\s+\((.+)\)$", line)
        if regex_tag_and_name:
            tag_names[regex_tag_and_name.group(1)] = regex_tag_and_name.group(2)

    return tag_names


TAG_NAMES = _load_taxonomy(ROOT / "taxonomy.txt")


def get_tag_name(tag):
    if tag in TAG_NAMES:
        return TAG_NAMES[tag]

    prefix = tag.split(".")[0] if "." in tag else tag
    if prefix in TAG_NAMES:
        return TAG_NAMES[prefix]

    return tag