import torch import torch.nn as nn from .base import LycorisBaseModule from .locon import LoConModule from .loha import LohaModule from .lokr import LokrModule from .full import FullModule from .norms import NormModule from .diag_oft import DiagOFTModule from .boft import ButterflyOFTModule from .glora import GLoRAModule from .dylora import DyLoraModule from .ia3 import IA3Module from ..functional.general import factorization MODULE_LIST = [ LoConModule, LohaModule, IA3Module, LokrModule, FullModule, NormModule, DiagOFTModule, ButterflyOFTModule, GLoRAModule, DyLoraModule, ] def get_module(lyco_state_dict, lora_name): for module in MODULE_LIST: if module.algo_check(lyco_state_dict, lora_name): return module, tuple(module.extract_state_dict(lyco_state_dict, lora_name)) return None, None @torch.no_grad() def make_module(lyco_type: LycorisBaseModule, params, lora_name, orig_module): try: module = lyco_type.make_module_from_state_dict(lora_name, orig_module, *params) except NotImplementedError: module = None return module