File size: 3,798 Bytes
87224ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
"""
Configuration file with exact hyperparameters from the LREC-COLING 2024 paper.
"""

import os
from dataclasses import dataclass
from typing import Optional

@dataclass
class Config:
    # Paths
    data_dir: str = "data"
    db_path: str = "data/ideograph.db" # Not used in demo
    font_dir: str = "data/font" # Not used in demo
    real_data_dir: str = "data/real"
    resnet_weights: str = None # Don't load local resnet weights file, allow torchvision download or None
    checkpoint_dir: str = "checkpoints"
    log_dir: str = "logs"
    
    # Model configuration
    roberta_model: str = "ethanyt/guwenbert-base"
    image_size: int = 64
    vocab_size: int = 23292  # GuwenBERT vocab size (actual)
    hidden_dim: int = 768  # RoBERTa large hidden size
    resnet_out_dim: int = 2048  # ResNet50 output dimension

    
    # Image decoder configuration
    num_deconv_layers: int = 5
    
    # Training hyperparameters (exact from paper, optimized for 4090 D)
    batch_size: int = 256  # Match paper's batch size (256)
    use_weighted_sampling_for_eval: bool = False  # Use natural distribution for eval (match paper)
    num_epochs: int = 30
    curriculum_epochs: int = 10  # First 10 epochs use curriculum learning
    learning_rate: float = 0.0001
    min_lr: float = 1e-5
    alpha: float = 100.0  # Loss weight for image reconstruction
    
    # Optimizer
    optimizer: str = "adam"
    weight_decay: float = 0.0
    
    # Data sampling
    max_seq_length: int = 50
    num_masks_min: int = 1
    num_masks_max: int = 5
    
    # Font filtering threshold (from Appendix)
    min_black_pixels: int = 510
    
    # Image augmentation parameters (from Appendix)
    rotation_degrees: float = 5.0
    translation_percent: float = 0.05
    scale_percent: float = 0.10
    brightness_range: tuple = (0.7, 1.3)
    contrast_range: tuple = (0.2, 1.0)
    blur_kernel_range: tuple = (2, 10)  # Must be odd
    blur_sigma_range: tuple = (1.0, 10.0)
    
    # Damage simulation (from Appendix)
    num_small_masks_min: int = 1
    num_small_masks_max: int = 20
    
    # Evaluation
    num_eval_samples: int = 30  # Number of random samplings for evaluation
    top_k_values: list = None  # [5, 10, 20]
    
    # Device
    device: str = "cuda"  # NVIDIA GeForce RTX 4090 D
    num_workers: int = 4  # Reduced from 16 to 4 to prevent hanging/deadlocks
    pin_memory: bool = True  
    
    # Optimization
    use_amp: bool = True  # Automatic Mixed Precision
    gradient_accumulation_steps: int = 1  # Simulated batch size multiplier

    
    # Seed for reproducibility
    seed: int = 42
    
    # TensorBoard configuration
    tensorboard_log_dir: str = "logs/tensorboard"
    tensorboard_enabled: bool = True
    tensorboard_log_images_interval: int = 5  # Log sample images every N epochs

    
    def __post_init__(self):
        if self.top_k_values is None:
            self.top_k_values = [5, 10, 20]
        
        # Create directories if they don't exist
        os.makedirs(self.checkpoint_dir, exist_ok=True)
        os.makedirs(self.log_dir, exist_ok=True)
    
    def get_phase1_checkpoint_path(self):
        """Path for Phase 1 (RoBERTa fine-tuning) checkpoint"""
        return os.path.join(self.checkpoint_dir, "phase1_roberta_finetuned.pt")
    
    def get_phase2_checkpoint_path(self, epoch: Optional[int] = None):
        """Path for Phase 2 (MMRM) checkpoint"""
        if epoch is not None:
            return os.path.join(self.checkpoint_dir, f"phase2_mmrm_epoch{epoch}.pt")
        return os.path.join(self.checkpoint_dir, "phase2_mmrm_best.pt")
    
    def get_baseline_checkpoint_path(self, baseline_name: str):
        """Path for baseline model checkpoints"""
        return os.path.join(self.checkpoint_dir, f"baseline_{baseline_name}.pt")