update unet
Browse files- unet/mv_unet.py +15 -1
unet/mv_unet.py
CHANGED
|
@@ -2,6 +2,7 @@ import torch
|
|
| 2 |
from typing import Optional, Tuple, Union
|
| 3 |
from diffusers import UNet2DConditionModel
|
| 4 |
from diffusers.models.attention_processor import Attention
|
|
|
|
| 5 |
|
| 6 |
|
| 7 |
def switch_multiview_processor(model, enable_filter=lambda x:True):
|
|
@@ -160,4 +161,17 @@ class UnifieldWrappedUNet(UNet2DConditionModel):
|
|
| 160 |
)
|
| 161 |
|
| 162 |
switch_multiview_processor(self, enable_filter=lambda name: name.endswith(f"{multiview_attn_position}.processor"))
|
| 163 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
from typing import Optional, Tuple, Union
|
| 3 |
from diffusers import UNet2DConditionModel
|
| 4 |
from diffusers.models.attention_processor import Attention
|
| 5 |
+
from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
|
| 6 |
|
| 7 |
|
| 8 |
def switch_multiview_processor(model, enable_filter=lambda x:True):
|
|
|
|
| 161 |
)
|
| 162 |
|
| 163 |
switch_multiview_processor(self, enable_filter=lambda name: name.endswith(f"{multiview_attn_position}.processor"))
|
| 164 |
+
|
| 165 |
+
def __call__(
|
| 166 |
+
self,
|
| 167 |
+
sample: torch.Tensor,
|
| 168 |
+
timestep: Union[torch.Tensor, float, int],
|
| 169 |
+
encoder_hidden_states: torch.Tensor,
|
| 170 |
+
condition_latens: torch.Tensor = None,
|
| 171 |
+
class_labels: Optional[torch.Tensor] = None,
|
| 172 |
+
) -> Union[UNet2DConditionOutput, Tuple]:
|
| 173 |
+
|
| 174 |
+
sample = torch.cat([sample, condition_latens], dim=1)
|
| 175 |
+
return self.forward(
|
| 176 |
+
sample, timestep, encoder_hidden_states, class_labels=class_labels,
|
| 177 |
+
)
|