import argparse from datasets import load_dataset from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling parser = argparse.ArgumentParser() parser.add_argument("--dataset", required=True) parser.add_argument("--output", default="trained_model") args = parser.parse_args() print("📊 Loading dataset...") dataset = load_dataset("json", data_files=args.dataset, split="train") print("🧠 Loading model and tokenizer...") tokenizer = AutoTokenizer.from_pretrained("distilgpt2") tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained("distilgpt2") # ✅ Clean, batch-safe tokenize def tokenize(batch): full_texts = [str(p) + str(c) for p, c in zip(batch["prompt"], batch["completion"])] return tokenizer(full_texts, padding="max_length", truncation=True, max_length=256) print("🔁 Tokenizing...") tokenized = dataset.map(tokenize, batched=True) print("📦 Setting up trainer...") data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) training_args = TrainingArguments( output_dir=args.output, per_device_train_batch_size=2, num_train_epochs=1, logging_steps=1, save_steps=5, save_total_limit=1, report_to=[] ) trainer = Trainer( model=model, args=training_args, train_dataset=tokenized, tokenizer=tokenizer, data_collator=data_collator, ) print("🚀 Starting training...") trainer.train() trainer.save_model(args.output) print("✅ Done.")