File size: 5,650 Bytes
ecadbd9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
from dataclasses import dataclass, field, fields, asdict
from typing import Optional, List, Literal, Dict, Any, Union
from transformers import TrainingArguments, Trainer
# from omegaconf import OmegaConf
import sys

from smpeft import SamaConfig


@dataclass
class ModelConfig:
    model_name: str = ""
    dropout: float = 0.0
    model_max_seq_length: int = field(default=512)
    data_collator_mode: str=field(default='fixed', metadata={"help": "fixed or dynamic padding in DataCollator"})  
    # lambda_reg: float = field(default=1e-4, metadata={"help": "The control strength of regularity"})
    adapter_path: Optional[str] = field(default=None)

    merge_adapter_path: Optional[str] = field(default=None)
    merge_output_path: Optional[str] = field(default=None)

@dataclass
class InferConfig:
    datasets: List[str] = field(default_factory=lambda: ["boolq", "piqa", "social_i_qa", "hellaswag", "winogrande", "ARC-Easy",  "ARC-Challenge", "openbookqa"])
    is_json: bool = field(default=True)
    infer_max_seq_length: int = field(default=1024)
    
@dataclass
class SamaConfig:
    num_unique_blocks_L: int = field(default=8)
    num_unique_blocks_R: int = field(default=-1)
    # share_factor: int = field(default=2)
    col_L: int = field(default=64)
    row_R: int = field(default=64)
    scaling: float = field(default=1.0)
    task_type: str = "CAUSAL_LM"
    target_modules: List[str] = field(default_factory=lambda: ["q_proj",])
    drop_out: float = field(default=0.0)

@dataclass
class DataConfig:
    dataset_name: str = 'math'
    split_ratio: Union[int,float] = field(default=0.01)
    path: str = "./nl_tasks/data/MetaMathQA-40K/MetaMathQA-40K.json"
    dataset_split: str = field(default="train[:1000]", metadata={"help": "(`['train', 'test', 'eval']`):"})
    adapter_names: List[Optional[str]] = field(default_factory=lambda: ["default"])   ###
    dataset_field: List[str] = field(default_factory=list, metadata={"help": "Fields of dataset input and output."})
    total_train_samples: int = field(default=2800)
    total_test_samples: int = field(default=1200)
    


@dataclass
class TrainingOverride:
    optim: str=field(default="adamw_torch")   ##
    eval_strategy: str=field(default='no')
    per_device_train_batch_size: int=field(default=8) ##
    per_device_eval_batch_size: int=field(default=8)  ##

    learning_rate: float = field(default=1e-05)
    lr_scheduler_type: str = field(default='cosine')
    warmup_steps: Union[int,float] = field(default=100)

    gradient_checkpointing: bool = field(default=False)
    gradient_accumulation_steps: int=field(default=1)
    gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = field(
        default_factory=lambda: {"use_reentrant": False}
    )
    output_dir: str = field(default="runs")
    save_steps: float = field(default=0)
    save_strategy: str = field(default='no')
    save_total_limit: int = field(default=1)
    # save_total_limit: int=field(default=1) No need any more
    bf16: bool=field(default=False)
    bf16_full_eval: bool=field(default=False)
    save_safetensors: bool=field(default=False)

    report_to: Union[None, str, list[str]]=field(default="none")
    logging_steps: int=field(default=25) # we use int only
    # logging_first_step: bool=field(default=False)
    eval_steps: Union[None,int]=field(default=None)  # we use int only f
    eval_delay: Union[int,float]=field(default=0)

    dataloader_num_workers: int = field(default=2)
    dataloader_pin_memory: bool = field(default=True)  ###
    dataloader_persistent_workers: bool=field(default=True) ###
    dataloader_prefetch_factor: int = field(default=1) ###

    num_train_epochs: float = field(default=1.0)
    max_steps: int=field(default=-1)
    load_best_model_at_end: bool = field(default=True)
    



@dataclass
class MainConfig:
    model: ModelConfig = field(default_factory=ModelConfig)
    sama_adapter: SamaConfig = field(default_factory=SamaConfig)
    data: DataConfig = field(default_factory=DataConfig)
    trainer_args: TrainingOverride = field(default_factory=TrainingOverride)
    infer: InferConfig = field(default_factory=InferConfig)
    
    project_name: str = "llm_sama"
    seed: int = 42
    run_text: str=field(default='def')
    # device: str = field(default='cpu')

@dataclass
class HFTrainingArguments(TrainingArguments):
    extension: Optional[Dict[str, Any]] = field(
        default=None, 
        metadata={"help": "Serialized MainConfig excluding training args"}
    )

def convert_to_trainer_args(main_cfg: MainConfig) -> HFTrainingArguments:
    """
    Maps MainConfig to MyTrainingArguments.
    Logic:
    1. Extract 'training' fields -> Pass to TrainingArguments constructor.
    2. Pack 'model', 'data', etc. -> Put into 'extension'.
    """
    KEY = "trainer_args"
    # 1. Convert OmegaConf/Dataclass to pure Python dict
    # resolve=True ensures variables like ${model.name} are interpolated
    full_dict = asdict(main_cfg)
    
    # 2. Extract Training Arguments
    # These will be unpack **kwargs to initialize the parent TrainingArguments
    train_args_dict = full_dict.pop(KEY)
    
    # 3. The rest (model, data, seed) goes into extension
    extension_payload = full_dict
    
    # 4. Initialize MyTrainingArguments
    # Note: We must ensure train_args_dict keys match TrainingArguments fields.
    try:
        args = HFTrainingArguments(**train_args_dict)
    except TypeError as e:
        print(f"Error: Your 'training' config contains keys unknown to HF TrainingArguments: {e}")
        sys.exit(1)
        
    # 5. Attach the extension
    args.extension = extension_payload
    
    return args