| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from typing import Literal, TypedDict |
| |
|
| | from peft import LoraConfig, PeftModel, get_peft_model |
| |
|
| | from ...utils.plugin import BasePlugin |
| | from ...utils.types import HFModel |
| |
|
| |
|
| | class LoraConfigDict(TypedDict, total=False): |
| | name: Literal["lora"] |
| | """Plugin name.""" |
| | r: int |
| | """Lora rank.""" |
| | lora_alpha: int |
| | """Lora alpha.""" |
| | target_modules: list[str] |
| | """Target modules.""" |
| |
|
| |
|
| | class FreezeConfigDict(TypedDict, total=False): |
| | name: Literal["freeze"] |
| | """Plugin name.""" |
| | freeze_trainable_layers: int |
| | """Freeze trainable layers.""" |
| | freeze_trainable_modules: list[str] | None |
| | """Freeze trainable modules.""" |
| |
|
| |
|
| | class PeftPlugin(BasePlugin): |
| | def __call__(self, model: HFModel, config: dict, is_train: bool) -> HFModel: |
| | return super().__call__(model, config) |
| |
|
| |
|
| | @PeftPlugin("lora").register() |
| | def get_lora_model(model: HFModel, config: LoraConfigDict, is_train: bool) -> PeftModel: |
| | peft_config = LoraConfig(**config) |
| | model = get_peft_model(model, peft_config) |
| | return model |
| |
|
| |
|
| | @PeftPlugin("freeze").register() |
| | def get_freeze_model(model: HFModel, config: FreezeConfigDict, is_train: bool) -> HFModel: |
| | raise NotImplementedError() |
| |
|