Instance-based-FT / iba /configIBA.py
nvan13's picture
Upload folder using huggingface_hub
a0d95b0 verified
from dataclasses import dataclass, field, fields
from typing import Optional, List, Union, Dict, Any
import draccus
@dataclass
class HyperXSConfig:
lora_attn_dim: int = 32
module_embed_dim: int = 16
layer_embed_dim: int = 48
n_cross_attn_tokens: int = 4
out_proj_dim: int = field(default=64)
layer_norm_epsilon: float = field(default=1e-5)
latent_feature_dim: int = field(default=256)
modules_per_layer: int = field(default=7)
drop_out: float = field(default=0.0)
@dataclass
class InferConfig:
datasets: List[str] = field(default_factory=lambda: ["boolq", "piqa", "social_i_qa", "hellaswag", "winogrande", "ARC-Easy", "ARC-Challenge", "openbookqa"])
is_json: bool = field(default=True)
model_path: str = field(default="")
eval_batch_size: int = field(default=32)
@dataclass
class ModelConfig:
base_model_name: str = "meta-llama/Llama-2-7b-hf" # huggyllama/llama-7b
# huggyllama/llama-7b # meta-llama/Meta-Llama-3-8B
#n_layers: int = 24
#feature_dim: int = 1024
cutoff_len: int = 512
train_on_inputs: bool = False
@dataclass
class TrainingConfig:
per_device_train_batch_size: int = field(default=16)
per_device_eval_batch_size: int = field(default=32)
num_workers: int = 2
### New
gradient_accumulation_steps: int=field(default=1)
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = field(
default_factory=lambda: {"use_reentrant": False}
)
resume_from_checkpoint: bool = False
#####
optim: str=field(default="adamw_torch") ##
eval_strategy: str=field(default='no')
learning_rate: float = field(default=1e-05)
lr_scheduler_type: str = field(default='cosine')
warmup_ratio: float = field(default=0.1)
gradient_checkpointing: bool = field(default=False)
output_dir: str = field(default="exps")
save_steps: float = field(default=0)
# save_total_limit: int=field(default=1) No need any more
bf16: bool=field(default=False)
bf16_full_eval: bool=field(default=False)
save_safetensors: bool=field(default=False) # Workaround Trainer/tied weights
report_to: Union[None, str, list[str]]=field(default="none")
logging_steps: int=field(default=25) # we use int only
# logging_first_step: bool=field(default=False)
save_strategy: str = field(default='no')
save_total_limit: int = field(default=1)
eval_steps: Union[None,int]=field(default=None) # we use int only f
eval_delay: Union[int,float]=field(default=0)
dataloader_num_workers: int = field(default=4)
dataloader_pin_memory: bool = field(default=True) ###
dataloader_persistent_workers: bool=field(default=True) ###
dataloader_prefetch_factor: int = field(default=1) ###
num_train_epochs: float = field(default=1.0)
max_steps: int=field(default=-1)
# torch_compile: bool=field(default=False)
load_best_model_at_end: bool = field(default=True)
@dataclass
class DataConfig:
dataset_name: str = field(default='Cifa')
#data_path: List[str] = field(default_factory=list)
data_path: str = field(default='./ft-training_set/math10k.json')
val_set_size: int = 128
@dataclass
class MainConfig:
hyperxs: HyperXSConfig = field(default_factory=HyperXSConfig)
model: ModelConfig = field(default_factory=ModelConfig)
training: TrainingConfig = field(default_factory=TrainingConfig)
data: DataConfig = field(default_factory=DataConfig)
infer: InferConfig = field(default_factory=InferConfig)
seed: int = 42
run_text: str=field(default='def')
def from_dict(classConfig, config_dict):
kwargs = {}
for f in fields(classConfig):
if f.name not in config_dict:
# Option A: Skip if you want to use the default value defined in the dataclass | new attributes
continue
else:
value = config_dict[f.name]
if hasattr(f.type, "__dataclass_fields__"):
value = from_dict(f.type, value)
kwargs[f.name] = value
return classConfig(**kwargs)