|
|
import os |
|
|
from datasets import load_from_disk |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer |
|
|
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training |
|
|
import torch |
|
|
|
|
|
model_name = "mergekit-community/Qwen-2.5-Coder" |
|
|
out_dir = "D:\\out_peft" |
|
|
|
|
|
os.environ['HF_HOME'] = 'D:\\huggingface_cache' |
|
|
|
|
|
|
|
|
ds = load_from_disk("processed_ds") |
|
|
train_ds = ds["train"] |
|
|
eval_ds = ds["test"] |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) |
|
|
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", load_in_4bit=True) |
|
|
|
|
|
|
|
|
model = prepare_model_for_kbit_training(model) |
|
|
lora_config = LoraConfig( |
|
|
r=8, lora_alpha=32, target_modules=["q_proj","v_proj","k_proj","o_proj"], |
|
|
lora_dropout=0.1, bias="none", task_type="CAUSAL_LM" |
|
|
) |
|
|
model = get_peft_model(model, lora_config) |
|
|
|
|
|
|
|
|
def tokenize_fn(batch): |
|
|
inputs = [a + tokenizer.eos_token + b for a,b in zip(batch["input_text"], batch["target_text"])] |
|
|
out = tokenizer(inputs, truncation=True, padding="max_length", max_length=1024) |
|
|
out["labels"] = out["input_ids"].copy() |
|
|
return out |
|
|
|
|
|
train_ds = train_ds.map(tokenize_fn, batched=True, remove_columns=train_ds.column_names) |
|
|
eval_ds = eval_ds.map(tokenize_fn, batched=True, remove_columns=eval_ds.column_names) |
|
|
|
|
|
training_args = TrainingArguments( |
|
|
output_dir=out_dir, |
|
|
per_device_train_batch_size=1, |
|
|
gradient_accumulation_steps=8, |
|
|
num_train_epochs=3, |
|
|
learning_rate=2e-4, |
|
|
fp16=True, |
|
|
logging_steps=50, |
|
|
save_total_limit=2, |
|
|
optim="paged_adamw_8bit" |
|
|
) |
|
|
|
|
|
trainer = Trainer(model=model, args=training_args, train_dataset=train_ds, eval_dataset=eval_ds) |
|
|
trainer.train() |
|
|
|
|
|
|
|
|
model.save_pretrained(out_dir) |
|
|
tokenizer.save_pretrained(out_dir) |
|
|
print("Saved PEFT to", out_dir) |