nvan13's picture
Upload folder using huggingface_hub
ecadbd9 verified
from dataclasses import dataclass, field, fields, asdict
from typing import Optional, List, Literal, Dict, Any, Union
from transformers import TrainingArguments, Trainer
# from omegaconf import OmegaConf
import sys
from smpeft import SamaConfig
@dataclass
class ModelConfig:
model_name: str = ""
dropout: float = 0.0
model_max_seq_length: int = field(default=512)
data_collator_mode: str=field(default='fixed', metadata={"help": "fixed or dynamic padding in DataCollator"})
# lambda_reg: float = field(default=1e-4, metadata={"help": "The control strength of regularity"})
adapter_path: Optional[str] = field(default=None)
merge_adapter_path: Optional[str] = field(default=None)
merge_output_path: Optional[str] = field(default=None)
@dataclass
class InferConfig:
datasets: List[str] = field(default_factory=lambda: ["boolq", "piqa", "social_i_qa", "hellaswag", "winogrande", "ARC-Easy", "ARC-Challenge", "openbookqa"])
is_json: bool = field(default=True)
infer_max_seq_length: int = field(default=1024)
@dataclass
class SamaConfig:
num_unique_blocks_L: int = field(default=8)
num_unique_blocks_R: int = field(default=-1)
# share_factor: int = field(default=2)
col_L: int = field(default=64)
row_R: int = field(default=64)
scaling: float = field(default=1.0)
task_type: str = "CAUSAL_LM"
target_modules: List[str] = field(default_factory=lambda: ["q_proj",])
drop_out: float = field(default=0.0)
@dataclass
class DataConfig:
dataset_name: str = 'math'
split_ratio: Union[int,float] = field(default=0.01)
path: str = "./nl_tasks/data/MetaMathQA-40K/MetaMathQA-40K.json"
dataset_split: str = field(default="train[:1000]", metadata={"help": "(`['train', 'test', 'eval']`):"})
adapter_names: List[Optional[str]] = field(default_factory=lambda: ["default"]) ###
dataset_field: List[str] = field(default_factory=list, metadata={"help": "Fields of dataset input and output."})
total_train_samples: int = field(default=2800)
total_test_samples: int = field(default=1200)
@dataclass
class TrainingOverride:
optim: str=field(default="adamw_torch") ##
eval_strategy: str=field(default='no')
per_device_train_batch_size: int=field(default=8) ##
per_device_eval_batch_size: int=field(default=8) ##
learning_rate: float = field(default=1e-05)
lr_scheduler_type: str = field(default='cosine')
warmup_steps: Union[int,float] = field(default=100)
gradient_checkpointing: bool = field(default=False)
gradient_accumulation_steps: int=field(default=1)
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = field(
default_factory=lambda: {"use_reentrant": False}
)
output_dir: str = field(default="runs")
save_steps: float = field(default=0)
save_strategy: str = field(default='no')
save_total_limit: int = field(default=1)
# save_total_limit: int=field(default=1) No need any more
bf16: bool=field(default=False)
bf16_full_eval: bool=field(default=False)
save_safetensors: bool=field(default=False)
report_to: Union[None, str, list[str]]=field(default="none")
logging_steps: int=field(default=25) # we use int only
# logging_first_step: bool=field(default=False)
eval_steps: Union[None,int]=field(default=None) # we use int only f
eval_delay: Union[int,float]=field(default=0)
dataloader_num_workers: int = field(default=2)
dataloader_pin_memory: bool = field(default=True) ###
dataloader_persistent_workers: bool=field(default=True) ###
dataloader_prefetch_factor: int = field(default=1) ###
num_train_epochs: float = field(default=1.0)
max_steps: int=field(default=-1)
load_best_model_at_end: bool = field(default=True)
@dataclass
class MainConfig:
model: ModelConfig = field(default_factory=ModelConfig)
sama_adapter: SamaConfig = field(default_factory=SamaConfig)
data: DataConfig = field(default_factory=DataConfig)
trainer_args: TrainingOverride = field(default_factory=TrainingOverride)
infer: InferConfig = field(default_factory=InferConfig)
project_name: str = "llm_sama"
seed: int = 42
run_text: str=field(default='def')
# device: str = field(default='cpu')
@dataclass
class HFTrainingArguments(TrainingArguments):
extension: Optional[Dict[str, Any]] = field(
default=None,
metadata={"help": "Serialized MainConfig excluding training args"}
)
def convert_to_trainer_args(main_cfg: MainConfig) -> HFTrainingArguments:
"""
Maps MainConfig to MyTrainingArguments.
Logic:
1. Extract 'training' fields -> Pass to TrainingArguments constructor.
2. Pack 'model', 'data', etc. -> Put into 'extension'.
"""
KEY = "trainer_args"
# 1. Convert OmegaConf/Dataclass to pure Python dict
# resolve=True ensures variables like ${model.name} are interpolated
full_dict = asdict(main_cfg)
# 2. Extract Training Arguments
# These will be unpack **kwargs to initialize the parent TrainingArguments
train_args_dict = full_dict.pop(KEY)
# 3. The rest (model, data, seed) goes into extension
extension_payload = full_dict
# 4. Initialize MyTrainingArguments
# Note: We must ensure train_args_dict keys match TrainingArguments fields.
try:
args = HFTrainingArguments(**train_args_dict)
except TypeError as e:
print(f"Error: Your 'training' config contains keys unknown to HF TrainingArguments: {e}")
sys.exit(1)
# 5. Attach the extension
args.extension = extension_payload
return args