chat-nlm-trainer / train.py
Percy3822's picture
Update train.py
5a4dc29 verified
import os, argparse, json
from datasets import load_dataset
from transformers import (
AutoTokenizer, AutoModelForCausalLM,
TrainingArguments, Trainer,
DataCollatorForLanguageModeling
)
from peft import LoraConfig, get_peft_model
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--base_model", default="distilgpt2")
ap.add_argument("--train_jsonl", default="chat_seed_train.jsonl")
ap.add_argument("--valid_jsonl", default="chat_seed_valid.jsonl")
ap.add_argument("--out_dir", default="trained_lora")
ap.add_argument("--epochs", type=float, default=1.0)
ap.add_argument("--lr", type=float, default=5e-4)
ap.add_argument("--batch_size", type=int, default=8)
ap.add_argument("--max_len", type=int, default=192)
ap.add_argument("--lora_r", type=int, default=8)
ap.add_argument("--lora_alpha", type=int, default=16)
ap.add_argument("--lora_dropout", type=float, default=0.05)
args = ap.parse_args()
data = load_dataset("json", data_files={"train": args.train_jsonl, "valid": args.valid_jsonl})
tok = AutoTokenizer.from_pretrained(args.base_model)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
def tok_fn(batch):
joined = [p + c for p, c in zip(batch["prompt"], batch["completion"])]
return tok(joined, truncation=True, max_length=args.max_len)
tokenized = data.map(tok_fn, batched=True, remove_columns=["prompt", "completion"])
base = AutoModelForCausalLM.from_pretrained(args.base_model)
lconf = LoraConfig(
r=args.lora_r,
lora_alpha=args.lora_alpha,
target_modules=["c_attn"],
lora_dropout=args.lora_dropout,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(base, lconf)
collator = DataCollatorForLanguageModeling(tok, mlm=False)
# Minimal args compatible with older Transformers
targs = TrainingArguments(
output_dir=args.out_dir,
per_device_train_batch_size=args.batch_size,
num_train_epochs=args.epochs,
learning_rate=args.lr,
weight_decay=0.01,
logging_steps=10,
save_total_limit=1,
fp16=False,
)
trainer = Trainer(
model=model,
args=targs,
train_dataset=tokenized["train"],
eval_dataset=tokenized["valid"], # will call evaluate() explicitly
data_collator=collator,
tokenizer=tok,
)
trainer.train()
# Save adapters locally
model.save_pretrained(args.out_dir)
tok.save_pretrained(args.out_dir)
# Try an evaluation pass (older TF versions still expose evaluate())
try:
metrics = trainer.evaluate()
except Exception as e:
metrics = {"note": "evaluate() failed on this version", "error": str(e)}
print(json.dumps(metrics, indent=2))
print("Saved adapters to:", args.out_dir)
if _name_ == "_main_":
main()