| |
|
|
| import ast |
| from typing import Dict, List, Optional |
| import torch |
| import torch.nn as nn |
|
|
| import logging |
|
|
| logger = logging.getLogger(__name__) |
| logging.basicConfig(level=logging.INFO) |
|
|
| import networks.lora as lora |
|
|
|
|
| FRAMEPACK_TARGET_REPLACE_MODULES = ["HunyuanVideoTransformerBlock", "HunyuanVideoSingleTransformerBlock"] |
|
|
|
|
| def create_arch_network( |
| multiplier: float, |
| network_dim: Optional[int], |
| network_alpha: Optional[float], |
| vae: nn.Module, |
| text_encoders: List[nn.Module], |
| unet: nn.Module, |
| neuron_dropout: Optional[float] = None, |
| **kwargs, |
| ): |
| |
| exclude_patterns = kwargs.get("exclude_patterns", None) |
| if exclude_patterns is None: |
| exclude_patterns = [] |
| else: |
| exclude_patterns = ast.literal_eval(exclude_patterns) |
|
|
| |
| exclude_patterns.append(r".*(norm).*") |
|
|
| kwargs["exclude_patterns"] = exclude_patterns |
|
|
| return lora.create_network( |
| FRAMEPACK_TARGET_REPLACE_MODULES, |
| "lora_unet", |
| multiplier, |
| network_dim, |
| network_alpha, |
| vae, |
| text_encoders, |
| unet, |
| neuron_dropout=neuron_dropout, |
| **kwargs, |
| ) |
|
|
|
|
| def create_arch_network_from_weights( |
| multiplier: float, |
| weights_sd: Dict[str, torch.Tensor], |
| text_encoders: Optional[List[nn.Module]] = None, |
| unet: Optional[nn.Module] = None, |
| for_inference: bool = False, |
| **kwargs, |
| ) -> lora.LoRANetwork: |
| return lora.create_network_from_weights( |
| FRAMEPACK_TARGET_REPLACE_MODULES, multiplier, weights_sd, text_encoders, unet, for_inference, **kwargs |
| ) |
|
|