Fahad-sha's picture
Upload 5 files
a365d48 verified
import os
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
from trl import SFTTrainer
from peft import LoraConfig
BASE_MODEL = os.environ.get("BASE_MODEL", "Qwen/Qwen2.5-0.5B-Instruct")
def main():
ds = load_dataset("json", data_files="data/sft.jsonl")["train"]
tok = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
device_map="auto",
torch_dtype="auto",
)
peft_cfg = LoraConfig(
r=16, lora_alpha=32, lora_dropout=0.05,
bias="none", task_type="CAUSAL_LM",
target_modules=["q_proj","k_proj","v_proj","o_proj","up_proj","down_proj","gate_proj"]
)
args = TrainingArguments(
output_dir="adapter_sft",
per_device_train_batch_size=2,
gradient_accumulation_steps=8,
learning_rate=2e-4,
num_train_epochs=1,
logging_steps=20,
save_steps=200,
fp16=True,
report_to="none"
)
trainer = SFTTrainer(
model=model,
tokenizer=tok,
train_dataset=ds,
peft_config=peft_cfg,
max_seq_length=1024,
args=args,
packing=False,
dataset_text_field=None, # because we use "messages"
)
trainer.train()
trainer.save_model("adapter_sft")
tok.save_pretrained("adapter_sft")
print("Saved adapter_sft/")
if __name__ == "__main__":
main()