File size: 4,075 Bytes
a0d95b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from dataclasses import dataclass, field, fields
from typing import Optional, List, Union, Dict, Any
import draccus

@dataclass
class HyperXSConfig:
    lora_attn_dim: int = 32
    module_embed_dim: int = 16
    layer_embed_dim: int = 48
    n_cross_attn_tokens: int = 4
    out_proj_dim: int = field(default=64)
    layer_norm_epsilon: float = field(default=1e-5)
    latent_feature_dim: int = field(default=256)
    modules_per_layer: int = field(default=7)
    drop_out: float = field(default=0.0)

@dataclass
class InferConfig:
    datasets: List[str] = field(default_factory=lambda: ["boolq", "piqa", "social_i_qa", "hellaswag", "winogrande", "ARC-Easy",  "ARC-Challenge", "openbookqa"])
    is_json: bool = field(default=True)
    model_path: str = field(default="")
    eval_batch_size: int = field(default=32)
    
@dataclass
class ModelConfig:
    base_model_name: str = "meta-llama/Llama-2-7b-hf" #   huggyllama/llama-7b
    # huggyllama/llama-7b # meta-llama/Meta-Llama-3-8B
    #n_layers: int = 24
    #feature_dim: int = 1024
    cutoff_len: int = 512
    train_on_inputs: bool = False

@dataclass
class TrainingConfig:
    per_device_train_batch_size: int = field(default=16)
    per_device_eval_batch_size: int = field(default=32)
    num_workers: int = 2
    ### New
    gradient_accumulation_steps: int=field(default=1)
    gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = field(
        default_factory=lambda: {"use_reentrant": False}
    )

    
    resume_from_checkpoint: bool = False

    #####
    optim: str=field(default="adamw_torch")   ##
    eval_strategy: str=field(default='no')

    learning_rate: float = field(default=1e-05)
    lr_scheduler_type: str = field(default='cosine')
    warmup_ratio: float = field(default=0.1)

    gradient_checkpointing: bool = field(default=False)
    output_dir: str = field(default="exps")
    save_steps: float = field(default=0)
    # save_total_limit: int=field(default=1) No need any more
    bf16: bool=field(default=False)
    bf16_full_eval: bool=field(default=False)
    save_safetensors: bool=field(default=False)  # Workaround Trainer/tied weights

    report_to: Union[None, str, list[str]]=field(default="none")
    logging_steps: int=field(default=25) # we use int only
    # logging_first_step: bool=field(default=False)
    save_strategy: str = field(default='no')
    save_total_limit: int = field(default=1)
    eval_steps: Union[None,int]=field(default=None)  # we use int only f
    eval_delay: Union[int,float]=field(default=0)

    dataloader_num_workers: int = field(default=4)
    dataloader_pin_memory: bool = field(default=True)  ###
    dataloader_persistent_workers: bool=field(default=True) ###
    dataloader_prefetch_factor: int = field(default=1) ###

    num_train_epochs: float = field(default=1.0)
    max_steps: int=field(default=-1)

    # torch_compile: bool=field(default=False)
    load_best_model_at_end: bool = field(default=True)

@dataclass
class DataConfig:
    dataset_name: str = field(default='Cifa')
    #data_path: List[str] = field(default_factory=list)
    data_path: str = field(default='./ft-training_set/math10k.json')
    val_set_size: int = 128


@dataclass
class MainConfig:
    hyperxs: HyperXSConfig = field(default_factory=HyperXSConfig)
    model: ModelConfig = field(default_factory=ModelConfig)
    training: TrainingConfig = field(default_factory=TrainingConfig)
    data: DataConfig = field(default_factory=DataConfig)
    infer: InferConfig = field(default_factory=InferConfig)
    
    seed: int = 42
    run_text: str=field(default='def')

def from_dict(classConfig, config_dict):
    kwargs = {}
    for f in fields(classConfig):
        if f.name not in config_dict:
            # Option A: Skip if you want to use the default value defined in the dataclass | new attributes
            continue
        else:
            value = config_dict[f.name]
            if hasattr(f.type, "__dataclass_fields__"):
                value = from_dict(f.type, value)
            kwargs[f.name] = value
    return classConfig(**kwargs)