Spaces:
Runtime error
Runtime error
feat(model): allow bias (#152)
Browse files
src/dalle_mini/__init__.py
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
-
__version__ = "0.0.
|
| 2 |
|
| 3 |
from .model import DalleBart, DalleBartProcessor
|
|
|
|
| 1 |
+
__version__ = "0.0.4"
|
| 2 |
|
| 3 |
from .model import DalleBart, DalleBartProcessor
|
src/dalle_mini/model/configuration.py
CHANGED
|
@@ -58,6 +58,7 @@ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
|
|
| 58 |
tie_word_embeddings=False, # different modalities and sizes
|
| 59 |
do_sample=True,
|
| 60 |
# transformer variants
|
|
|
|
| 61 |
ln_type="layernorm", # layer normalization type, "rmsnorm", "layernorm"
|
| 62 |
ln_positions="normformer", # layer normalization positions, "normformer", "swinv2", "cogview", "postln", "preln", "deepnet" (same as postln)
|
| 63 |
use_head_scale=False, # used in NormFormer
|
|
@@ -65,7 +66,7 @@ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
|
|
| 65 |
tau_init=0.05, # used only in cosine attention (Swin v2)
|
| 66 |
use_deepnet_scaling=False, # used in Deepnet
|
| 67 |
use_glu=False, # "GLU Variants Improve Transformer"
|
| 68 |
-
use_alibi=False, # from "Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation"
|
| 69 |
sinkhorn_iters=1, # used in SinkFormers
|
| 70 |
use_final_ln_encoder=False, # final layer normalization in encoder
|
| 71 |
use_final_ln_decoder=False, # final layer normalization in decoder
|
|
@@ -77,7 +78,7 @@ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
|
|
| 77 |
self.normalize_text = normalize_text
|
| 78 |
|
| 79 |
# transformer variants
|
| 80 |
-
self.
|
| 81 |
assert ln_type in [
|
| 82 |
"rmsnorm",
|
| 83 |
"layernorm",
|
|
@@ -92,6 +93,7 @@ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
|
|
| 92 |
"postln",
|
| 93 |
"preln",
|
| 94 |
], "ln_positions must be 'normformer', 'swinv2', 'cogview', 'postln', 'preln'"
|
|
|
|
| 95 |
assert use_alibi is False, "use_alibi is not supported yet"
|
| 96 |
self.ln_positions = ln_positions
|
| 97 |
self.use_cosine_attention = use_cosine_attention
|
|
|
|
| 58 |
tie_word_embeddings=False, # different modalities and sizes
|
| 59 |
do_sample=True,
|
| 60 |
# transformer variants
|
| 61 |
+
use_bias=False, # use bias in attention and dense layers (except for lm_head)
|
| 62 |
ln_type="layernorm", # layer normalization type, "rmsnorm", "layernorm"
|
| 63 |
ln_positions="normformer", # layer normalization positions, "normformer", "swinv2", "cogview", "postln", "preln", "deepnet" (same as postln)
|
| 64 |
use_head_scale=False, # used in NormFormer
|
|
|
|
| 66 |
tau_init=0.05, # used only in cosine attention (Swin v2)
|
| 67 |
use_deepnet_scaling=False, # used in Deepnet
|
| 68 |
use_glu=False, # "GLU Variants Improve Transformer"
|
| 69 |
+
use_alibi=False, # Not implemented yet - from "Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation"
|
| 70 |
sinkhorn_iters=1, # used in SinkFormers
|
| 71 |
use_final_ln_encoder=False, # final layer normalization in encoder
|
| 72 |
use_final_ln_decoder=False, # final layer normalization in decoder
|
|
|
|
| 78 |
self.normalize_text = normalize_text
|
| 79 |
|
| 80 |
# transformer variants
|
| 81 |
+
self.use_bias = use_bias
|
| 82 |
assert ln_type in [
|
| 83 |
"rmsnorm",
|
| 84 |
"layernorm",
|
|
|
|
| 93 |
"postln",
|
| 94 |
"preln",
|
| 95 |
], "ln_positions must be 'normformer', 'swinv2', 'cogview', 'postln', 'preln'"
|
| 96 |
+
self.use_head_scale = use_head_scale
|
| 97 |
assert use_alibi is False, "use_alibi is not supported yet"
|
| 98 |
self.ln_positions = ln_positions
|
| 99 |
self.use_cosine_attention = use_cosine_attention
|
src/dalle_mini/model/modeling.py
CHANGED
|
@@ -444,7 +444,7 @@ class GLU(nn.Module):
|
|
| 444 |
w = nn.Dense(
|
| 445 |
self.ffn_dim,
|
| 446 |
dtype=self.dtype,
|
| 447 |
-
use_bias=
|
| 448 |
kernel_init=deepnet_init(gain)
|
| 449 |
if self.config.use_deepnet_scaling
|
| 450 |
else jax.nn.initializers.normal(self.config.init_std),
|
|
@@ -453,7 +453,7 @@ class GLU(nn.Module):
|
|
| 453 |
v = nn.Dense(
|
| 454 |
self.ffn_dim,
|
| 455 |
dtype=self.dtype,
|
| 456 |
-
use_bias=
|
| 457 |
kernel_init=deepnet_init(gain)
|
| 458 |
if self.config.use_deepnet_scaling
|
| 459 |
else jax.nn.initializers.normal(self.config.init_std),
|
|
@@ -473,7 +473,7 @@ class GLU(nn.Module):
|
|
| 473 |
x = nn.Dense(
|
| 474 |
self.embed_dim,
|
| 475 |
dtype=self.dtype,
|
| 476 |
-
use_bias=
|
| 477 |
kernel_init=deepnet_init(gain)
|
| 478 |
if self.config.use_deepnet_scaling
|
| 479 |
else jax.nn.initializers.normal(self.config.init_std),
|
|
@@ -509,7 +509,7 @@ class FFN(nn.Module):
|
|
| 509 |
x = nn.Dense(
|
| 510 |
self.ffn_dim,
|
| 511 |
dtype=self.dtype,
|
| 512 |
-
use_bias=
|
| 513 |
kernel_init=deepnet_init(gain)
|
| 514 |
if self.config.use_deepnet_scaling
|
| 515 |
else jax.nn.initializers.normal(self.config.init_std),
|
|
@@ -528,7 +528,7 @@ class FFN(nn.Module):
|
|
| 528 |
x = nn.Dense(
|
| 529 |
self.embed_dim,
|
| 530 |
dtype=self.dtype,
|
| 531 |
-
use_bias=
|
| 532 |
kernel_init=deepnet_init(gain)
|
| 533 |
if self.config.use_deepnet_scaling
|
| 534 |
else jax.nn.initializers.normal(self.config.init_std),
|
|
@@ -580,7 +580,7 @@ class FlaxBartEncoderLayer(nn.Module):
|
|
| 580 |
embed_dim=embed_dim,
|
| 581 |
num_heads=self.config.encoder_attention_heads,
|
| 582 |
dropout=self.config.attention_dropout,
|
| 583 |
-
bias=
|
| 584 |
dtype=self.dtype,
|
| 585 |
is_encoder=True,
|
| 586 |
)(hidden_states=hidden_states, attention_mask=attention_mask)
|
|
@@ -686,7 +686,7 @@ class FlaxBartDecoderLayer(nn.Module):
|
|
| 686 |
num_heads=self.config.decoder_attention_heads,
|
| 687 |
dropout=self.config.attention_dropout,
|
| 688 |
causal=True,
|
| 689 |
-
bias=
|
| 690 |
dtype=self.dtype,
|
| 691 |
is_encoder=False,
|
| 692 |
)(
|
|
@@ -724,7 +724,7 @@ class FlaxBartDecoderLayer(nn.Module):
|
|
| 724 |
embed_dim=embed_dim,
|
| 725 |
num_heads=self.config.decoder_attention_heads,
|
| 726 |
dropout=self.config.attention_dropout,
|
| 727 |
-
bias=
|
| 728 |
dtype=self.dtype,
|
| 729 |
is_encoder=False,
|
| 730 |
)(
|
|
|
|
| 444 |
w = nn.Dense(
|
| 445 |
self.ffn_dim,
|
| 446 |
dtype=self.dtype,
|
| 447 |
+
use_bias=self.config.use_bias,
|
| 448 |
kernel_init=deepnet_init(gain)
|
| 449 |
if self.config.use_deepnet_scaling
|
| 450 |
else jax.nn.initializers.normal(self.config.init_std),
|
|
|
|
| 453 |
v = nn.Dense(
|
| 454 |
self.ffn_dim,
|
| 455 |
dtype=self.dtype,
|
| 456 |
+
use_bias=self.config.use_bias,
|
| 457 |
kernel_init=deepnet_init(gain)
|
| 458 |
if self.config.use_deepnet_scaling
|
| 459 |
else jax.nn.initializers.normal(self.config.init_std),
|
|
|
|
| 473 |
x = nn.Dense(
|
| 474 |
self.embed_dim,
|
| 475 |
dtype=self.dtype,
|
| 476 |
+
use_bias=self.config.use_bias,
|
| 477 |
kernel_init=deepnet_init(gain)
|
| 478 |
if self.config.use_deepnet_scaling
|
| 479 |
else jax.nn.initializers.normal(self.config.init_std),
|
|
|
|
| 509 |
x = nn.Dense(
|
| 510 |
self.ffn_dim,
|
| 511 |
dtype=self.dtype,
|
| 512 |
+
use_bias=self.config.use_bias,
|
| 513 |
kernel_init=deepnet_init(gain)
|
| 514 |
if self.config.use_deepnet_scaling
|
| 515 |
else jax.nn.initializers.normal(self.config.init_std),
|
|
|
|
| 528 |
x = nn.Dense(
|
| 529 |
self.embed_dim,
|
| 530 |
dtype=self.dtype,
|
| 531 |
+
use_bias=self.config.use_bias,
|
| 532 |
kernel_init=deepnet_init(gain)
|
| 533 |
if self.config.use_deepnet_scaling
|
| 534 |
else jax.nn.initializers.normal(self.config.init_std),
|
|
|
|
| 580 |
embed_dim=embed_dim,
|
| 581 |
num_heads=self.config.encoder_attention_heads,
|
| 582 |
dropout=self.config.attention_dropout,
|
| 583 |
+
bias=self.config.use_bias,
|
| 584 |
dtype=self.dtype,
|
| 585 |
is_encoder=True,
|
| 586 |
)(hidden_states=hidden_states, attention_mask=attention_mask)
|
|
|
|
| 686 |
num_heads=self.config.decoder_attention_heads,
|
| 687 |
dropout=self.config.attention_dropout,
|
| 688 |
causal=True,
|
| 689 |
+
bias=self.config.use_bias,
|
| 690 |
dtype=self.dtype,
|
| 691 |
is_encoder=False,
|
| 692 |
)(
|
|
|
|
| 724 |
embed_dim=embed_dim,
|
| 725 |
num_heads=self.config.decoder_attention_heads,
|
| 726 |
dropout=self.config.attention_dropout,
|
| 727 |
+
bias=self.config.use_bias,
|
| 728 |
dtype=self.dtype,
|
| 729 |
is_encoder=False,
|
| 730 |
)(
|
tools/train/train.py
CHANGED
|
@@ -49,6 +49,7 @@ from scalable_shampoo.distributed_shampoo import GraftingType, distributed_shamp
|
|
| 49 |
from tqdm import tqdm
|
| 50 |
from transformers import HfArgumentParser
|
| 51 |
|
|
|
|
| 52 |
from dalle_mini.data import Dataset
|
| 53 |
from dalle_mini.model import (
|
| 54 |
DalleBart,
|
|
@@ -675,6 +676,7 @@ def main():
|
|
| 675 |
"transformers": transformers.__version__,
|
| 676 |
"datasets": datasets.__version__,
|
| 677 |
"wandb": wandb.__version__,
|
|
|
|
| 678 |
},
|
| 679 |
}
|
| 680 |
)
|
|
|
|
| 49 |
from tqdm import tqdm
|
| 50 |
from transformers import HfArgumentParser
|
| 51 |
|
| 52 |
+
import dalle_mini
|
| 53 |
from dalle_mini.data import Dataset
|
| 54 |
from dalle_mini.model import (
|
| 55 |
DalleBart,
|
|
|
|
| 676 |
"transformers": transformers.__version__,
|
| 677 |
"datasets": datasets.__version__,
|
| 678 |
"wandb": wandb.__version__,
|
| 679 |
+
"dalle_mini": dalle_mini.__version__,
|
| 680 |
},
|
| 681 |
}
|
| 682 |
)
|