Spaces:
Paused
Paused
Update train.py
Browse files
train.py
CHANGED
|
@@ -3,18 +3,22 @@ import os
|
|
| 3 |
import torch
|
| 4 |
import trl
|
| 5 |
|
| 6 |
-
from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM, TrainingArguments, PreTrainedTokenizerFast
|
| 7 |
from datasets import load_dataset
|
| 8 |
from tokenizers import ByteLevelBPETokenizer
|
| 9 |
|
| 10 |
MAX_SEQ_LENGTH = 128
|
| 11 |
-
BATCH_SIZE =
|
| 12 |
EPOCHS = 10
|
| 13 |
-
LEARNING_RATE =
|
| 14 |
FACTOR = 4
|
| 15 |
VOCAB_SIZE = 32000
|
| 16 |
INPUT_DATASET = "nroggendorff/oak"
|
| 17 |
OUTPUT_REPO = "smallama"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
PUSH_TO_HUB = True
|
| 19 |
|
| 20 |
def load_data():
|
|
@@ -94,8 +98,21 @@ def train_model(model, tokenizer, dataset, push):
|
|
| 94 |
num_train_epochs=EPOCHS,
|
| 95 |
per_device_train_batch_size=BATCH_SIZE,
|
| 96 |
learning_rate=LEARNING_RATE,
|
| 97 |
-
optim="
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
dataset = dataset.map(lambda examples: format_prompts(examples, tokenizer), batched=True)
|
| 100 |
trainer = trl.SFTTrainer(
|
| 101 |
model=model,
|
|
|
|
| 3 |
import torch
|
| 4 |
import trl
|
| 5 |
|
| 6 |
+
from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM, TrainingArguments, PreTrainedTokenizerFast, AdamW, get_linear_schedule_with_warmup
|
| 7 |
from datasets import load_dataset
|
| 8 |
from tokenizers import ByteLevelBPETokenizer
|
| 9 |
|
| 10 |
MAX_SEQ_LENGTH = 128
|
| 11 |
+
BATCH_SIZE = 512
|
| 12 |
EPOCHS = 10
|
| 13 |
+
LEARNING_RATE = 2e-5
|
| 14 |
FACTOR = 4
|
| 15 |
VOCAB_SIZE = 32000
|
| 16 |
INPUT_DATASET = "nroggendorff/oak"
|
| 17 |
OUTPUT_REPO = "smallama"
|
| 18 |
+
FP16 = True
|
| 19 |
+
WARMUP_STEPS = 500
|
| 20 |
+
DECAY = 0.01
|
| 21 |
+
GRADIENT_ACCUMILATION_STEPS = 4
|
| 22 |
PUSH_TO_HUB = True
|
| 23 |
|
| 24 |
def load_data():
|
|
|
|
| 98 |
num_train_epochs=EPOCHS,
|
| 99 |
per_device_train_batch_size=BATCH_SIZE,
|
| 100 |
learning_rate=LEARNING_RATE,
|
| 101 |
+
optim="adamw_torch",
|
| 102 |
+
warmup_steps=WARMUP_STEPS,
|
| 103 |
+
weight_decay=DECAY,
|
| 104 |
+
gradient_accumulation_steps=GRADIENT_ACCUMILATION_STEPS,
|
| 105 |
+
fp16=True,
|
| 106 |
+
evaluation_strategy="steps"
|
| 107 |
)
|
| 108 |
+
|
| 109 |
+
optimizer = AdamW(model.parameters(), lr=args.learning_rate)
|
| 110 |
+
scheduler = get_linear_schedule_with_warmup(
|
| 111 |
+
optimizer,
|
| 112 |
+
num_warmup_steps=args.warmup_steps,
|
| 113 |
+
num_training_steps=len(dataset) * args.num_train_epochs // args.gradient_accumulation_steps
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
dataset = dataset.map(lambda examples: format_prompts(examples, tokenizer), batched=True)
|
| 117 |
trainer = trl.SFTTrainer(
|
| 118 |
model=model,
|