import signal import sys # Signal handler function def signal_handler(sig, frame): print('You pressed Ctrl+C! Exiting...') sys.exit(0) # Register signal handler signal.signal(signal.SIGINT, signal_handler) from datasets import load_dataset from transformers import TrainingArguments from trl import SFTTrainer import torch from transformers import AutoTokenizer, AutoModelForCausalLM from peft import LoraConfig import math # Load jsonl data from HF or disk dataset = load_dataset("HuggingFaceH4/deita-10k-v0-sft", split="train_sft") # Hugging Face model id model_id = "google/gemma-7b" tokenizer_id = "philschmid/gemma-tokenizer-chatml" # Load model and tokenizer model = AutoModelForCausalLM.from_pretrained( model_id, device_map="auto", attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16, ) tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) tokenizer.padding_side = 'right' # to prevent warnings peft_config = LoraConfig( lora_alpha=16, lora_dropout=0.05, r=32, bias="none", target_modules="all-linear", task_type="CAUSAL_LM", use_dora=False, # Enable Dora method ) args = TrainingArguments( output_dir="./out", # directory to save and repository id num_train_epochs=3, # number of training epochs per_device_train_batch_size=8, # batch size per device during training gradient_checkpointing=True, # use gradient checkpointing to save memory optim="adamw_torch_fused", logging_steps=2, save_strategy="steps", save_steps=300, bf16=True, # use bfloat16 precision tf32=True, # use tf32 precision ### peft specific arguments ### learning_rate=2e-4, max_grad_norm=0.3, warmup_ratio=0.00, lr_scheduler_type="constant", report_to="wandb", push_to_hub=False, # push model to hub ) max_seq_length = 2048 # max sequence length for model and packing of the dataset trainer = SFTTrainer( model=model, args=args, train_dataset=dataset, ### peft specific arguments ### peft_config=peft_config, max_seq_length=max_seq_length, tokenizer=tokenizer, packing=False, dataset_kwargs={ "add_special_tokens": True, # and should be part of the dataset. "append_concat_token": False, # make sure to not add additional tokens when packing } ) # 3. Use float32 for sqrt(3072) and sqrt(2048) calculations def gemma_downscale_embeddings(model): model.model.embed_tokens.weight.data[:] /= torch.tensor(math.sqrt(model.config.hidden_size), dtype=torch.float32) gemma_downscale_embeddings(model) # 4. Upcast Layernorm to float32 def patch_layernorm(model): for module in model.modules(): if isinstance(module, torch.nn.LayerNorm): module.forward = lambda x: torch.nn.functional.layer_norm( x.float(), module.normalized_shape, module.weight.float(), module.bias.float(), module.eps).type_as(x) patch_layernorm(model) # 5. Fix Keras mixed_bfloat16 RoPE issue def fix_rope_dtype(model): for module in model.modules(): if hasattr(module, 'rope'): module.rope.cos_cached = module.rope.cos_cached.to(torch.int32) module.rope.sin_cached = module.rope.sin_cached.to(torch.int32) fix_rope_dtype(model) # 6. Use division instead of multiplication for RoPE calculation def fix_rope_calculation(model): for module in model.modules(): if hasattr(module, 'rope'): module.rope._set_cos_sin_cache = lambda seq_len, device, dtype: _set_cos_sin_cache_fixed( module.rope, seq_len, device, dtype) def _set_cos_sin_cache_fixed(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len t = torch.arange(seq_len, device=device, dtype=torch.int32).float() freqs = torch.div(t[:, None], self.freq_base ** torch.arange(0, self.dim, 2) / self.dim) emb = torch.stack((freqs, freqs), dim=-1).reshape(seq_len, -1) self.cos_cached = emb.cos()[None, None, :, :].to(dtype) self.sin_cached = emb.sin()[None, None, :, :].to(dtype) fix_rope_calculation(model) # 7. Use float32 for RoPE def cast_rope_to_float32(model): for module in model.modules(): if hasattr(module, 'rope'): module.rope.cos_cached = module.rope.cos_cached.float() module.rope.sin_cached = module.rope.sin_cached.float() cast_rope_to_float32(model) # 8. Use approximate tanh instead of exact for GELU def fix_gelu(model): for module in model.modules(): if isinstance(module, torch.nn.GELU): module._approximate = True fix_gelu(model) # start training, the model will be automatically saved to the hub and the output directory trainer.train() # save model trainer.save_model()