CabinLavatoryPrediction / code /train_qlora.py
sutama's picture
Upload CabinLavatoryPrediction LoRA adapter, checkpoint, code, and evaluation artifacts
e74a796 verified
#!/usr/bin/env python3
import argparse
import json
from pathlib import Path
import torch
from datasets import load_dataset
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
DataCollatorForSeq2Seq,
Trainer,
TrainingArguments,
)
def load_messages(path):
rows = []
with open(path, encoding="utf-8") as f:
for line in f:
if line.strip():
obj = json.loads(line)
rows.append({"messages": obj["messages"]})
return rows
def build_tokenizer(model_name):
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
return tokenizer
def render_prompt(tokenizer, messages):
try:
return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, enable_thinking=False)
except TypeError:
return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
def preprocess_example(example, tokenizer, max_seq_length):
messages = example["messages"]
prompt_messages = messages[:-1]
answer = messages[-1]["content"]
prompt = render_prompt(tokenizer, prompt_messages)
answer = str(answer) + tokenizer.eos_token
prompt_ids = tokenizer(prompt, add_special_tokens=False)["input_ids"]
full_ids = tokenizer(prompt + answer, add_special_tokens=False, truncation=True, max_length=max_seq_length)["input_ids"]
labels = [-100] * min(len(prompt_ids), len(full_ids)) + full_ids[len(prompt_ids) :]
labels = labels[: len(full_ids)]
return {"input_ids": full_ids, "attention_mask": [1] * len(full_ids), "labels": labels}
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model-name", default="Qwen/Qwen3.5-9B")
parser.add_argument("--train-file", default="data/processed/train_mixed.jsonl")
parser.add_argument("--val-file", default="data/processed/val_mixed.jsonl")
parser.add_argument("--output-dir", default="outputs/qwen35_9b_lora")
parser.add_argument("--max-seq-length", type=int, default=2048)
parser.add_argument("--num-train-epochs", type=float, default=1.0)
parser.add_argument("--learning-rate", type=float, default=2e-4)
parser.add_argument("--per-device-train-batch-size", type=int, default=1)
parser.add_argument("--per-device-eval-batch-size", type=int, default=1)
parser.add_argument("--gradient-accumulation-steps", type=int, default=8)
parser.add_argument("--eval-steps", type=int, default=500)
parser.add_argument("--save-steps", type=int, default=500)
parser.add_argument("--logging-steps", type=int, default=20)
parser.add_argument("--max-train-samples", type=int, default=None)
parser.add_argument("--max-eval-samples", type=int, default=512)
args = parser.parse_args()
tokenizer = build_tokenizer(args.model_name)
raw = load_dataset("json", data_files={"train": args.train_file, "validation": args.val_file})
if args.max_train_samples:
raw["train"] = raw["train"].select(range(min(args.max_train_samples, len(raw["train"]))))
if args.max_eval_samples:
raw["validation"] = raw["validation"].select(range(min(args.max_eval_samples, len(raw["validation"]))))
tokenized = raw.map(
lambda ex: preprocess_example(ex, tokenizer, args.max_seq_length),
remove_columns=raw["train"].column_names,
desc="Tokenizing chat SFT data",
)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
args.model_name,
trust_remote_code=True,
quantization_config=bnb_config,
device_map="auto",
torch_dtype=torch.bfloat16,
)
model.config.use_cache = False
model = prepare_model_for_kbit_training(model)
lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
training_args = TrainingArguments(
output_dir=args.output_dir,
num_train_epochs=args.num_train_epochs,
learning_rate=args.learning_rate,
per_device_train_batch_size=args.per_device_train_batch_size,
per_device_eval_batch_size=args.per_device_eval_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
bf16=True,
gradient_checkpointing=True,
optim="paged_adamw_8bit",
logging_steps=args.logging_steps,
eval_strategy="steps",
eval_steps=args.eval_steps,
save_strategy="steps",
save_steps=args.save_steps,
save_total_limit=3,
report_to="none",
remove_unused_columns=False,
warmup_ratio=0.03,
lr_scheduler_type="cosine",
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized["train"],
eval_dataset=tokenized["validation"],
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
)
trainer.train()
trainer.save_model(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
(Path(args.output_dir) / "run_config.json").write_text(json.dumps(vars(args), ensure_ascii=False, indent=2), encoding="utf-8")
if __name__ == "__main__":
main()