Spaces:
Runtime error
Runtime error
Update models/model.py
Browse files- models/model.py +4 -4
models/model.py
CHANGED
|
@@ -665,7 +665,7 @@ class NextDiT(nn.Module):
|
|
| 665 |
qk_norm: bool = False,
|
| 666 |
cap_feat_dim: int = 5120,
|
| 667 |
rope_scaling_factor: float = 1.0,
|
| 668 |
-
|
| 669 |
) -> None:
|
| 670 |
super().__init__()
|
| 671 |
self.learn_sigma = learn_sigma
|
|
@@ -715,10 +715,10 @@ class NextDiT(nn.Module):
|
|
| 715 |
dim // n_heads,
|
| 716 |
384,
|
| 717 |
rope_scaling_factor=rope_scaling_factor,
|
| 718 |
-
|
| 719 |
)
|
| 720 |
self.rope_scaling_factor = rope_scaling_factor
|
| 721 |
-
self.
|
| 722 |
# self.eol_token = nn.Parameter(torch.empty(dim))
|
| 723 |
self.pad_token = nn.Parameter(torch.empty(dim))
|
| 724 |
# nn.init.normal_(self.eol_token, std=0.02)
|
|
@@ -875,7 +875,7 @@ class NextDiT(nn.Module):
|
|
| 875 |
cap_mask,
|
| 876 |
cfg_scale,
|
| 877 |
rope_scaling_factor=None,
|
| 878 |
-
|
| 879 |
base_seqlen: Optional[int] = None,
|
| 880 |
proportional_attn: bool = False,
|
| 881 |
):
|
|
|
|
| 665 |
qk_norm: bool = False,
|
| 666 |
cap_feat_dim: int = 5120,
|
| 667 |
rope_scaling_factor: float = 1.0,
|
| 668 |
+
scale_factor: float = 1.0,
|
| 669 |
) -> None:
|
| 670 |
super().__init__()
|
| 671 |
self.learn_sigma = learn_sigma
|
|
|
|
| 715 |
dim // n_heads,
|
| 716 |
384,
|
| 717 |
rope_scaling_factor=rope_scaling_factor,
|
| 718 |
+
scale_factor=scale_factor,
|
| 719 |
)
|
| 720 |
self.rope_scaling_factor = rope_scaling_factor
|
| 721 |
+
self.scale_factor = scale_factor
|
| 722 |
# self.eol_token = nn.Parameter(torch.empty(dim))
|
| 723 |
self.pad_token = nn.Parameter(torch.empty(dim))
|
| 724 |
# nn.init.normal_(self.eol_token, std=0.02)
|
|
|
|
| 875 |
cap_mask,
|
| 876 |
cfg_scale,
|
| 877 |
rope_scaling_factor=None,
|
| 878 |
+
scale_factor=None,
|
| 879 |
base_seqlen: Optional[int] = None,
|
| 880 |
proportional_attn: bool = False,
|
| 881 |
):
|