VideoModelStudio
/
docs
/finetrainers-src-codebase
/finetrainers
/trainer
/control_trainer
/config.py
| import argparse | |
| from enum import Enum | |
| from typing import TYPE_CHECKING, Any, Dict, List, Union | |
| from finetrainers.utils import ArgsConfigMixin | |
| if TYPE_CHECKING: | |
| from finetrainers.args import BaseArgs | |
| class ControlType(str, Enum): | |
| r""" | |
| Enum class for the control types. | |
| """ | |
| CANNY = "canny" | |
| CUSTOM = "custom" | |
| NONE = "none" | |
| class FrameConditioningType(str, Enum): | |
| r""" | |
| Enum class for the frame conditioning types. | |
| """ | |
| INDEX = "index" | |
| PREFIX = "prefix" | |
| RANDOM = "random" | |
| FIRST_AND_LAST = "first_and_last" | |
| FULL = "full" | |
| class ControlLowRankConfig(ArgsConfigMixin): | |
| r""" | |
| Configuration class for SFT channel-concatenated Control low rank training. | |
| Args: | |
| control_type (`str`, defaults to `"canny"`): | |
| Control type for the low rank approximation matrices. Can be "canny", "custom". | |
| rank (int, defaults to `64`): | |
| Rank of the low rank approximation matrix. | |
| lora_alpha (int, defaults to `64`): | |
| The lora_alpha parameter to compute scaling factor (lora_alpha / rank) for low-rank matrices. | |
| target_modules (`str` or `List[str]`, defaults to `"(transformer_blocks|single_transformer_blocks).*(to_q|to_k|to_v|to_out.0|ff.net.0.proj|ff.net.2)"`): | |
| Target modules for the low rank approximation matrices. Can be a regex string or a list of regex strings. | |
| train_qk_norm (`bool`, defaults to `False`): | |
| Whether to train the QK normalization layers. | |
| frame_conditioning_type (`str`, defaults to `"full"`): | |
| Type of frame conditioning. Can be "index", "prefix", "random", "first_and_last", or "full". | |
| frame_conditioning_index (int, defaults to `0`): | |
| Index of the frame conditioning. Only used if `frame_conditioning_type` is "index". | |
| frame_conditioning_concatenate_mask (`bool`, defaults to `False`): | |
| Whether to concatenate the frame mask with the latents across channel dim. | |
| """ | |
| control_type: str = ControlType.CANNY | |
| rank: int = 64 | |
| lora_alpha: int = 64 | |
| target_modules: Union[str, List[str]] = ( | |
| "(transformer_blocks|single_transformer_blocks).*(to_q|to_k|to_v|to_out.0|ff.net.0.proj|ff.net.2)" | |
| ) | |
| train_qk_norm: bool = False | |
| # Specific to video models | |
| frame_conditioning_type: str = FrameConditioningType.FULL | |
| frame_conditioning_index: int = 0 | |
| frame_conditioning_concatenate_mask: bool = False | |
| def add_args(self, parser: argparse.ArgumentParser): | |
| parser.add_argument( | |
| "--control_type", | |
| type=str, | |
| default=ControlType.CANNY.value, | |
| choices=[x.value for x in ControlType.__members__.values()], | |
| ) | |
| parser.add_argument("--rank", type=int, default=64) | |
| parser.add_argument("--lora_alpha", type=int, default=64) | |
| parser.add_argument( | |
| "--target_modules", | |
| type=str, | |
| nargs="+", | |
| default=[ | |
| "(transformer_blocks|single_transformer_blocks).*(to_q|to_k|to_v|to_out.0|ff.net.0.proj|ff.net.2)" | |
| ], | |
| ) | |
| parser.add_argument("--train_qk_norm", action="store_true") | |
| parser.add_argument( | |
| "--frame_conditioning_type", | |
| type=str, | |
| default=FrameConditioningType.INDEX.value, | |
| choices=[x.value for x in FrameConditioningType.__members__.values()], | |
| ) | |
| parser.add_argument("--frame_conditioning_index", type=int, default=0) | |
| parser.add_argument("--frame_conditioning_concatenate_mask", action="store_true") | |
| def validate_args(self, args: "BaseArgs"): | |
| assert self.rank > 0, "Rank must be a positive integer." | |
| assert self.lora_alpha > 0, "lora_alpha must be a positive integer." | |
| def map_args(self, argparse_args: argparse.Namespace, mapped_args: "BaseArgs"): | |
| mapped_args.control_type = argparse_args.control_type | |
| mapped_args.rank = argparse_args.rank | |
| mapped_args.lora_alpha = argparse_args.lora_alpha | |
| mapped_args.target_modules = ( | |
| argparse_args.target_modules[0] if len(argparse_args.target_modules) == 1 else argparse_args.target_modules | |
| ) | |
| mapped_args.train_qk_norm = argparse_args.train_qk_norm | |
| mapped_args.frame_conditioning_type = argparse_args.frame_conditioning_type | |
| mapped_args.frame_conditioning_index = argparse_args.frame_conditioning_index | |
| mapped_args.frame_conditioning_concatenate_mask = argparse_args.frame_conditioning_concatenate_mask | |
| def to_dict(self) -> Dict[str, Any]: | |
| return { | |
| "control_type": self.control_type, | |
| "rank": self.rank, | |
| "lora_alpha": self.lora_alpha, | |
| "target_modules": self.target_modules, | |
| "train_qk_norm": self.train_qk_norm, | |
| "frame_conditioning_type": self.frame_conditioning_type, | |
| "frame_conditioning_index": self.frame_conditioning_index, | |
| "frame_conditioning_concatenate_mask": self.frame_conditioning_concatenate_mask, | |
| } | |
| class ControlFullRankConfig(ArgsConfigMixin): | |
| r""" | |
| Configuration class for SFT channel-concatenated Control full rank training. | |
| Args: | |
| control_type (`str`, defaults to `"canny"`): | |
| Control type for the low rank approximation matrices. Can be "canny", "custom". | |
| train_qk_norm (`bool`, defaults to `False`): | |
| Whether to train the QK normalization layers. | |
| frame_conditioning_type (`str`, defaults to `"index"`): | |
| Type of frame conditioning. Can be "index", "prefix", "random", "first_and_last", or "full". | |
| frame_conditioning_index (int, defaults to `0`): | |
| Index of the frame conditioning. Only used if `frame_conditioning_type` is "index". | |
| frame_conditioning_concatenate_mask (`bool`, defaults to `False`): | |
| Whether to concatenate the frame mask with the latents across channel dim. | |
| """ | |
| control_type: str = ControlType.CANNY | |
| train_qk_norm: bool = False | |
| # Specific to video models | |
| frame_conditioning_type: str = FrameConditioningType.INDEX | |
| frame_conditioning_index: int = 0 | |
| frame_conditioning_concatenate_mask: bool = False | |
| def add_args(self, parser: argparse.ArgumentParser): | |
| parser.add_argument( | |
| "--control_type", | |
| type=str, | |
| default=ControlType.CANNY.value, | |
| choices=[x.value for x in ControlType.__members__.values()], | |
| ) | |
| parser.add_argument("--train_qk_norm", action="store_true") | |
| parser.add_argument( | |
| "--frame_conditioning_type", | |
| type=str, | |
| default=FrameConditioningType.INDEX.value, | |
| choices=[x.value for x in FrameConditioningType.__members__.values()], | |
| ) | |
| parser.add_argument("--frame_conditioning_index", type=int, default=0) | |
| parser.add_argument("--frame_conditioning_concatenate_mask", action="store_true") | |
| def validate_args(self, args: "BaseArgs"): | |
| pass | |
| def map_args(self, argparse_args: argparse.Namespace, mapped_args: "BaseArgs"): | |
| mapped_args.control_type = argparse_args.control_type | |
| mapped_args.train_qk_norm = argparse_args.train_qk_norm | |
| mapped_args.frame_conditioning_type = argparse_args.frame_conditioning_type | |
| mapped_args.frame_conditioning_index = argparse_args.frame_conditioning_index | |
| mapped_args.frame_conditioning_concatenate_mask = argparse_args.frame_conditioning_concatenate_mask | |
| def to_dict(self) -> Dict[str, Any]: | |
| return { | |
| "control_type": self.control_type, | |
| "train_qk_norm": self.train_qk_norm, | |
| "frame_conditioning_type": self.frame_conditioning_type, | |
| "frame_conditioning_index": self.frame_conditioning_index, | |
| "frame_conditioning_concatenate_mask": self.frame_conditioning_concatenate_mask, | |
| } | |