Spaces:
Runtime error
Runtime error
Update models/model.py
Browse files- models/model.py +21 -33
models/model.py
CHANGED
|
@@ -885,29 +885,14 @@ class NextDiT(nn.Module):
|
|
| 885 |
# """
|
| 886 |
# # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
|
| 887 |
# print(ntk_factor, rope_scaling_factor, self.ntk_factor, self.rope_scaling_factor)
|
| 888 |
-
if
|
| 889 |
-
|
| 890 |
-
|
| 891 |
-
|
| 892 |
-
|
|
|
|
|
|
|
| 893 |
)
|
| 894 |
-
ntk_factor = ntk_factor if ntk_factor is not None else self.ntk_factor
|
| 895 |
-
if (
|
| 896 |
-
rope_scaling_factor != self.rope_scaling_factor
|
| 897 |
-
or ntk_factor != self.ntk_factor
|
| 898 |
-
):
|
| 899 |
-
print(
|
| 900 |
-
f"override freqs_cis, rope_scaling {rope_scaling_factor}, ntk {ntk_factor}",
|
| 901 |
-
flush=True,
|
| 902 |
-
)
|
| 903 |
-
self.freqs_cis = NextDiT.precompute_freqs_cis(
|
| 904 |
-
self.dim // self.n_heads,
|
| 905 |
-
384,
|
| 906 |
-
rope_scaling_factor=rope_scaling_factor,
|
| 907 |
-
ntk_factor=ntk_factor,
|
| 908 |
-
)
|
| 909 |
-
self.rope_scaling_factor = rope_scaling_factor
|
| 910 |
-
self.ntk_factor = ntk_factor
|
| 911 |
|
| 912 |
if proportional_attn:
|
| 913 |
assert base_seqlen is not None
|
|
@@ -938,7 +923,8 @@ class NextDiT(nn.Module):
|
|
| 938 |
end: int,
|
| 939 |
theta: float = 10000.0,
|
| 940 |
rope_scaling_factor: float = 1.0,
|
| 941 |
-
|
|
|
|
| 942 |
):
|
| 943 |
"""
|
| 944 |
Precompute the frequency tensor for complex exponentials (cis) with
|
|
@@ -959,23 +945,25 @@ class NextDiT(nn.Module):
|
|
| 959 |
torch.Tensor: Precomputed frequency tensor with complex
|
| 960 |
exponentials.
|
| 961 |
"""
|
|
|
|
| 962 |
|
| 963 |
-
|
|
|
|
|
|
|
| 964 |
|
| 965 |
-
|
| 966 |
-
|
| 967 |
-
)
|
| 968 |
-
|
| 969 |
-
|
| 970 |
-
|
| 971 |
-
|
| 972 |
-
t = t / rope_scaling_factor
|
| 973 |
-
freqs = torch.outer(t, freqs).float() # type: ignore
|
| 974 |
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
| 975 |
|
| 976 |
freqs_cis_h = freqs_cis.view(end, 1, dim // 4, 1).repeat(1, end, 1, 1)
|
| 977 |
freqs_cis_w = freqs_cis.view(1, end, dim // 4, 1).repeat(end, 1, 1, 1)
|
| 978 |
freqs_cis = torch.cat([freqs_cis_h, freqs_cis_w], dim=-1).flatten(2)
|
|
|
|
| 979 |
return freqs_cis
|
| 980 |
|
| 981 |
def parameter_count(self) -> int:
|
|
|
|
| 885 |
# """
|
| 886 |
# # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
|
| 887 |
# print(ntk_factor, rope_scaling_factor, self.ntk_factor, self.rope_scaling_factor)
|
| 888 |
+
if scale_factor is not None:
|
| 889 |
+
assert scale_factor is not None
|
| 890 |
+
self.freqs_cis = NextDiT.precompute_freqs_cis(
|
| 891 |
+
self.dim // self.n_heads,
|
| 892 |
+
384,
|
| 893 |
+
scale_factor=scale_factor,
|
| 894 |
+
timestep=t[0],
|
| 895 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 896 |
|
| 897 |
if proportional_attn:
|
| 898 |
assert base_seqlen is not None
|
|
|
|
| 923 |
end: int,
|
| 924 |
theta: float = 10000.0,
|
| 925 |
rope_scaling_factor: float = 1.0,
|
| 926 |
+
scale_factor: float = 1.0,
|
| 927 |
+
timestep: float = 1.0,
|
| 928 |
):
|
| 929 |
"""
|
| 930 |
Precompute the frequency tensor for complex exponentials (cis) with
|
|
|
|
| 945 |
torch.Tensor: Precomputed frequency tensor with complex
|
| 946 |
exponentials.
|
| 947 |
"""
|
| 948 |
+
freqs_inter = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float().cuda() / dim)) / scale_factor
|
| 949 |
|
| 950 |
+
target_dim = timestep * dim + 1
|
| 951 |
+
scale_factor = scale_factor ** (dim / target_dim)
|
| 952 |
+
theta = theta * scale_factor
|
| 953 |
|
| 954 |
+
freqs_time_scaled = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float().cuda() / dim))
|
| 955 |
+
|
| 956 |
+
freqs = torch.max(freqs_inter, freqs_time_scaled)
|
| 957 |
+
|
| 958 |
+
timestep = torch.arange(end, device=freqs.device, dtype=torch.float) # type: ignore
|
| 959 |
+
|
| 960 |
+
freqs = torch.outer(timestep, freqs).float() # type: ignore
|
|
|
|
|
|
|
| 961 |
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
| 962 |
|
| 963 |
freqs_cis_h = freqs_cis.view(end, 1, dim // 4, 1).repeat(1, end, 1, 1)
|
| 964 |
freqs_cis_w = freqs_cis.view(1, end, dim // 4, 1).repeat(end, 1, 1, 1)
|
| 965 |
freqs_cis = torch.cat([freqs_cis_h, freqs_cis_w], dim=-1).flatten(2)
|
| 966 |
+
|
| 967 |
return freqs_cis
|
| 968 |
|
| 969 |
def parameter_count(self) -> int:
|