from dataclasses import dataclass, field, fields, asdict from typing import Optional, List, Literal, Dict, Any, Union from transformers import TrainingArguments, Trainer # from omegaconf import OmegaConf import sys from smpeft import SamaConfig @dataclass class ModelConfig: model_name: str = "" dropout: float = 0.0 model_max_seq_length: int = field(default=512) data_collator_mode: str=field(default='fixed', metadata={"help": "fixed or dynamic padding in DataCollator"}) # lambda_reg: float = field(default=1e-4, metadata={"help": "The control strength of regularity"}) adapter_path: Optional[str] = field(default=None) merge_adapter_path: Optional[str] = field(default=None) merge_output_path: Optional[str] = field(default=None) @dataclass class InferConfig: datasets: List[str] = field(default_factory=lambda: ["boolq", "piqa", "social_i_qa", "hellaswag", "winogrande", "ARC-Easy", "ARC-Challenge", "openbookqa"]) is_json: bool = field(default=True) infer_max_seq_length: int = field(default=1024) @dataclass class SamaConfig: num_unique_blocks_L: int = field(default=8) num_unique_blocks_R: int = field(default=-1) # share_factor: int = field(default=2) col_L: int = field(default=64) row_R: int = field(default=64) scaling: float = field(default=1.0) task_type: str = "CAUSAL_LM" target_modules: List[str] = field(default_factory=lambda: ["q_proj",]) drop_out: float = field(default=0.0) @dataclass class DataConfig: dataset_name: str = 'math' split_ratio: Union[int,float] = field(default=0.01) path: str = "./nl_tasks/data/MetaMathQA-40K/MetaMathQA-40K.json" dataset_split: str = field(default="train[:1000]", metadata={"help": "(`['train', 'test', 'eval']`):"}) adapter_names: List[Optional[str]] = field(default_factory=lambda: ["default"]) ### dataset_field: List[str] = field(default_factory=list, metadata={"help": "Fields of dataset input and output."}) total_train_samples: int = field(default=2800) total_test_samples: int = field(default=1200) @dataclass class TrainingOverride: optim: str=field(default="adamw_torch") ## eval_strategy: str=field(default='no') per_device_train_batch_size: int=field(default=8) ## per_device_eval_batch_size: int=field(default=8) ## learning_rate: float = field(default=1e-05) lr_scheduler_type: str = field(default='cosine') warmup_steps: Union[int,float] = field(default=100) gradient_checkpointing: bool = field(default=False) gradient_accumulation_steps: int=field(default=1) gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = field( default_factory=lambda: {"use_reentrant": False} ) output_dir: str = field(default="runs") save_steps: float = field(default=0) save_strategy: str = field(default='no') save_total_limit: int = field(default=1) # save_total_limit: int=field(default=1) No need any more bf16: bool=field(default=False) bf16_full_eval: bool=field(default=False) save_safetensors: bool=field(default=False) report_to: Union[None, str, list[str]]=field(default="none") logging_steps: int=field(default=25) # we use int only # logging_first_step: bool=field(default=False) eval_steps: Union[None,int]=field(default=None) # we use int only f eval_delay: Union[int,float]=field(default=0) dataloader_num_workers: int = field(default=2) dataloader_pin_memory: bool = field(default=True) ### dataloader_persistent_workers: bool=field(default=True) ### dataloader_prefetch_factor: int = field(default=1) ### num_train_epochs: float = field(default=1.0) max_steps: int=field(default=-1) load_best_model_at_end: bool = field(default=True) @dataclass class MainConfig: model: ModelConfig = field(default_factory=ModelConfig) sama_adapter: SamaConfig = field(default_factory=SamaConfig) data: DataConfig = field(default_factory=DataConfig) trainer_args: TrainingOverride = field(default_factory=TrainingOverride) infer: InferConfig = field(default_factory=InferConfig) project_name: str = "llm_sama" seed: int = 42 run_text: str=field(default='def') # device: str = field(default='cpu') @dataclass class HFTrainingArguments(TrainingArguments): extension: Optional[Dict[str, Any]] = field( default=None, metadata={"help": "Serialized MainConfig excluding training args"} ) def convert_to_trainer_args(main_cfg: MainConfig) -> HFTrainingArguments: """ Maps MainConfig to MyTrainingArguments. Logic: 1. Extract 'training' fields -> Pass to TrainingArguments constructor. 2. Pack 'model', 'data', etc. -> Put into 'extension'. """ KEY = "trainer_args" # 1. Convert OmegaConf/Dataclass to pure Python dict # resolve=True ensures variables like ${model.name} are interpolated full_dict = asdict(main_cfg) # 2. Extract Training Arguments # These will be unpack **kwargs to initialize the parent TrainingArguments train_args_dict = full_dict.pop(KEY) # 3. The rest (model, data, seed) goes into extension extension_payload = full_dict # 4. Initialize MyTrainingArguments # Note: We must ensure train_args_dict keys match TrainingArguments fields. try: args = HFTrainingArguments(**train_args_dict) except TypeError as e: print(f"Error: Your 'training' config contains keys unknown to HF TrainingArguments: {e}") sys.exit(1) # 5. Attach the extension args.extension = extension_payload return args