add option n_view
Browse files- unet/config.json +1 -1
- unet/mv_unet.py +2 -2
unet/config.json
CHANGED
|
@@ -64,7 +64,7 @@
|
|
| 64 |
],
|
| 65 |
"upcast_attention": false,
|
| 66 |
"use_linear_projection": false,
|
| 67 |
-
|
| 68 |
"multiview_attn_position": "attn1",
|
| 69 |
"num_modalities": 1,
|
| 70 |
"latent_size": 64,
|
|
|
|
| 64 |
],
|
| 65 |
"upcast_attention": false,
|
| 66 |
"use_linear_projection": false,
|
| 67 |
+
"n_view": 4,
|
| 68 |
"multiview_attn_position": "attn1",
|
| 69 |
"num_modalities": 1,
|
| 70 |
"latent_size": 64,
|
unet/mv_unet.py
CHANGED
|
@@ -142,6 +142,7 @@ class UnifieldWrappedUNet(UNet2DConditionModel):
|
|
| 142 |
cross_attention_norm: Optional[str] = None,
|
| 143 |
addition_embed_type_num_heads: int = 64,
|
| 144 |
multiview_attn_position: str = "attn1",
|
|
|
|
| 145 |
num_modalities: int = 1,
|
| 146 |
latent_size: int = 64,
|
| 147 |
multiview_chain_pose: str = "parralle",
|
|
@@ -149,7 +150,7 @@ class UnifieldWrappedUNet(UNet2DConditionModel):
|
|
| 149 |
):
|
| 150 |
super().__init__(**{
|
| 151 |
k: v for k, v in locals().items() if k not in
|
| 152 |
-
["self", "kwargs", "__class__", "multiview_attn_position", "num_modalities", "latent_size", "multiview_chain_pose"]
|
| 153 |
})
|
| 154 |
|
| 155 |
add_multiview_processor(
|
|
@@ -170,7 +171,6 @@ class UnifieldWrappedUNet(UNet2DConditionModel):
|
|
| 170 |
condition_latens: torch.Tensor = None,
|
| 171 |
class_labels: Optional[torch.Tensor] = None,
|
| 172 |
) -> Union[UNet2DConditionOutput, Tuple]:
|
| 173 |
-
|
| 174 |
sample = torch.cat([sample, condition_latens], dim=1)
|
| 175 |
return self.forward(
|
| 176 |
sample, timestep, encoder_hidden_states, class_labels=class_labels,
|
|
|
|
| 142 |
cross_attention_norm: Optional[str] = None,
|
| 143 |
addition_embed_type_num_heads: int = 64,
|
| 144 |
multiview_attn_position: str = "attn1",
|
| 145 |
+
n_views: int = 4,
|
| 146 |
num_modalities: int = 1,
|
| 147 |
latent_size: int = 64,
|
| 148 |
multiview_chain_pose: str = "parralle",
|
|
|
|
| 150 |
):
|
| 151 |
super().__init__(**{
|
| 152 |
k: v for k, v in locals().items() if k not in
|
| 153 |
+
["self", "kwargs", "__class__", "multiview_attn_position", "n_views", "num_modalities", "latent_size", "multiview_chain_pose"]
|
| 154 |
})
|
| 155 |
|
| 156 |
add_multiview_processor(
|
|
|
|
| 171 |
condition_latens: torch.Tensor = None,
|
| 172 |
class_labels: Optional[torch.Tensor] = None,
|
| 173 |
) -> Union[UNet2DConditionOutput, Tuple]:
|
|
|
|
| 174 |
sample = torch.cat([sample, condition_latens], dim=1)
|
| 175 |
return self.forward(
|
| 176 |
sample, timestep, encoder_hidden_states, class_labels=class_labels,
|