File size: 5,409 Bytes
9b57ce7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import sys
import yaml
from pathlib import Path
from typing import Any, Optional, Union, Tuple, List, Literal
from omegaconf import OmegaConf
from transformers import HfArgumentParser
from dataclasses import dataclass, field
from transformers import TrainingArguments

@dataclass
class DataConfig:
    train_json_list: List[str] = field(default_factory=lambda: ["/path/to/dataset/meta_data.json"])
    val_json_list: List[str] = field(default_factory=lambda: ["/path/to/dataset/meta_data.json"])
    test_json_list: List[str] = field(default_factory=lambda: ["/path/to/dataset/meta_data.json"])
    soft_label: bool = False
    confidence_threshold: Optional[float] = None
    max_pixels: Optional[int] = 256 * 28 * 28  # Default max pixels
    min_pixels: Optional[int] = 256 * 28 * 28
    with_instruction: bool = True
    tied_threshold: Optional[float] = None

@dataclass
class TrainingConfig(TrainingArguments):
    max_grad_norm: Optional[float] = 1.0
    dataset_num_proc: Optional[int] = None
    center_rewards_coefficient: Optional[float] = None
    disable_flash_attn2: bool = field(default=False)
    disable_dropout: bool = field(default=False)

    vision_lr: Optional[float] = None
    merger_lr: Optional[float] = None
    rm_head_lr: Optional[float] = None
    special_token_lr: Optional[float] = None

    conduct_eval: Optional[bool] = True
    load_from_pretrained: str = None
    load_from_pretrained_step: int = None
    logging_epochs: Optional[float] = None
    eval_epochs: Optional[float] = None
    save_epochs: Optional[float] = None
    remove_unused_columns: Optional[bool] = False

    save_full_model: Optional[bool] = False
    
    # Visualization parameters
    visualization_steps: Optional[int] = 100
    max_viz_samples: Optional[int] = 4

@dataclass
class PEFTLoraConfig:
    lora_enable: bool = False
    vision_lora: bool = False
    lora_r: int = 16
    lora_alpha: int = 32
    lora_dropout: float = 0.05
    lora_target_modules: Optional[List[str]] = None
    lora_namespan_exclude: Optional[List[str]] = None
    lora_modules_to_save: Optional[List[str]] = None
    lora_task_type: str = "CAUSAL_LM"
    use_rslora: bool = False
    num_lora_modules: int = -1

    def __post_init__(self):
        if (
            isinstance(self.lora_target_modules, list)
            and len(self.lora_target_modules) == 1
        ):
            self.lora_target_modules = self.lora_target_modules[0]

        if (
            isinstance(self.lora_namespan_exclude, list)
            and len(self.lora_namespan_exclude) == 1
        ):
            self.lora_namespan_exclude = self.lora_namespan_exclude[0]


@dataclass
class ModelConfig:
    model_name_or_path: Optional[str] = None
    model_revision: str = "main"
    rm_head_type: str = "default"
    rm_head_kwargs: Optional[dict] = None
    output_dim: int = 1

    use_special_tokens: bool = False

    freeze_vision_tower: bool = field(default=False)
    freeze_llm: bool = field(default=False)
    tune_merger: bool = field(default=False)
    trainable_visual_layers: Optional[int] = -1

    torch_dtype: Optional[Literal["auto", "bfloat16", "float16", "float32"]] = None
    trust_remote_code: bool = False
    attn_implementation: Optional[str] = None
    load_in_8bit: bool = False
    load_in_4bit: bool = False
    bnb_4bit_quant_type: Literal["fp4", "nf4"] = "nf4"
    use_bnb_nested_quant: bool = False
    reward_token: Literal["last", "mean", "special"] = "last"
    loss_type: Literal["bt", "reg", "btt", "margin", "constant_margin", "scaled"] = (
        "regular"
    )
    loss_hyperparameters: dict = field(default_factory=lambda: {})
    checkpoint_path: Optional[str] = None
    
    def __post_init__(self):
        if self.load_in_8bit and self.load_in_4bit:
            raise ValueError("You can't use 8 bit and 4 bit precision at the same time")

        # if isinstance(self.lora_target_modules, list) and len(self.lora_target_modules) == 1:
        #     self.lora_target_modules = self.lora_target_modules[0]

        # if isinstance(self.lora_namespan_exclude, list) and len(self.lora_namespan_exclude) == 1:
        #     self.lora_namespan_exclude = self.lora_namespan_exclude[0]


########## Functions for get trainable modules' parameters ##########

def parse_args_with_yaml(
    dataclass_types: Tuple[type, ...], 
    config_path: str = None,
    allow_extra_keys: bool = True,
    is_train: bool = True,
) -> Tuple[Any, ...]:
    """
    Parse arguments using HfArgumentParser with OmegaConf for YAML support.
    
    Args:
        dataclass_types: Tuple of dataclass types for HfArgumentParser
        args: Optional arguments (if None, will read from sys.argv)
        allow_extra_keys: Whether to allow extra keys in config
    
    Returns:
        Tuple of parsed dataclass instances
    """
    # Read arguments from command line or provided args
    # Load YAML config and merge with command line overrides
    args = OmegaConf.to_container(OmegaConf.load(config_path))
    if not is_train:
        args.pop('deepspeed', None)

    # Parse with HfArgumentParser
    parser = HfArgumentParser(dataclass_types)
    return parser.parse_dict(args, allow_extra_keys=allow_extra_keys), config_path


if __name__ == "__main__":
    data_config, training_args, model_config, peft_lora_config = parse_args_with_yaml(
        (DataConfig, TrainingConfig, ModelConfig, PEFTLoraConfig)
    )