Luffuly commited on
Commit
7a8dd96
·
1 Parent(s): 817b35b

add option n_view

Browse files
Files changed (2) hide show
  1. unet/config.json +1 -1
  2. unet/mv_unet.py +2 -2
unet/config.json CHANGED
@@ -64,7 +64,7 @@
64
  ],
65
  "upcast_attention": false,
66
  "use_linear_projection": false,
67
-
68
  "multiview_attn_position": "attn1",
69
  "num_modalities": 1,
70
  "latent_size": 64,
 
64
  ],
65
  "upcast_attention": false,
66
  "use_linear_projection": false,
67
+ "n_view": 4,
68
  "multiview_attn_position": "attn1",
69
  "num_modalities": 1,
70
  "latent_size": 64,
unet/mv_unet.py CHANGED
@@ -142,6 +142,7 @@ class UnifieldWrappedUNet(UNet2DConditionModel):
142
  cross_attention_norm: Optional[str] = None,
143
  addition_embed_type_num_heads: int = 64,
144
  multiview_attn_position: str = "attn1",
 
145
  num_modalities: int = 1,
146
  latent_size: int = 64,
147
  multiview_chain_pose: str = "parralle",
@@ -149,7 +150,7 @@ class UnifieldWrappedUNet(UNet2DConditionModel):
149
  ):
150
  super().__init__(**{
151
  k: v for k, v in locals().items() if k not in
152
- ["self", "kwargs", "__class__", "multiview_attn_position", "num_modalities", "latent_size", "multiview_chain_pose"]
153
  })
154
 
155
  add_multiview_processor(
@@ -170,7 +171,6 @@ class UnifieldWrappedUNet(UNet2DConditionModel):
170
  condition_latens: torch.Tensor = None,
171
  class_labels: Optional[torch.Tensor] = None,
172
  ) -> Union[UNet2DConditionOutput, Tuple]:
173
-
174
  sample = torch.cat([sample, condition_latens], dim=1)
175
  return self.forward(
176
  sample, timestep, encoder_hidden_states, class_labels=class_labels,
 
142
  cross_attention_norm: Optional[str] = None,
143
  addition_embed_type_num_heads: int = 64,
144
  multiview_attn_position: str = "attn1",
145
+ n_views: int = 4,
146
  num_modalities: int = 1,
147
  latent_size: int = 64,
148
  multiview_chain_pose: str = "parralle",
 
150
  ):
151
  super().__init__(**{
152
  k: v for k, v in locals().items() if k not in
153
+ ["self", "kwargs", "__class__", "multiview_attn_position", "n_views", "num_modalities", "latent_size", "multiview_chain_pose"]
154
  })
155
 
156
  add_multiview_processor(
 
171
  condition_latens: torch.Tensor = None,
172
  class_labels: Optional[torch.Tensor] = None,
173
  ) -> Union[UNet2DConditionOutput, Tuple]:
 
174
  sample = torch.cat([sample, condition_latens], dim=1)
175
  return self.forward(
176
  sample, timestep, encoder_hidden_states, class_labels=class_labels,