Spaces:
Runtime error
Runtime error
Update models/model.py
Browse files- models/model.py +5 -25
models/model.py
CHANGED
|
@@ -592,7 +592,7 @@ class ParallelFinalLayer(nn.Module):
|
|
| 592 |
return x
|
| 593 |
|
| 594 |
|
| 595 |
-
class
|
| 596 |
"""
|
| 597 |
Diffusion model with a Transformer backbone.
|
| 598 |
"""
|
|
@@ -645,7 +645,7 @@ class DiT_Llama(nn.Module):
|
|
| 645 |
assert (dim // n_heads) % 4 == 0, "2d rope needs head dim to be divisible by 4"
|
| 646 |
self.dim = dim
|
| 647 |
self.n_heads = n_heads
|
| 648 |
-
self.freqs_cis =
|
| 649 |
dim // n_heads, 384, rope_scaling_factor=rope_scaling_factor, ntk_factor=ntk_factor
|
| 650 |
)
|
| 651 |
self.rope_scaling_factor = rope_scaling_factor
|
|
@@ -781,7 +781,7 @@ class DiT_Llama(nn.Module):
|
|
| 781 |
ntk_factor = ntk_factor if ntk_factor is not None else self.ntk_factor
|
| 782 |
if rope_scaling_factor != self.rope_scaling_factor or ntk_factor != self.ntk_factor:
|
| 783 |
print(f"override freqs_cis, rope_scaling {rope_scaling_factor}, ntk {ntk_factor}", flush=True)
|
| 784 |
-
self.freqs_cis =
|
| 785 |
self.dim // self.n_heads, 384,
|
| 786 |
rope_scaling_factor=rope_scaling_factor, ntk_factor=ntk_factor
|
| 787 |
)
|
|
@@ -882,27 +882,7 @@ class DiT_Llama(nn.Module):
|
|
| 882 |
#############################################################################
|
| 883 |
# DiT Configs #
|
| 884 |
#############################################################################
|
| 885 |
-
|
| 886 |
-
|
| 887 |
-
def DiT_Llama_600M_patch2(**kwargs):
|
| 888 |
-
return DiT_Llama(
|
| 889 |
-
patch_size=2, dim=1536, n_layers=16, n_heads=32, **kwargs
|
| 890 |
-
)
|
| 891 |
-
|
| 892 |
-
|
| 893 |
-
def DiT_Llama_2B_patch2(**kwargs):
|
| 894 |
-
return DiT_Llama(
|
| 895 |
patch_size=2, dim=2304, n_layers=24, n_heads=32, **kwargs
|
| 896 |
)
|
| 897 |
-
|
| 898 |
-
|
| 899 |
-
def DiT_Llama_3B_patch2(**kwargs):
|
| 900 |
-
return DiT_Llama(
|
| 901 |
-
patch_size=2, dim=3072, n_layers=32, n_heads=32, **kwargs
|
| 902 |
-
)
|
| 903 |
-
|
| 904 |
-
|
| 905 |
-
def DiT_Llama_7B_patch2(**kwargs):
|
| 906 |
-
return DiT_Llama(
|
| 907 |
-
patch_size=2, dim=4096, n_layers=32, n_heads=32, **kwargs
|
| 908 |
-
)
|
|
|
|
| 592 |
return x
|
| 593 |
|
| 594 |
|
| 595 |
+
class NextDiT(nn.Module):
|
| 596 |
"""
|
| 597 |
Diffusion model with a Transformer backbone.
|
| 598 |
"""
|
|
|
|
| 645 |
assert (dim // n_heads) % 4 == 0, "2d rope needs head dim to be divisible by 4"
|
| 646 |
self.dim = dim
|
| 647 |
self.n_heads = n_heads
|
| 648 |
+
self.freqs_cis = NextDiT.precompute_freqs_cis(
|
| 649 |
dim // n_heads, 384, rope_scaling_factor=rope_scaling_factor, ntk_factor=ntk_factor
|
| 650 |
)
|
| 651 |
self.rope_scaling_factor = rope_scaling_factor
|
|
|
|
| 781 |
ntk_factor = ntk_factor if ntk_factor is not None else self.ntk_factor
|
| 782 |
if rope_scaling_factor != self.rope_scaling_factor or ntk_factor != self.ntk_factor:
|
| 783 |
print(f"override freqs_cis, rope_scaling {rope_scaling_factor}, ntk {ntk_factor}", flush=True)
|
| 784 |
+
self.freqs_cis = NextDiT.precompute_freqs_cis(
|
| 785 |
self.dim // self.n_heads, 384,
|
| 786 |
rope_scaling_factor=rope_scaling_factor, ntk_factor=ntk_factor
|
| 787 |
)
|
|
|
|
| 882 |
#############################################################################
|
| 883 |
# DiT Configs #
|
| 884 |
#############################################################################
|
| 885 |
+
def NextDiT_2B_patch2(**kwargs):
|
| 886 |
+
return NextDiT(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 887 |
patch_size=2, dim=2304, n_layers=24, n_heads=32, **kwargs
|
| 888 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|