|
|
""" |
|
|
Configuration file with exact hyperparameters from the LREC-COLING 2024 paper. |
|
|
""" |
|
|
|
|
|
import os |
|
|
from dataclasses import dataclass |
|
|
from typing import Optional |
|
|
|
|
|
@dataclass |
|
|
class Config: |
|
|
|
|
|
data_dir: str = "data" |
|
|
db_path: str = "data/ideograph.db" |
|
|
font_dir: str = "data/font" |
|
|
real_data_dir: str = "data/real" |
|
|
resnet_weights: str = None |
|
|
checkpoint_dir: str = "checkpoints" |
|
|
log_dir: str = "logs" |
|
|
|
|
|
|
|
|
roberta_model: str = "ethanyt/guwenbert-base" |
|
|
image_size: int = 64 |
|
|
vocab_size: int = 23292 |
|
|
hidden_dim: int = 768 |
|
|
resnet_out_dim: int = 2048 |
|
|
|
|
|
|
|
|
|
|
|
num_deconv_layers: int = 5 |
|
|
|
|
|
|
|
|
batch_size: int = 256 |
|
|
use_weighted_sampling_for_eval: bool = False |
|
|
num_epochs: int = 30 |
|
|
curriculum_epochs: int = 10 |
|
|
learning_rate: float = 0.0001 |
|
|
min_lr: float = 1e-5 |
|
|
alpha: float = 100.0 |
|
|
|
|
|
|
|
|
optimizer: str = "adam" |
|
|
weight_decay: float = 0.0 |
|
|
|
|
|
|
|
|
max_seq_length: int = 50 |
|
|
num_masks_min: int = 1 |
|
|
num_masks_max: int = 5 |
|
|
|
|
|
|
|
|
min_black_pixels: int = 510 |
|
|
|
|
|
|
|
|
rotation_degrees: float = 5.0 |
|
|
translation_percent: float = 0.05 |
|
|
scale_percent: float = 0.10 |
|
|
brightness_range: tuple = (0.7, 1.3) |
|
|
contrast_range: tuple = (0.2, 1.0) |
|
|
blur_kernel_range: tuple = (2, 10) |
|
|
blur_sigma_range: tuple = (1.0, 10.0) |
|
|
|
|
|
|
|
|
num_small_masks_min: int = 1 |
|
|
num_small_masks_max: int = 20 |
|
|
|
|
|
|
|
|
num_eval_samples: int = 30 |
|
|
top_k_values: list = None |
|
|
|
|
|
|
|
|
device: str = "cuda" |
|
|
num_workers: int = 4 |
|
|
pin_memory: bool = True |
|
|
|
|
|
|
|
|
use_amp: bool = True |
|
|
gradient_accumulation_steps: int = 1 |
|
|
|
|
|
|
|
|
|
|
|
seed: int = 42 |
|
|
|
|
|
|
|
|
tensorboard_log_dir: str = "logs/tensorboard" |
|
|
tensorboard_enabled: bool = True |
|
|
tensorboard_log_images_interval: int = 5 |
|
|
|
|
|
|
|
|
def __post_init__(self): |
|
|
if self.top_k_values is None: |
|
|
self.top_k_values = [5, 10, 20] |
|
|
|
|
|
|
|
|
os.makedirs(self.checkpoint_dir, exist_ok=True) |
|
|
os.makedirs(self.log_dir, exist_ok=True) |
|
|
|
|
|
def get_phase1_checkpoint_path(self): |
|
|
"""Path for Phase 1 (RoBERTa fine-tuning) checkpoint""" |
|
|
return os.path.join(self.checkpoint_dir, "phase1_roberta_finetuned.pt") |
|
|
|
|
|
def get_phase2_checkpoint_path(self, epoch: Optional[int] = None): |
|
|
"""Path for Phase 2 (MMRM) checkpoint""" |
|
|
if epoch is not None: |
|
|
return os.path.join(self.checkpoint_dir, f"phase2_mmrm_epoch{epoch}.pt") |
|
|
return os.path.join(self.checkpoint_dir, "phase2_mmrm_best.pt") |
|
|
|
|
|
def get_baseline_checkpoint_path(self, baseline_name: str): |
|
|
"""Path for baseline model checkpoints""" |
|
|
return os.path.join(self.checkpoint_dir, f"baseline_{baseline_name}.pt") |
|
|
|