paper-classifier / src /config.py
Andrei Pavlov
Paper classifier app and model
e0b0f3b
raw
history blame contribute delete
886 Bytes
from pathlib import Path
import re
ROOT = Path(__file__).parent
DATA_DIR = ROOT / "data"
MODEL_DIR = ROOT / "model"
RAW_DATA_PATH = ROOT / "arxivData.json"
SEED = 42
BATCH_SIZE = 16
NUM_EPOCHS = 10
VAL_RATIO = 0.1
TEST_RATIO = 0.1
LEARNING_RATE = 1e-3
MAX_LENGTH = 512
def _load_taxonomy(path):
tag_names = {}
for line in open(path):
line = line.strip()
if not line:
continue
regex_tag_and_name = re.match(r"^([\w.-]+)\s+\((.+)\)$", line)
if regex_tag_and_name:
tag_names[regex_tag_and_name.group(1)] = regex_tag_and_name.group(2)
return tag_names
TAG_NAMES = _load_taxonomy(ROOT / "taxonomy.txt")
def get_tag_name(tag):
if tag in TAG_NAMES:
return TAG_NAMES[tag]
prefix = tag.split(".")[0] if "." in tag else tag
if prefix in TAG_NAMES:
return TAG_NAMES[prefix]
return tag