Spaces:
Sleeping
Sleeping
| from typing import Any, Dict, List, Optional, Tuple, Union | |
| import torch | |
| from torch import nn | |
| from ...models.controlnet import ControlNetModel, ControlNetOutput | |
| from ...models.modeling_utils import ModelMixin | |
| class MultiControlNetModel(ModelMixin): | |
| r""" | |
| Multiple `ControlNetModel` wrapper class for Multi-ControlNet | |
| This module is a wrapper for multiple instances of the `ControlNetModel`. The `forward()` API is designed to be | |
| compatible with `ControlNetModel`. | |
| Args: | |
| controlnets (`List[ControlNetModel]`): | |
| Provides additional conditioning to the unet during the denoising process. You must set multiple | |
| `ControlNetModel` as a list. | |
| """ | |
| def __init__(self, controlnets: Union[List[ControlNetModel], Tuple[ControlNetModel]]): | |
| super().__init__() | |
| self.nets = nn.ModuleList(controlnets) | |
| def forward( | |
| self, | |
| sample: torch.FloatTensor, | |
| timestep: Union[torch.Tensor, float, int], | |
| encoder_hidden_states: torch.Tensor, | |
| controlnet_cond: List[torch.tensor], | |
| conditioning_scale: List[float], | |
| class_labels: Optional[torch.Tensor] = None, | |
| timestep_cond: Optional[torch.Tensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| cross_attention_kwargs: Optional[Dict[str, Any]] = None, | |
| guess_mode: bool = False, | |
| return_dict: bool = True, | |
| ) -> Union[ControlNetOutput, Tuple]: | |
| for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)): | |
| down_samples, mid_sample = controlnet( | |
| sample, | |
| timestep, | |
| encoder_hidden_states, | |
| image, | |
| scale, | |
| class_labels, | |
| timestep_cond, | |
| attention_mask, | |
| cross_attention_kwargs, | |
| guess_mode, | |
| return_dict, | |
| ) | |
| # merge samples | |
| if i == 0: | |
| down_block_res_samples, mid_block_res_sample = down_samples, mid_sample | |
| else: | |
| down_block_res_samples = [ | |
| samples_prev + samples_curr | |
| for samples_prev, samples_curr in zip(down_block_res_samples, down_samples) | |
| ] | |
| mid_block_res_sample += mid_sample | |
| return down_block_res_samples, mid_block_res_sample | |