| | |
| |
|
| | import ast |
| | from typing import Dict, List, Optional |
| | import torch |
| | import torch.nn as nn |
| |
|
| | import logging |
| |
|
| | logger = logging.getLogger(__name__) |
| | logging.basicConfig(level=logging.INFO) |
| |
|
| | import networks.lora as lora |
| |
|
| |
|
| | WAN_TARGET_REPLACE_MODULES = ["WanAttentionBlock"] |
| |
|
| |
|
| | def create_arch_network( |
| | multiplier: float, |
| | network_dim: Optional[int], |
| | network_alpha: Optional[float], |
| | vae: nn.Module, |
| | text_encoders: List[nn.Module], |
| | unet: nn.Module, |
| | neuron_dropout: Optional[float] = None, |
| | **kwargs, |
| | ): |
| | |
| | exclude_patterns = kwargs.get("exclude_patterns", None) |
| | if exclude_patterns is None: |
| | exclude_patterns = [] |
| | else: |
| | exclude_patterns = ast.literal_eval(exclude_patterns) |
| |
|
| | |
| | exclude_patterns.append(r".*(patch_embedding|text_embedding|time_embedding|time_projection|norm|head).*") |
| |
|
| | kwargs["exclude_patterns"] = exclude_patterns |
| |
|
| | return lora.create_network( |
| | WAN_TARGET_REPLACE_MODULES, |
| | "lora_unet", |
| | multiplier, |
| | network_dim, |
| | network_alpha, |
| | vae, |
| | text_encoders, |
| | unet, |
| | neuron_dropout=neuron_dropout, |
| | **kwargs, |
| | ) |
| |
|
| |
|
| | def create_arch_network_from_weights( |
| | multiplier: float, |
| | weights_sd: Dict[str, torch.Tensor], |
| | text_encoders: Optional[List[nn.Module]] = None, |
| | unet: Optional[nn.Module] = None, |
| | for_inference: bool = False, |
| | **kwargs, |
| | ) -> lora.LoRANetwork: |
| | return lora.create_network_from_weights( |
| | WAN_TARGET_REPLACE_MODULES, multiplier, weights_sd, text_encoders, unet, for_inference, **kwargs |
| | ) |
| |
|