Spaces:
Running
Running
| # LoRA module for FramePack | |
| import ast | |
| from typing import Dict, List, Optional | |
| import torch | |
| import torch.nn as nn | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| logging.basicConfig(level=logging.INFO) | |
| import networks.lora as lora | |
| FRAMEPACK_TARGET_REPLACE_MODULES = ["HunyuanVideoTransformerBlock", "HunyuanVideoSingleTransformerBlock"] | |
| def create_arch_network( | |
| multiplier: float, | |
| network_dim: Optional[int], | |
| network_alpha: Optional[float], | |
| vae: nn.Module, | |
| text_encoders: List[nn.Module], | |
| unet: nn.Module, | |
| neuron_dropout: Optional[float] = None, | |
| **kwargs, | |
| ): | |
| # add default exclude patterns | |
| exclude_patterns = kwargs.get("exclude_patterns", None) | |
| if exclude_patterns is None: | |
| exclude_patterns = [] | |
| else: | |
| exclude_patterns = ast.literal_eval(exclude_patterns) | |
| # exclude if 'norm' in the name of the module | |
| exclude_patterns.append(r".*(norm).*") | |
| kwargs["exclude_patterns"] = exclude_patterns | |
| return lora.create_network( | |
| FRAMEPACK_TARGET_REPLACE_MODULES, | |
| "lora_unet", | |
| multiplier, | |
| network_dim, | |
| network_alpha, | |
| vae, | |
| text_encoders, | |
| unet, | |
| neuron_dropout=neuron_dropout, | |
| **kwargs, | |
| ) | |
| def create_arch_network_from_weights( | |
| multiplier: float, | |
| weights_sd: Dict[str, torch.Tensor], | |
| text_encoders: Optional[List[nn.Module]] = None, | |
| unet: Optional[nn.Module] = None, | |
| for_inference: bool = False, | |
| **kwargs, | |
| ) -> lora.LoRANetwork: | |
| return lora.create_network_from_weights( | |
| FRAMEPACK_TARGET_REPLACE_MODULES, multiplier, weights_sd, text_encoders, unet, for_inference, **kwargs | |
| ) | |