Text Generation
PEFT
Safetensors
Chinese
English
qwen
qlora
radar
aircraft-cabin
structured-prediction
qa
conversational
Instructions to use sutama/CabinLavatoryPrediction with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- PEFT
How to use sutama/CabinLavatoryPrediction with PEFT:
from peft import PeftModel from transformers import AutoModelForCausalLM base_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3.5-9B") model = PeftModel.from_pretrained(base_model, "sutama/CabinLavatoryPrediction") - Notebooks
- Google Colab
- Kaggle
File size: 5,813 Bytes
e74a796 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 | #!/usr/bin/env python3
import argparse
import json
from pathlib import Path
import torch
from datasets import load_dataset
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
DataCollatorForSeq2Seq,
Trainer,
TrainingArguments,
)
def load_messages(path):
rows = []
with open(path, encoding="utf-8") as f:
for line in f:
if line.strip():
obj = json.loads(line)
rows.append({"messages": obj["messages"]})
return rows
def build_tokenizer(model_name):
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
return tokenizer
def render_prompt(tokenizer, messages):
try:
return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, enable_thinking=False)
except TypeError:
return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
def preprocess_example(example, tokenizer, max_seq_length):
messages = example["messages"]
prompt_messages = messages[:-1]
answer = messages[-1]["content"]
prompt = render_prompt(tokenizer, prompt_messages)
answer = str(answer) + tokenizer.eos_token
prompt_ids = tokenizer(prompt, add_special_tokens=False)["input_ids"]
full_ids = tokenizer(prompt + answer, add_special_tokens=False, truncation=True, max_length=max_seq_length)["input_ids"]
labels = [-100] * min(len(prompt_ids), len(full_ids)) + full_ids[len(prompt_ids) :]
labels = labels[: len(full_ids)]
return {"input_ids": full_ids, "attention_mask": [1] * len(full_ids), "labels": labels}
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model-name", default="Qwen/Qwen3.5-9B")
parser.add_argument("--train-file", default="data/processed/train_mixed.jsonl")
parser.add_argument("--val-file", default="data/processed/val_mixed.jsonl")
parser.add_argument("--output-dir", default="outputs/qwen35_9b_lora")
parser.add_argument("--max-seq-length", type=int, default=2048)
parser.add_argument("--num-train-epochs", type=float, default=1.0)
parser.add_argument("--learning-rate", type=float, default=2e-4)
parser.add_argument("--per-device-train-batch-size", type=int, default=1)
parser.add_argument("--per-device-eval-batch-size", type=int, default=1)
parser.add_argument("--gradient-accumulation-steps", type=int, default=8)
parser.add_argument("--eval-steps", type=int, default=500)
parser.add_argument("--save-steps", type=int, default=500)
parser.add_argument("--logging-steps", type=int, default=20)
parser.add_argument("--max-train-samples", type=int, default=None)
parser.add_argument("--max-eval-samples", type=int, default=512)
args = parser.parse_args()
tokenizer = build_tokenizer(args.model_name)
raw = load_dataset("json", data_files={"train": args.train_file, "validation": args.val_file})
if args.max_train_samples:
raw["train"] = raw["train"].select(range(min(args.max_train_samples, len(raw["train"]))))
if args.max_eval_samples:
raw["validation"] = raw["validation"].select(range(min(args.max_eval_samples, len(raw["validation"]))))
tokenized = raw.map(
lambda ex: preprocess_example(ex, tokenizer, args.max_seq_length),
remove_columns=raw["train"].column_names,
desc="Tokenizing chat SFT data",
)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
args.model_name,
trust_remote_code=True,
quantization_config=bnb_config,
device_map="auto",
torch_dtype=torch.bfloat16,
)
model.config.use_cache = False
model = prepare_model_for_kbit_training(model)
lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
training_args = TrainingArguments(
output_dir=args.output_dir,
num_train_epochs=args.num_train_epochs,
learning_rate=args.learning_rate,
per_device_train_batch_size=args.per_device_train_batch_size,
per_device_eval_batch_size=args.per_device_eval_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
bf16=True,
gradient_checkpointing=True,
optim="paged_adamw_8bit",
logging_steps=args.logging_steps,
eval_strategy="steps",
eval_steps=args.eval_steps,
save_strategy="steps",
save_steps=args.save_steps,
save_total_limit=3,
report_to="none",
remove_unused_columns=False,
warmup_ratio=0.03,
lr_scheduler_type="cosine",
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized["train"],
eval_dataset=tokenized["validation"],
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
)
trainer.train()
trainer.save_model(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
(Path(args.output_dir) / "run_config.json").write_text(json.dumps(vars(args), ensure_ascii=False, indent=2), encoding="utf-8")
if __name__ == "__main__":
main()
|