Spaces:
Paused
Paused
Update train.py
Browse files
train.py
CHANGED
|
@@ -6,6 +6,7 @@ import trl
|
|
| 6 |
from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM, TrainingArguments, PreTrainedTokenizerFast, AdamW, get_linear_schedule_with_warmup
|
| 7 |
from datasets import load_dataset
|
| 8 |
from tokenizers import ByteLevelBPETokenizer
|
|
|
|
| 9 |
|
| 10 |
MAX_SEQ_LENGTH = 512
|
| 11 |
BATCH_SIZE = 32
|
|
@@ -22,6 +23,8 @@ GRADIENT_ACCUMULATION_STEPS = 8
|
|
| 22 |
CLIPPING = 1.0
|
| 23 |
PUSH_TO_HUB = True
|
| 24 |
|
|
|
|
|
|
|
| 25 |
def load_data():
|
| 26 |
dataset = load_dataset(INPUT_DATASET, split="train")#.select(range(int(2e+4)))
|
| 27 |
return dataset
|
|
@@ -124,6 +127,11 @@ def train_model(model, tokenizer, dataset, push):
|
|
| 124 |
max_seq_length=MAX_SEQ_LENGTH,
|
| 125 |
optimizers=(optimizer, scheduler)
|
| 126 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
trainer.train()
|
| 128 |
|
| 129 |
trained_model = trainer.model
|
|
|
|
| 6 |
from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM, TrainingArguments, PreTrainedTokenizerFast, AdamW, get_linear_schedule_with_warmup
|
| 7 |
from datasets import load_dataset
|
| 8 |
from tokenizers import ByteLevelBPETokenizer
|
| 9 |
+
from accelerate import Accelerator
|
| 10 |
|
| 11 |
MAX_SEQ_LENGTH = 512
|
| 12 |
BATCH_SIZE = 32
|
|
|
|
| 23 |
CLIPPING = 1.0
|
| 24 |
PUSH_TO_HUB = True
|
| 25 |
|
| 26 |
+
accelerator = Accelerator()
|
| 27 |
+
|
| 28 |
def load_data():
|
| 29 |
dataset = load_dataset(INPUT_DATASET, split="train")#.select(range(int(2e+4)))
|
| 30 |
return dataset
|
|
|
|
| 127 |
max_seq_length=MAX_SEQ_LENGTH,
|
| 128 |
optimizers=(optimizer, scheduler)
|
| 129 |
)
|
| 130 |
+
|
| 131 |
+
model, optimizer = accelerator.prepare(model, optimizer)
|
| 132 |
+
trainer.model = model
|
| 133 |
+
trainer.optimizer = optimizer
|
| 134 |
+
trainer = accelerator.prepare(trainer)
|
| 135 |
trainer.train()
|
| 136 |
|
| 137 |
trained_model = trainer.model
|