#!/usr/bin/env python3 import argparse import json from pathlib import Path import torch from datasets import load_dataset from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training from transformers import ( AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, DataCollatorForSeq2Seq, Trainer, TrainingArguments, ) def load_messages(path): rows = [] with open(path, encoding="utf-8") as f: for line in f: if line.strip(): obj = json.loads(line) rows.append({"messages": obj["messages"]}) return rows def build_tokenizer(model_name): tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "right" return tokenizer def render_prompt(tokenizer, messages): try: return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, enable_thinking=False) except TypeError: return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) def preprocess_example(example, tokenizer, max_seq_length): messages = example["messages"] prompt_messages = messages[:-1] answer = messages[-1]["content"] prompt = render_prompt(tokenizer, prompt_messages) answer = str(answer) + tokenizer.eos_token prompt_ids = tokenizer(prompt, add_special_tokens=False)["input_ids"] full_ids = tokenizer(prompt + answer, add_special_tokens=False, truncation=True, max_length=max_seq_length)["input_ids"] labels = [-100] * min(len(prompt_ids), len(full_ids)) + full_ids[len(prompt_ids) :] labels = labels[: len(full_ids)] return {"input_ids": full_ids, "attention_mask": [1] * len(full_ids), "labels": labels} def main(): parser = argparse.ArgumentParser() parser.add_argument("--model-name", default="Qwen/Qwen3.5-9B") parser.add_argument("--train-file", default="data/processed/train_mixed.jsonl") parser.add_argument("--val-file", default="data/processed/val_mixed.jsonl") parser.add_argument("--output-dir", default="outputs/qwen35_9b_lora") parser.add_argument("--max-seq-length", type=int, default=2048) parser.add_argument("--num-train-epochs", type=float, default=1.0) parser.add_argument("--learning-rate", type=float, default=2e-4) parser.add_argument("--per-device-train-batch-size", type=int, default=1) parser.add_argument("--per-device-eval-batch-size", type=int, default=1) parser.add_argument("--gradient-accumulation-steps", type=int, default=8) parser.add_argument("--eval-steps", type=int, default=500) parser.add_argument("--save-steps", type=int, default=500) parser.add_argument("--logging-steps", type=int, default=20) parser.add_argument("--max-train-samples", type=int, default=None) parser.add_argument("--max-eval-samples", type=int, default=512) args = parser.parse_args() tokenizer = build_tokenizer(args.model_name) raw = load_dataset("json", data_files={"train": args.train_file, "validation": args.val_file}) if args.max_train_samples: raw["train"] = raw["train"].select(range(min(args.max_train_samples, len(raw["train"])))) if args.max_eval_samples: raw["validation"] = raw["validation"].select(range(min(args.max_eval_samples, len(raw["validation"])))) tokenized = raw.map( lambda ex: preprocess_example(ex, tokenizer, args.max_seq_length), remove_columns=raw["train"].column_names, desc="Tokenizing chat SFT data", ) bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, ) model = AutoModelForCausalLM.from_pretrained( args.model_name, trust_remote_code=True, quantization_config=bnb_config, device_map="auto", torch_dtype=torch.bfloat16, ) model.config.use_cache = False model = prepare_model_for_kbit_training(model) lora_config = LoraConfig( r=16, lora_alpha=32, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM", target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], ) model = get_peft_model(model, lora_config) model.print_trainable_parameters() training_args = TrainingArguments( output_dir=args.output_dir, num_train_epochs=args.num_train_epochs, learning_rate=args.learning_rate, per_device_train_batch_size=args.per_device_train_batch_size, per_device_eval_batch_size=args.per_device_eval_batch_size, gradient_accumulation_steps=args.gradient_accumulation_steps, bf16=True, gradient_checkpointing=True, optim="paged_adamw_8bit", logging_steps=args.logging_steps, eval_strategy="steps", eval_steps=args.eval_steps, save_strategy="steps", save_steps=args.save_steps, save_total_limit=3, report_to="none", remove_unused_columns=False, warmup_ratio=0.03, lr_scheduler_type="cosine", ) trainer = Trainer( model=model, args=training_args, train_dataset=tokenized["train"], eval_dataset=tokenized["validation"], data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True), ) trainer.train() trainer.save_model(args.output_dir) tokenizer.save_pretrained(args.output_dir) (Path(args.output_dir) / "run_config.json").write_text(json.dumps(vars(args), ensure_ascii=False, indent=2), encoding="utf-8") if __name__ == "__main__": main()