Spaces:
Runtime error
Runtime error
fix(config): set min/max for generation
Browse files
dalle_mini/model/configuration.py
CHANGED
|
@@ -85,8 +85,6 @@ class DalleBartConfig(PretrainedConfig):
|
|
| 85 |
self.scale_embedding = (
|
| 86 |
scale_embedding # scale factor will be sqrt(d_model) if True
|
| 87 |
)
|
| 88 |
-
self.min_length = image_length + 1
|
| 89 |
-
self.max_length = image_length + 1
|
| 90 |
|
| 91 |
# remove inferred keys to prevent errors when loading config (passed as kwargs)
|
| 92 |
for k in [
|
|
@@ -94,6 +92,8 @@ class DalleBartConfig(PretrainedConfig):
|
|
| 94 |
"bos_token_id",
|
| 95 |
"eos_token_id",
|
| 96 |
"decoder_start_token_id",
|
|
|
|
|
|
|
| 97 |
]:
|
| 98 |
kwargs.pop(k, None)
|
| 99 |
|
|
@@ -106,6 +106,8 @@ class DalleBartConfig(PretrainedConfig):
|
|
| 106 |
decoder_start_token_id=image_vocab_size, # BOS appended to vocab
|
| 107 |
forced_eos_token_id=forced_eos_token_id,
|
| 108 |
tie_word_embeddings=tie_word_embeddings,
|
|
|
|
|
|
|
| 109 |
**kwargs,
|
| 110 |
)
|
| 111 |
|
|
|
|
| 85 |
self.scale_embedding = (
|
| 86 |
scale_embedding # scale factor will be sqrt(d_model) if True
|
| 87 |
)
|
|
|
|
|
|
|
| 88 |
|
| 89 |
# remove inferred keys to prevent errors when loading config (passed as kwargs)
|
| 90 |
for k in [
|
|
|
|
| 92 |
"bos_token_id",
|
| 93 |
"eos_token_id",
|
| 94 |
"decoder_start_token_id",
|
| 95 |
+
"min_length",
|
| 96 |
+
"max_length",
|
| 97 |
]:
|
| 98 |
kwargs.pop(k, None)
|
| 99 |
|
|
|
|
| 106 |
decoder_start_token_id=image_vocab_size, # BOS appended to vocab
|
| 107 |
forced_eos_token_id=forced_eos_token_id,
|
| 108 |
tie_word_embeddings=tie_word_embeddings,
|
| 109 |
+
min_length=image_length + 1,
|
| 110 |
+
max_length=image_length + 1,
|
| 111 |
**kwargs,
|
| 112 |
)
|
| 113 |
|
dalle_mini/model/modeling.py
CHANGED
|
@@ -46,6 +46,8 @@ from transformers.models.bart.modeling_flax_bart import (
|
|
| 46 |
FlaxBartForConditionalGeneration,
|
| 47 |
)
|
| 48 |
|
|
|
|
|
|
|
| 49 |
logger = logging.get_logger(__name__)
|
| 50 |
|
| 51 |
|
|
@@ -296,8 +298,11 @@ class FlaxBartPreTrainedModel(FlaxBartPreTrainedModel):
|
|
| 296 |
"""
|
| 297 |
Edits:
|
| 298 |
- added num_params property
|
|
|
|
| 299 |
"""
|
| 300 |
|
|
|
|
|
|
|
| 301 |
@property
|
| 302 |
def num_params(self):
|
| 303 |
num_params = jax.tree_map(
|
|
|
|
| 46 |
FlaxBartForConditionalGeneration,
|
| 47 |
)
|
| 48 |
|
| 49 |
+
from .configuration import DalleBartConfig
|
| 50 |
+
|
| 51 |
logger = logging.get_logger(__name__)
|
| 52 |
|
| 53 |
|
|
|
|
| 298 |
"""
|
| 299 |
Edits:
|
| 300 |
- added num_params property
|
| 301 |
+
- config_class replaced to DalleBartConfig
|
| 302 |
"""
|
| 303 |
|
| 304 |
+
config_class = DalleBartConfig
|
| 305 |
+
|
| 306 |
@property
|
| 307 |
def num_params(self):
|
| 308 |
num_params = jax.tree_map(
|