| from peft.tuners.tuners_utils import BaseTunerLayer |
| from typing import List, Any, Optional, Type |
|
|
|
|
| class enable_lora: |
| def __init__(self, lora_modules: List[BaseTunerLayer], activated: bool) -> None: |
| self.activated: bool = activated |
| if activated: |
| return |
| self.lora_modules: List[BaseTunerLayer] = [ |
| each for each in lora_modules if isinstance(each, BaseTunerLayer) |
| ] |
| self.scales = [ |
| { |
| active_adapter: lora_module.scaling[active_adapter] |
| for active_adapter in lora_module.active_adapters |
| } |
| for lora_module in self.lora_modules |
| ] |
|
|
| def __enter__(self) -> None: |
| if self.activated: |
| return |
|
|
| for lora_module in self.lora_modules: |
| if not isinstance(lora_module, BaseTunerLayer): |
| continue |
| lora_module.scale_layer(0) |
|
|
| def __exit__( |
| self, |
| exc_type: Optional[Type[BaseException]], |
| exc_val: Optional[BaseException], |
| exc_tb: Optional[Any], |
| ) -> None: |
| if self.activated: |
| return |
| for i, lora_module in enumerate(self.lora_modules): |
| if not isinstance(lora_module, BaseTunerLayer): |
| continue |
| for active_adapter in lora_module.active_adapters: |
| lora_module.scaling[active_adapter] = self.scales[i][active_adapter] |
|
|
|
|
| class set_lora_scale: |
| def __init__(self, lora_modules: List[BaseTunerLayer], scale: float) -> None: |
| self.lora_modules: List[BaseTunerLayer] = [ |
| each for each in lora_modules if isinstance(each, BaseTunerLayer) |
| ] |
| self.scales = [ |
| { |
| active_adapter: lora_module.scaling[active_adapter] |
| for active_adapter in lora_module.active_adapters |
| } |
| for lora_module in self.lora_modules |
| ] |
| self.scale = scale |
|
|
| def __enter__(self) -> None: |
| for lora_module in self.lora_modules: |
| if not isinstance(lora_module, BaseTunerLayer): |
| continue |
| lora_module.scale_layer(self.scale) |
|
|
| def __exit__( |
| self, |
| exc_type: Optional[Type[BaseException]], |
| exc_val: Optional[BaseException], |
| exc_tb: Optional[Any], |
| ) -> None: |
| for i, lora_module in enumerate(self.lora_modules): |
| if not isinstance(lora_module, BaseTunerLayer): |
| continue |
| for active_adapter in lora_module.active_adapters: |
| lora_module.scaling[active_adapter] = self.scales[i][active_adapter] |
|
|