| | import signal |
| | import sys |
| |
|
| | |
| | def signal_handler(sig, frame): |
| | print('You pressed Ctrl+C! Exiting...') |
| | sys.exit(0) |
| |
|
| | |
| | signal.signal(signal.SIGINT, signal_handler) |
| |
|
| | from datasets import load_dataset |
| | from transformers import TrainingArguments |
| | from trl import SFTTrainer |
| | import torch |
| | from transformers import AutoTokenizer, AutoModelForCausalLM |
| | from peft import LoraConfig |
| | import math |
| |
|
| | |
| | dataset = load_dataset("HuggingFaceH4/deita-10k-v0-sft", split="train_sft") |
| |
|
| | |
| | model_id = "google/gemma-7b" |
| | tokenizer_id = "philschmid/gemma-tokenizer-chatml" |
| |
|
| | |
| | model = AutoModelForCausalLM.from_pretrained( |
| | model_id, |
| | device_map="auto", |
| | attn_implementation="flash_attention_2", |
| | torch_dtype=torch.bfloat16, |
| | ) |
| | tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) |
| | tokenizer.padding_side = 'right' |
| |
|
| |
|
| | peft_config = LoraConfig( |
| | lora_alpha=16, |
| | lora_dropout=0.05, |
| | r=32, |
| | bias="none", |
| | target_modules="all-linear", |
| | task_type="CAUSAL_LM", |
| | use_dora=False, |
| | ) |
| |
|
| | args = TrainingArguments( |
| | output_dir="./out", |
| | num_train_epochs=3, |
| | per_device_train_batch_size=8, |
| | gradient_checkpointing=True, |
| | optim="adamw_torch_fused", |
| | logging_steps=2, |
| | save_strategy="steps", |
| | save_steps=300, |
| | bf16=True, |
| | tf32=True, |
| | |
| | learning_rate=2e-4, |
| | max_grad_norm=0.3, |
| | warmup_ratio=0.00, |
| | lr_scheduler_type="constant", |
| | report_to="wandb", |
| | push_to_hub=False, |
| | |
| | |
| | ) |
| |
|
| | max_seq_length = 2048 |
| |
|
| | trainer = SFTTrainer( |
| | model=model, |
| | args=args, |
| | train_dataset=dataset, |
| | |
| | peft_config=peft_config, |
| | max_seq_length=max_seq_length, |
| | tokenizer=tokenizer, |
| | packing=False, |
| | dataset_kwargs={ |
| | "add_special_tokens": True, |
| | "append_concat_token": False, |
| | } |
| | ) |
| |
|
| |
|
| | |
| | def gemma_downscale_embeddings(model): |
| | model.model.embed_tokens.weight.data[:] /= torch.tensor(math.sqrt(model.config.hidden_size), dtype=torch.float32) |
| |
|
| | gemma_downscale_embeddings(model) |
| |
|
| | |
| | def patch_layernorm(model): |
| | for module in model.modules(): |
| | if isinstance(module, torch.nn.LayerNorm): |
| | module.forward = lambda x: torch.nn.functional.layer_norm( |
| | x.float(), module.normalized_shape, module.weight.float(), module.bias.float(), module.eps).type_as(x) |
| |
|
| | patch_layernorm(model) |
| |
|
| | |
| | def fix_rope_dtype(model): |
| | for module in model.modules(): |
| | if hasattr(module, 'rope'): |
| | module.rope.cos_cached = module.rope.cos_cached.to(torch.int32) |
| | module.rope.sin_cached = module.rope.sin_cached.to(torch.int32) |
| |
|
| | fix_rope_dtype(model) |
| |
|
| | |
| | def fix_rope_calculation(model): |
| | for module in model.modules(): |
| | if hasattr(module, 'rope'): |
| | module.rope._set_cos_sin_cache = lambda seq_len, device, dtype: _set_cos_sin_cache_fixed( |
| | module.rope, seq_len, device, dtype) |
| |
|
| | def _set_cos_sin_cache_fixed(self, seq_len, device, dtype): |
| | self.max_seq_len_cached = seq_len |
| | t = torch.arange(seq_len, device=device, dtype=torch.int32).float() |
| | freqs = torch.div(t[:, None], self.freq_base ** torch.arange(0, self.dim, 2) / self.dim) |
| | emb = torch.stack((freqs, freqs), dim=-1).reshape(seq_len, -1) |
| | self.cos_cached = emb.cos()[None, None, :, :].to(dtype) |
| | self.sin_cached = emb.sin()[None, None, :, :].to(dtype) |
| |
|
| | fix_rope_calculation(model) |
| |
|
| | |
| | def cast_rope_to_float32(model): |
| | for module in model.modules(): |
| | if hasattr(module, 'rope'): |
| | module.rope.cos_cached = module.rope.cos_cached.float() |
| | module.rope.sin_cached = module.rope.sin_cached.float() |
| |
|
| | cast_rope_to_float32(model) |
| |
|
| | |
| | def fix_gelu(model): |
| | for module in model.modules(): |
| | if isinstance(module, torch.nn.GELU): |
| | module._approximate = True |
| |
|
| | fix_gelu(model) |
| |
|
| | |
| | trainer.train() |
| |
|
| | |
| | trainer.save_model() |