aapot commited on
Commit ·
0b67ff4
1
Parent(s): 916632f
Update optimizers
Browse files- EasyLM/data.py +7 -0
- EasyLM/optimizers.py +47 -3
- pretrain_llama_3b.sh +2 -1
EasyLM/data.py
CHANGED
|
@@ -153,6 +153,7 @@ class HuggingfaceDataset(object):
|
|
| 153 |
config.start_seek_loc = 0
|
| 154 |
config.tokens_count_at_start = 0
|
| 155 |
config.batch_token_dtype = 'i4'
|
|
|
|
| 156 |
|
| 157 |
if updates is not None:
|
| 158 |
config.update(ConfigDict(updates).copy_and_resolve_references())
|
|
@@ -173,6 +174,8 @@ class HuggingfaceDataset(object):
|
|
| 173 |
self._dataset_loc = self.config.start_seek_loc
|
| 174 |
self._total_tokens = self.config.tokens_count_at_start
|
| 175 |
self._index = 0
|
|
|
|
|
|
|
| 176 |
|
| 177 |
def __iter__(self):
|
| 178 |
if not self._eval_dataset and self._train_epochs > 0:
|
|
@@ -236,6 +239,10 @@ class HuggingfaceDataset(object):
|
|
| 236 |
self._dataset_loc = state_dict.get('dataset_loc', self.config.start_seek_loc)
|
| 237 |
self._total_tokens = state_dict.get('total_tokens', self.config.tokens_count_at_start)
|
| 238 |
self._train_epochs = state_dict.get('epochs', 0)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
|
| 240 |
@property
|
| 241 |
def seq_length(self):
|
|
|
|
| 153 |
config.start_seek_loc = 0
|
| 154 |
config.tokens_count_at_start = 0
|
| 155 |
config.batch_token_dtype = 'i4'
|
| 156 |
+
config.reset_dataset_loc = False
|
| 157 |
|
| 158 |
if updates is not None:
|
| 159 |
config.update(ConfigDict(updates).copy_and_resolve_references())
|
|
|
|
| 174 |
self._dataset_loc = self.config.start_seek_loc
|
| 175 |
self._total_tokens = self.config.tokens_count_at_start
|
| 176 |
self._index = 0
|
| 177 |
+
self.reset_dataset_loc = self.config.reset_dataset_loc
|
| 178 |
+
|
| 179 |
|
| 180 |
def __iter__(self):
|
| 181 |
if not self._eval_dataset and self._train_epochs > 0:
|
|
|
|
| 239 |
self._dataset_loc = state_dict.get('dataset_loc', self.config.start_seek_loc)
|
| 240 |
self._total_tokens = state_dict.get('total_tokens', self.config.tokens_count_at_start)
|
| 241 |
self._train_epochs = state_dict.get('epochs', 0)
|
| 242 |
+
if self.reset_dataset_loc:
|
| 243 |
+
self._dataset_loc = 0
|
| 244 |
+
self._train_epochs = 0
|
| 245 |
+
|
| 246 |
|
| 247 |
@property
|
| 248 |
def seq_length(self):
|
EasyLM/optimizers.py
CHANGED
|
@@ -205,8 +205,9 @@ class LionOptimizerFactory(object):
|
|
| 205 |
config.init_lr = 0.0
|
| 206 |
config.end_lr = 0.0001
|
| 207 |
config.lr = 0.001
|
| 208 |
-
config.lr_warmup_steps =
|
| 209 |
-
config.
|
|
|
|
| 210 |
config.b1 = 0.9
|
| 211 |
config.b2 = 0.98
|
| 212 |
config.clip_gradient = 1.0
|
|
@@ -243,6 +244,43 @@ class LionOptimizerFactory(object):
|
|
| 243 |
],
|
| 244 |
[config.lr_warmup_steps],
|
| 245 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
elif config.lr_schedule_type == "exponential_decay":
|
| 247 |
learning_rate_schedule = optax.exponential_decay(
|
| 248 |
init_value=config.lr,
|
|
@@ -252,8 +290,14 @@ class LionOptimizerFactory(object):
|
|
| 252 |
staircase=False,
|
| 253 |
end_value=config.end_lr,
|
| 254 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
else:
|
| 256 |
-
raise ValueError('config.lr_schedule_type must be "warmup_cosine_decay_schedule", "warmup_constant",
|
| 257 |
|
| 258 |
optimizer_info = dict(
|
| 259 |
learning_rate_schedule=learning_rate_schedule,
|
|
|
|
| 205 |
config.init_lr = 0.0
|
| 206 |
config.end_lr = 0.0001
|
| 207 |
config.lr = 0.001
|
| 208 |
+
config.lr_warmup_steps = 60000
|
| 209 |
+
config.lr_constant_steps = 840000
|
| 210 |
+
config.lr_decay_steps = 100000
|
| 211 |
config.b1 = 0.9
|
| 212 |
config.b2 = 0.98
|
| 213 |
config.clip_gradient = 1.0
|
|
|
|
| 244 |
],
|
| 245 |
[config.lr_warmup_steps],
|
| 246 |
)
|
| 247 |
+
elif config.lr_schedule_type == "warmup_constant_linear_decay":
|
| 248 |
+
learning_rate_schedule = optax.join_schedules(
|
| 249 |
+
[
|
| 250 |
+
optax.linear_schedule(
|
| 251 |
+
init_value=config.init_lr,
|
| 252 |
+
end_value=config.lr,
|
| 253 |
+
transition_steps=config.lr_warmup_steps,
|
| 254 |
+
),
|
| 255 |
+
optax.constant_schedule(config.lr),
|
| 256 |
+
optax.linear_schedule(
|
| 257 |
+
init_value=config.lr,
|
| 258 |
+
end_value=config.end_lr,
|
| 259 |
+
transition_steps=config.lr_decay_steps,
|
| 260 |
+
)
|
| 261 |
+
],
|
| 262 |
+
[config.lr_warmup_steps, config.lr_constant_steps],
|
| 263 |
+
)
|
| 264 |
+
elif config.lr_schedule_type == "warmup_constant_exponential_decay":
|
| 265 |
+
learning_rate_schedule = optax.join_schedules(
|
| 266 |
+
[
|
| 267 |
+
optax.linear_schedule(
|
| 268 |
+
init_value=config.init_lr,
|
| 269 |
+
end_value=config.lr,
|
| 270 |
+
transition_steps=config.lr_warmup_steps,
|
| 271 |
+
),
|
| 272 |
+
optax.constant_schedule(config.lr),
|
| 273 |
+
optax.exponential_decay(
|
| 274 |
+
init_value=config.lr,
|
| 275 |
+
transition_steps=config.lr_decay_steps,
|
| 276 |
+
decay_rate=config.lr_decay_rate,
|
| 277 |
+
transition_begin=0,
|
| 278 |
+
staircase=False,
|
| 279 |
+
end_value=config.end_lr,
|
| 280 |
+
)
|
| 281 |
+
],
|
| 282 |
+
[config.lr_warmup_steps, config.lr_constant_steps],
|
| 283 |
+
)
|
| 284 |
elif config.lr_schedule_type == "exponential_decay":
|
| 285 |
learning_rate_schedule = optax.exponential_decay(
|
| 286 |
init_value=config.lr,
|
|
|
|
| 290 |
staircase=False,
|
| 291 |
end_value=config.end_lr,
|
| 292 |
)
|
| 293 |
+
elif config.lr_schedule_type == "linear_decay":
|
| 294 |
+
learning_rate_schedule = optax.linear_schedule(
|
| 295 |
+
init_value=config.lr,
|
| 296 |
+
end_value=config.end_lr,
|
| 297 |
+
transition_steps=config.lr_decay_steps,
|
| 298 |
+
)
|
| 299 |
else:
|
| 300 |
+
raise ValueError('config.lr_schedule_type must be "warmup_cosine_decay_schedule", "warmup_constant", "warmup_constant_linear_decay", "warmup_constant_exponential_decay", "exponential_decay" or "linear_decay"')
|
| 301 |
|
| 302 |
optimizer_info = dict(
|
| 303 |
learning_rate_schedule=learning_rate_schedule,
|
pretrain_llama_3b.sh
CHANGED
|
@@ -23,10 +23,11 @@ python3 -m EasyLM.models.llama.llama_train \
|
|
| 23 |
--tokenizer.vocab_file='tokenizer.model' \
|
| 24 |
--optimizer.type='lion' \
|
| 25 |
--optimizer.lion_optimizer.weight_decay=1.0 \
|
| 26 |
-
--optimizer.lion_optimizer.lr_schedule_type='
|
| 27 |
--optimizer.lion_optimizer.lr=1e-4 \
|
| 28 |
--optimizer.lion_optimizer.end_lr=1e-5 \
|
| 29 |
--optimizer.lion_optimizer.lr_warmup_steps=60000 \
|
|
|
|
| 30 |
--optimizer.lion_optimizer.lr_decay_steps=100000 \
|
| 31 |
--optimizer.lion_optimizer.bf16_momentum=True \
|
| 32 |
--train_dataset.type='huggingface' \
|
|
|
|
| 23 |
--tokenizer.vocab_file='tokenizer.model' \
|
| 24 |
--optimizer.type='lion' \
|
| 25 |
--optimizer.lion_optimizer.weight_decay=1.0 \
|
| 26 |
+
--optimizer.lion_optimizer.lr_schedule_type='warmup_constant_linear_decay' \
|
| 27 |
--optimizer.lion_optimizer.lr=1e-4 \
|
| 28 |
--optimizer.lion_optimizer.end_lr=1e-5 \
|
| 29 |
--optimizer.lion_optimizer.lr_warmup_steps=60000 \
|
| 30 |
+
--optimizer.lion_optimizer.lr_constant_steps=900000 \
|
| 31 |
--optimizer.lion_optimizer.lr_decay_steps=100000 \
|
| 32 |
--optimizer.lion_optimizer.bf16_momentum=True \
|
| 33 |
--train_dataset.type='huggingface' \
|