GemmaBugFix-TRL / train.py
Crystalcareai's picture
Update train.py
936c87d verified
import signal
import sys
# Signal handler function
def signal_handler(sig, frame):
print('You pressed Ctrl+C! Exiting...')
sys.exit(0)
# Register signal handler
signal.signal(signal.SIGINT, signal_handler)
from datasets import load_dataset
from transformers import TrainingArguments
from trl import SFTTrainer
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig
import math
# Load jsonl data from HF or disk
dataset = load_dataset("HuggingFaceH4/deita-10k-v0-sft", split="train_sft")
# Hugging Face model id
model_id = "google/gemma-7b"
tokenizer_id = "philschmid/gemma-tokenizer-chatml"
# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
tokenizer.padding_side = 'right' # to prevent warnings
peft_config = LoraConfig(
lora_alpha=16,
lora_dropout=0.05,
r=32,
bias="none",
target_modules="all-linear",
task_type="CAUSAL_LM",
use_dora=False, # Enable Dora method
)
args = TrainingArguments(
output_dir="./out", # directory to save and repository id
num_train_epochs=3, # number of training epochs
per_device_train_batch_size=8, # batch size per device during training
gradient_checkpointing=True, # use gradient checkpointing to save memory
optim="adamw_torch_fused",
logging_steps=2,
save_strategy="steps",
save_steps=300,
bf16=True, # use bfloat16 precision
tf32=True, # use tf32 precision
### peft specific arguments ###
learning_rate=2e-4,
max_grad_norm=0.3,
warmup_ratio=0.00,
lr_scheduler_type="constant",
report_to="wandb",
push_to_hub=False,
# push model to hub
)
max_seq_length = 2048 # max sequence length for model and packing of the dataset
trainer = SFTTrainer(
model=model,
args=args,
train_dataset=dataset,
### peft specific arguments ###
peft_config=peft_config,
max_seq_length=max_seq_length,
tokenizer=tokenizer,
packing=False,
dataset_kwargs={
"add_special_tokens": True, # <bos> and <eos> should be part of the dataset.
"append_concat_token": False, # make sure to not add additional tokens when packing
}
)
# 3. Use float32 for sqrt(3072) and sqrt(2048) calculations
def gemma_downscale_embeddings(model):
model.model.embed_tokens.weight.data[:] /= torch.tensor(math.sqrt(model.config.hidden_size), dtype=torch.float32)
gemma_downscale_embeddings(model)
# 4. Upcast Layernorm to float32
def patch_layernorm(model):
for module in model.modules():
if isinstance(module, torch.nn.LayerNorm):
module.forward = lambda x: torch.nn.functional.layer_norm(
x.float(), module.normalized_shape, module.weight.float(), module.bias.float(), module.eps).type_as(x)
patch_layernorm(model)
# 5. Fix Keras mixed_bfloat16 RoPE issue
def fix_rope_dtype(model):
for module in model.modules():
if hasattr(module, 'rope'):
module.rope.cos_cached = module.rope.cos_cached.to(torch.int32)
module.rope.sin_cached = module.rope.sin_cached.to(torch.int32)
fix_rope_dtype(model)
# 6. Use division instead of multiplication for RoPE calculation
def fix_rope_calculation(model):
for module in model.modules():
if hasattr(module, 'rope'):
module.rope._set_cos_sin_cache = lambda seq_len, device, dtype: _set_cos_sin_cache_fixed(
module.rope, seq_len, device, dtype)
def _set_cos_sin_cache_fixed(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(seq_len, device=device, dtype=torch.int32).float()
freqs = torch.div(t[:, None], self.freq_base ** torch.arange(0, self.dim, 2) / self.dim)
emb = torch.stack((freqs, freqs), dim=-1).reshape(seq_len, -1)
self.cos_cached = emb.cos()[None, None, :, :].to(dtype)
self.sin_cached = emb.sin()[None, None, :, :].to(dtype)
fix_rope_calculation(model)
# 7. Use float32 for RoPE
def cast_rope_to_float32(model):
for module in model.modules():
if hasattr(module, 'rope'):
module.rope.cos_cached = module.rope.cos_cached.float()
module.rope.sin_cached = module.rope.sin_cached.float()
cast_rope_to_float32(model)
# 8. Use approximate tanh instead of exact for GELU
def fix_gelu(model):
for module in model.modules():
if isinstance(module, torch.nn.GELU):
module._approximate = True
fix_gelu(model)
# start training, the model will be automatically saved to the hub and the output directory
trainer.train()
# save model
trainer.save_model()