Luffuly commited on
Commit
5189c0a
·
1 Parent(s): 7487caa

update unet

Browse files
Files changed (1) hide show
  1. unet/mv_unet.py +15 -1
unet/mv_unet.py CHANGED
@@ -2,6 +2,7 @@ import torch
2
  from typing import Optional, Tuple, Union
3
  from diffusers import UNet2DConditionModel
4
  from diffusers.models.attention_processor import Attention
 
5
 
6
 
7
  def switch_multiview_processor(model, enable_filter=lambda x:True):
@@ -160,4 +161,17 @@ class UnifieldWrappedUNet(UNet2DConditionModel):
160
  )
161
 
162
  switch_multiview_processor(self, enable_filter=lambda name: name.endswith(f"{multiview_attn_position}.processor"))
163
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  from typing import Optional, Tuple, Union
3
  from diffusers import UNet2DConditionModel
4
  from diffusers.models.attention_processor import Attention
5
+ from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
6
 
7
 
8
  def switch_multiview_processor(model, enable_filter=lambda x:True):
 
161
  )
162
 
163
  switch_multiview_processor(self, enable_filter=lambda name: name.endswith(f"{multiview_attn_position}.processor"))
164
+
165
+ def __call__(
166
+ self,
167
+ sample: torch.Tensor,
168
+ timestep: Union[torch.Tensor, float, int],
169
+ encoder_hidden_states: torch.Tensor,
170
+ condition_latens: torch.Tensor = None,
171
+ class_labels: Optional[torch.Tensor] = None,
172
+ ) -> Union[UNet2DConditionOutput, Tuple]:
173
+
174
+ sample = torch.cat([sample, condition_latens], dim=1)
175
+ return self.forward(
176
+ sample, timestep, encoder_hidden_states, class_labels=class_labels,
177
+ )