add option n_view
Browse files- unet/mv_unet.py +1 -0
unet/mv_unet.py
CHANGED
|
@@ -159,6 +159,7 @@ class UnifieldWrappedUNet(UNet2DConditionModel):
|
|
| 159 |
num_modalities = num_modalities,
|
| 160 |
base_img_size = latent_size,
|
| 161 |
chain_pos = multiview_chain_pose,
|
|
|
|
| 162 |
)
|
| 163 |
|
| 164 |
switch_multiview_processor(self, enable_filter=lambda name: name.endswith(f"{multiview_attn_position}.processor"))
|
|
|
|
| 159 |
num_modalities = num_modalities,
|
| 160 |
base_img_size = latent_size,
|
| 161 |
chain_pos = multiview_chain_pose,
|
| 162 |
+
views=n_views
|
| 163 |
)
|
| 164 |
|
| 165 |
switch_multiview_processor(self, enable_filter=lambda name: name.endswith(f"{multiview_attn_position}.processor"))
|