routercore / training /train_lora.py
Jayteare's picture
Deploy RouterCore Gradio demo
1137e50 verified
from __future__ import annotations
import argparse
from pathlib import Path
class OptionalTrainingDependencyError(RuntimeError):
"""Raised when optional LoRA training dependencies are not installed."""
def load_training_dependencies():
try:
import torch
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
DataCollatorForLanguageModeling,
Trainer,
TrainingArguments,
)
except ImportError as exc:
raise OptionalTrainingDependencyError(
"Optional training dependencies are unavailable. Install transformers, datasets, peft, and torch."
) from exc
return {
"torch": torch,
"load_dataset": load_dataset,
"LoraConfig": LoraConfig,
"get_peft_model": get_peft_model,
"AutoModelForCausalLM": AutoModelForCausalLM,
"AutoTokenizer": AutoTokenizer,
"DataCollatorForLanguageModeling": DataCollatorForLanguageModeling,
"Trainer": Trainer,
"TrainingArguments": TrainingArguments,
}
def find_lora_target_modules(model, preferred_targets: list[str]) -> list[str]:
module_suffixes = {name.split(".")[-1] for name, _ in model.named_modules()}
return [target for target in preferred_targets if target in module_suffixes]
def train_lora(args: argparse.Namespace) -> None:
deps = load_training_dependencies()
torch = deps["torch"]
load_dataset = deps["load_dataset"]
LoraConfig = deps["LoraConfig"]
get_peft_model = deps["get_peft_model"]
AutoModelForCausalLM = deps["AutoModelForCausalLM"]
AutoTokenizer = deps["AutoTokenizer"]
DataCollatorForLanguageModeling = deps["DataCollatorForLanguageModeling"]
Trainer = deps["Trainer"]
TrainingArguments = deps["TrainingArguments"]
tokenizer = AutoTokenizer.from_pretrained(args.model)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(args.model)
model.config.pad_token_id = tokenizer.pad_token_id
preferred_targets = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
target_modules = find_lora_target_modules(model, preferred_targets)
if not target_modules:
raise ValueError(
"No common LoRA target modules were found. Expected one of: "
f"{', '.join(preferred_targets)}. Inspect the model architecture and set compatible targets."
)
lora_config = LoraConfig(
r=args.lora_r,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
bias="none",
task_type="CAUSAL_LM",
target_modules=target_modules,
)
model = get_peft_model(model, lora_config)
if hasattr(model, "print_trainable_parameters"):
model.print_trainable_parameters()
dataset = load_dataset(
"json",
data_files={"train": str(args.train_file), "eval": str(args.eval_file)},
)
def tokenize_batch(batch):
tokenized = tokenizer(
batch["text"],
truncation=True,
max_length=args.max_seq_length,
padding=False,
)
return tokenized
tokenized_dataset = dataset.map(
tokenize_batch,
batched=True,
remove_columns=dataset["train"].column_names,
)
training_args = TrainingArguments(
output_dir=str(args.output_dir),
max_steps=args.max_steps,
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
learning_rate=args.learning_rate,
logging_steps=args.logging_steps,
save_steps=args.max_steps,
report_to=[],
remove_unused_columns=False,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset["train"],
eval_dataset=tokenized_dataset["eval"],
data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False),
)
trainer.train()
args.output_dir.mkdir(parents=True, exist_ok=True)
trainer.model.save_pretrained(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
device_name = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Saved LoRA adapter and tokenizer to {args.output_dir}")
print(f"Training device detected by torch: {device_name}")
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Fine-tune a RouterCore LoRA adapter.")
parser.add_argument("--model", required=True, help="Base Hugging Face model name or path.")
parser.add_argument("--train-file", type=Path, required=True)
parser.add_argument("--eval-file", type=Path, required=True)
parser.add_argument("--output-dir", type=Path, required=True)
parser.add_argument("--max-steps", type=int, default=100)
parser.add_argument("--batch-size", type=int, default=1)
parser.add_argument("--gradient-accumulation-steps", type=int, default=8)
parser.add_argument("--learning-rate", type=float, default=2e-4)
parser.add_argument("--max-seq-length", type=int, default=1024)
parser.add_argument("--logging-steps", type=int, default=10)
parser.add_argument("--lora-r", type=int, default=8)
parser.add_argument("--lora-alpha", type=int, default=16)
parser.add_argument("--lora-dropout", type=float, default=0.05)
return parser.parse_args()
def main() -> None:
args = parse_args()
try:
train_lora(args)
except OptionalTrainingDependencyError as exc:
print(str(exc))
print("Skipping LoRA training. Run `pip install transformers datasets peft torch` to enable it.")
except ValueError as exc:
print(f"LoRA training configuration error: {exc}")
raise SystemExit(2) from exc
if __name__ == "__main__":
main()