import os, argparse, json from datasets import load_dataset from transformers import ( AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling ) from peft import LoraConfig, get_peft_model def main(): ap = argparse.ArgumentParser() ap.add_argument("--base_model", default="distilgpt2") ap.add_argument("--train_jsonl", default="chat_seed_train.jsonl") ap.add_argument("--valid_jsonl", default="chat_seed_valid.jsonl") ap.add_argument("--out_dir", default="trained_lora") ap.add_argument("--epochs", type=float, default=1.0) ap.add_argument("--lr", type=float, default=5e-4) ap.add_argument("--batch_size", type=int, default=8) ap.add_argument("--max_len", type=int, default=192) ap.add_argument("--lora_r", type=int, default=8) ap.add_argument("--lora_alpha", type=int, default=16) ap.add_argument("--lora_dropout", type=float, default=0.05) args = ap.parse_args() data = load_dataset("json", data_files={"train": args.train_jsonl, "valid": args.valid_jsonl}) tok = AutoTokenizer.from_pretrained(args.base_model) if tok.pad_token is None: tok.pad_token = tok.eos_token def tok_fn(batch): joined = [p + c for p, c in zip(batch["prompt"], batch["completion"])] return tok(joined, truncation=True, max_length=args.max_len) tokenized = data.map(tok_fn, batched=True, remove_columns=["prompt", "completion"]) base = AutoModelForCausalLM.from_pretrained(args.base_model) lconf = LoraConfig( r=args.lora_r, lora_alpha=args.lora_alpha, target_modules=["c_attn"], lora_dropout=args.lora_dropout, bias="none", task_type="CAUSAL_LM", ) model = get_peft_model(base, lconf) collator = DataCollatorForLanguageModeling(tok, mlm=False) # Minimal args compatible with older Transformers targs = TrainingArguments( output_dir=args.out_dir, per_device_train_batch_size=args.batch_size, num_train_epochs=args.epochs, learning_rate=args.lr, weight_decay=0.01, logging_steps=10, save_total_limit=1, fp16=False, ) trainer = Trainer( model=model, args=targs, train_dataset=tokenized["train"], eval_dataset=tokenized["valid"], # will call evaluate() explicitly data_collator=collator, tokenizer=tok, ) trainer.train() # Save adapters locally model.save_pretrained(args.out_dir) tok.save_pretrained(args.out_dir) # Try an evaluation pass (older TF versions still expose evaluate()) try: metrics = trainer.evaluate() except Exception as e: metrics = {"note": "evaluate() failed on this version", "error": str(e)} print(json.dumps(metrics, indent=2)) print("Saved adapters to:", args.out_dir) if _name_ == "_main_": main()