| | |
| | |
| |
|
| | import torch |
| | import torch_xla |
| |
|
| | import torch_xla.core.xla_model as xm |
| |
|
| | from datasets import load_dataset |
| | from peft import LoraConfig, get_peft_model |
| | from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments |
| | from trl import SFTTrainer |
| |
|
| | |
| | device = xm.xla_device() |
| | model_id = "google/gemma-7b" |
| |
|
| | |
| | tokenizer = AutoTokenizer.from_pretrained(model_id) |
| | model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16) |
| |
|
| | |
| | lora_config = LoraConfig( |
| | r=8, |
| | target_modules=["k_proj", "v_proj"], |
| | task_type="CAUSAL_LM", |
| | ) |
| |
|
| | |
| | data = load_dataset("Abirate/english_quotes", split="train") |
| | max_seq_length = 1024 |
| |
|
| | |
| | fsdp_config = {"fsdp_transformer_layer_cls_to_wrap": [ |
| | "GemmaDecoderLayer" |
| | ], |
| | "xla": True, |
| | "xla_fsdp_v2": True, |
| | "xla_fsdp_grad_ckpt": True} |
| |
|
| | |
| | trainer = SFTTrainer( |
| | model=model, |
| | train_dataset=data, |
| | args=TrainingArguments( |
| | per_device_train_batch_size=64, |
| | num_train_epochs=100, |
| | max_steps=-1, |
| | output_dir="./output", |
| | optim="adafactor", |
| | logging_steps=1, |
| | dataloader_drop_last = True, |
| | fsdp="full_shard", |
| | fsdp_config=fsdp_config, |
| | ), |
| | peft_config=lora_config, |
| | dataset_text_field="quote", |
| | max_seq_length=max_seq_length, |
| | packing=True, |
| | ) |
| |
|
| | trainer.train() |