from dataclasses import dataclass, field, fields from typing import Optional, List, Union, Dict, Any import draccus @dataclass class HyperXSConfig: lora_attn_dim: int = 32 module_embed_dim: int = 16 layer_embed_dim: int = 48 n_cross_attn_tokens: int = 4 out_proj_dim: int = field(default=64) layer_norm_epsilon: float = field(default=1e-5) latent_feature_dim: int = field(default=256) modules_per_layer: int = field(default=7) drop_out: float = field(default=0.0) @dataclass class InferConfig: datasets: List[str] = field(default_factory=lambda: ["boolq", "piqa", "social_i_qa", "hellaswag", "winogrande", "ARC-Easy", "ARC-Challenge", "openbookqa"]) is_json: bool = field(default=True) model_path: str = field(default="") eval_batch_size: int = field(default=32) @dataclass class ModelConfig: base_model_name: str = "meta-llama/Llama-2-7b-hf" # huggyllama/llama-7b # huggyllama/llama-7b # meta-llama/Meta-Llama-3-8B #n_layers: int = 24 #feature_dim: int = 1024 cutoff_len: int = 512 train_on_inputs: bool = False @dataclass class TrainingConfig: per_device_train_batch_size: int = field(default=16) per_device_eval_batch_size: int = field(default=32) num_workers: int = 2 ### New gradient_accumulation_steps: int=field(default=1) gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = field( default_factory=lambda: {"use_reentrant": False} ) resume_from_checkpoint: bool = False ##### optim: str=field(default="adamw_torch") ## eval_strategy: str=field(default='no') learning_rate: float = field(default=1e-05) lr_scheduler_type: str = field(default='cosine') warmup_ratio: float = field(default=0.1) gradient_checkpointing: bool = field(default=False) output_dir: str = field(default="exps") save_steps: float = field(default=0) # save_total_limit: int=field(default=1) No need any more bf16: bool=field(default=False) bf16_full_eval: bool=field(default=False) save_safetensors: bool=field(default=False) # Workaround Trainer/tied weights report_to: Union[None, str, list[str]]=field(default="none") logging_steps: int=field(default=25) # we use int only # logging_first_step: bool=field(default=False) save_strategy: str = field(default='no') save_total_limit: int = field(default=1) eval_steps: Union[None,int]=field(default=None) # we use int only f eval_delay: Union[int,float]=field(default=0) dataloader_num_workers: int = field(default=4) dataloader_pin_memory: bool = field(default=True) ### dataloader_persistent_workers: bool=field(default=True) ### dataloader_prefetch_factor: int = field(default=1) ### num_train_epochs: float = field(default=1.0) max_steps: int=field(default=-1) # torch_compile: bool=field(default=False) load_best_model_at_end: bool = field(default=True) @dataclass class DataConfig: dataset_name: str = field(default='Cifa') #data_path: List[str] = field(default_factory=list) data_path: str = field(default='./ft-training_set/math10k.json') val_set_size: int = 128 @dataclass class MainConfig: hyperxs: HyperXSConfig = field(default_factory=HyperXSConfig) model: ModelConfig = field(default_factory=ModelConfig) training: TrainingConfig = field(default_factory=TrainingConfig) data: DataConfig = field(default_factory=DataConfig) infer: InferConfig = field(default_factory=InferConfig) seed: int = 42 run_text: str=field(default='def') def from_dict(classConfig, config_dict): kwargs = {} for f in fields(classConfig): if f.name not in config_dict: # Option A: Skip if you want to use the default value defined in the dataclass | new attributes continue else: value = config_dict[f.name] if hasattr(f.type, "__dataclass_fields__"): value = from_dict(f.type, value) kwargs[f.name] = value return classConfig(**kwargs)