MMRM / config.py
rexera's picture
0-shot pipeline test
87224ba
"""
Configuration file with exact hyperparameters from the LREC-COLING 2024 paper.
"""
import os
from dataclasses import dataclass
from typing import Optional
@dataclass
class Config:
# Paths
data_dir: str = "data"
db_path: str = "data/ideograph.db" # Not used in demo
font_dir: str = "data/font" # Not used in demo
real_data_dir: str = "data/real"
resnet_weights: str = None # Don't load local resnet weights file, allow torchvision download or None
checkpoint_dir: str = "checkpoints"
log_dir: str = "logs"
# Model configuration
roberta_model: str = "ethanyt/guwenbert-base"
image_size: int = 64
vocab_size: int = 23292 # GuwenBERT vocab size (actual)
hidden_dim: int = 768 # RoBERTa large hidden size
resnet_out_dim: int = 2048 # ResNet50 output dimension
# Image decoder configuration
num_deconv_layers: int = 5
# Training hyperparameters (exact from paper, optimized for 4090 D)
batch_size: int = 256 # Match paper's batch size (256)
use_weighted_sampling_for_eval: bool = False # Use natural distribution for eval (match paper)
num_epochs: int = 30
curriculum_epochs: int = 10 # First 10 epochs use curriculum learning
learning_rate: float = 0.0001
min_lr: float = 1e-5
alpha: float = 100.0 # Loss weight for image reconstruction
# Optimizer
optimizer: str = "adam"
weight_decay: float = 0.0
# Data sampling
max_seq_length: int = 50
num_masks_min: int = 1
num_masks_max: int = 5
# Font filtering threshold (from Appendix)
min_black_pixels: int = 510
# Image augmentation parameters (from Appendix)
rotation_degrees: float = 5.0
translation_percent: float = 0.05
scale_percent: float = 0.10
brightness_range: tuple = (0.7, 1.3)
contrast_range: tuple = (0.2, 1.0)
blur_kernel_range: tuple = (2, 10) # Must be odd
blur_sigma_range: tuple = (1.0, 10.0)
# Damage simulation (from Appendix)
num_small_masks_min: int = 1
num_small_masks_max: int = 20
# Evaluation
num_eval_samples: int = 30 # Number of random samplings for evaluation
top_k_values: list = None # [5, 10, 20]
# Device
device: str = "cuda" # NVIDIA GeForce RTX 4090 D
num_workers: int = 4 # Reduced from 16 to 4 to prevent hanging/deadlocks
pin_memory: bool = True
# Optimization
use_amp: bool = True # Automatic Mixed Precision
gradient_accumulation_steps: int = 1 # Simulated batch size multiplier
# Seed for reproducibility
seed: int = 42
# TensorBoard configuration
tensorboard_log_dir: str = "logs/tensorboard"
tensorboard_enabled: bool = True
tensorboard_log_images_interval: int = 5 # Log sample images every N epochs
def __post_init__(self):
if self.top_k_values is None:
self.top_k_values = [5, 10, 20]
# Create directories if they don't exist
os.makedirs(self.checkpoint_dir, exist_ok=True)
os.makedirs(self.log_dir, exist_ok=True)
def get_phase1_checkpoint_path(self):
"""Path for Phase 1 (RoBERTa fine-tuning) checkpoint"""
return os.path.join(self.checkpoint_dir, "phase1_roberta_finetuned.pt")
def get_phase2_checkpoint_path(self, epoch: Optional[int] = None):
"""Path for Phase 2 (MMRM) checkpoint"""
if epoch is not None:
return os.path.join(self.checkpoint_dir, f"phase2_mmrm_epoch{epoch}.pt")
return os.path.join(self.checkpoint_dir, "phase2_mmrm_best.pt")
def get_baseline_checkpoint_path(self, baseline_name: str):
"""Path for baseline model checkpoints"""
return os.path.join(self.checkpoint_dir, f"baseline_{baseline_name}.pt")