import os from datasets import load_dataset from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments from trl import SFTTrainer from peft import LoraConfig BASE_MODEL = os.environ.get("BASE_MODEL", "Qwen/Qwen2.5-0.5B-Instruct") def main(): ds = load_dataset("json", data_files="data/sft.jsonl")["train"] tok = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True) if tok.pad_token is None: tok.pad_token = tok.eos_token model = AutoModelForCausalLM.from_pretrained( BASE_MODEL, device_map="auto", torch_dtype="auto", ) peft_cfg = LoraConfig( r=16, lora_alpha=32, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM", target_modules=["q_proj","k_proj","v_proj","o_proj","up_proj","down_proj","gate_proj"] ) args = TrainingArguments( output_dir="adapter_sft", per_device_train_batch_size=2, gradient_accumulation_steps=8, learning_rate=2e-4, num_train_epochs=1, logging_steps=20, save_steps=200, fp16=True, report_to="none" ) trainer = SFTTrainer( model=model, tokenizer=tok, train_dataset=ds, peft_config=peft_cfg, max_seq_length=1024, args=args, packing=False, dataset_text_field=None, # because we use "messages" ) trainer.train() trainer.save_model("adapter_sft") tok.save_pretrained("adapter_sft") print("Saved adapter_sft/") if __name__ == "__main__": main()