Spaces:
Sleeping
Sleeping
| import os, argparse, json | |
| from datasets import load_dataset | |
| from transformers import ( | |
| AutoTokenizer, AutoModelForCausalLM, | |
| TrainingArguments, Trainer, | |
| DataCollatorForLanguageModeling | |
| ) | |
| from peft import LoraConfig, get_peft_model | |
| def main(): | |
| ap = argparse.ArgumentParser() | |
| ap.add_argument("--base_model", default="distilgpt2") | |
| ap.add_argument("--train_jsonl", default="chat_seed_train.jsonl") | |
| ap.add_argument("--valid_jsonl", default="chat_seed_valid.jsonl") | |
| ap.add_argument("--out_dir", default="trained_lora") | |
| ap.add_argument("--epochs", type=float, default=1.0) | |
| ap.add_argument("--lr", type=float, default=5e-4) | |
| ap.add_argument("--batch_size", type=int, default=8) | |
| ap.add_argument("--max_len", type=int, default=192) | |
| ap.add_argument("--lora_r", type=int, default=8) | |
| ap.add_argument("--lora_alpha", type=int, default=16) | |
| ap.add_argument("--lora_dropout", type=float, default=0.05) | |
| args = ap.parse_args() | |
| data = load_dataset("json", data_files={"train": args.train_jsonl, "valid": args.valid_jsonl}) | |
| tok = AutoTokenizer.from_pretrained(args.base_model) | |
| if tok.pad_token is None: | |
| tok.pad_token = tok.eos_token | |
| def tok_fn(batch): | |
| joined = [p + c for p, c in zip(batch["prompt"], batch["completion"])] | |
| return tok(joined, truncation=True, max_length=args.max_len) | |
| tokenized = data.map(tok_fn, batched=True, remove_columns=["prompt", "completion"]) | |
| base = AutoModelForCausalLM.from_pretrained(args.base_model) | |
| lconf = LoraConfig( | |
| r=args.lora_r, | |
| lora_alpha=args.lora_alpha, | |
| target_modules=["c_attn"], | |
| lora_dropout=args.lora_dropout, | |
| bias="none", | |
| task_type="CAUSAL_LM", | |
| ) | |
| model = get_peft_model(base, lconf) | |
| collator = DataCollatorForLanguageModeling(tok, mlm=False) | |
| # Minimal args compatible with older Transformers | |
| targs = TrainingArguments( | |
| output_dir=args.out_dir, | |
| per_device_train_batch_size=args.batch_size, | |
| num_train_epochs=args.epochs, | |
| learning_rate=args.lr, | |
| weight_decay=0.01, | |
| logging_steps=10, | |
| save_total_limit=1, | |
| fp16=False, | |
| ) | |
| trainer = Trainer( | |
| model=model, | |
| args=targs, | |
| train_dataset=tokenized["train"], | |
| eval_dataset=tokenized["valid"], # will call evaluate() explicitly | |
| data_collator=collator, | |
| tokenizer=tok, | |
| ) | |
| trainer.train() | |
| # Save adapters locally | |
| model.save_pretrained(args.out_dir) | |
| tok.save_pretrained(args.out_dir) | |
| # Try an evaluation pass (older TF versions still expose evaluate()) | |
| try: | |
| metrics = trainer.evaluate() | |
| except Exception as e: | |
| metrics = {"note": "evaluate() failed on this version", "error": str(e)} | |
| print(json.dumps(metrics, indent=2)) | |
| print("Saved adapters to:", args.out_dir) | |
| if _name_ == "_main_": | |
| main() |