| from dataclasses import dataclass, field |
| from typing import List, Optional |
|
|
| from ..core import flatten_dict |
|
|
|
|
| @dataclass |
| class ModelConfig: |
| """ |
| Arguments which define the model and tokenizer to load. |
| """ |
|
|
| model_name_or_path: Optional[str] = field( |
| default=None, |
| metadata={"help": ("The model checkpoint for weights initialization.")}, |
| ) |
| model_revision: str = field( |
| default="main", |
| metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, |
| ) |
| torch_dtype: Optional[str] = field( |
| default=None, |
| metadata={ |
| "help": ("Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the " "dtype will be automatically derived from the model's weights."), |
| "choices": ["auto", "bfloat16", "float16", "float32"], |
| }, |
| ) |
| trust_remote_code: bool = field(default=False, metadata={"help": "Trust remote code when loading a model."}) |
| attn_implementation: Optional[str] = field( |
| default=None, |
| metadata={"help": ("Which attention implementation to use; you can run --attn_implementation=flash_attention_2, in which case you must install this manually by running `pip install flash-attn --no-build-isolation`")}, |
| ) |
| use_peft: bool = field( |
| default=False, |
| metadata={"help": ("Whether to use PEFT or not for training.")}, |
| ) |
| lora_r: Optional[int] = field( |
| default=16, |
| metadata={"help": ("LoRA R value.")}, |
| ) |
| lora_alpha: Optional[int] = field( |
| default=32, |
| metadata={"help": ("LoRA alpha.")}, |
| ) |
| lora_dropout: Optional[float] = field( |
| default=0.05, |
| metadata={"help": ("LoRA dropout.")}, |
| ) |
| lora_target_modules: Optional[List[str]] = field( |
| default=None, |
| metadata={"help": ("LoRA target modules.")}, |
| ) |
| lora_modules_to_save: Optional[List[str]] = field( |
| default=None, |
| metadata={"help": ("Model layers to unfreeze & train")}, |
| ) |
| load_in_8bit: bool = field(default=False, metadata={"help": "use 8 bit precision for the base model - works only with LoRA"}) |
| load_in_4bit: bool = field(default=False, metadata={"help": "use 4 bit precision for the base model - works only with LoRA"}) |
|
|
| bnb_4bit_quant_type: Optional[str] = field(default="nf4", metadata={"help": "precise the quantization type (fp4 or nf4)"}) |
| use_bnb_nested_quant: bool = field(default=False, metadata={"help": "use nested quantization"}) |
|
|
| def to_dict(self): |
| output_dict = {} |
| for key, value in self.__dict__.items(): |
| output_dict[key] = value |
| return flatten_dict(output_dict) |
|
|
| def __post_init__(self): |
| if self.load_in_8bit and self.load_in_4bit: |
| raise ValueError("You can't use 8 bit and 4 bit precision at the same time") |
|
|