python_ai_coder / train.py
Percy3822's picture
Update train.py
078d71d verified
raw
history blame
1.53 kB
import argparse
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", required=True)
parser.add_argument("--output", default="trained_model")
args = parser.parse_args()
print("πŸ“Š Loading dataset...")
dataset = load_dataset("json", data_files=args.dataset, split="train")
print("🧠 Loading model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained("distilgpt2")
# βœ… Clean, batch-safe tokenize
def tokenize(batch):
full_texts = [str(p) + str(c) for p, c in zip(batch["prompt"], batch["completion"])]
return tokenizer(full_texts, padding="max_length", truncation=True, max_length=256)
print("πŸ” Tokenizing...")
tokenized = dataset.map(tokenize, batched=True)
print("πŸ“¦ Setting up trainer...")
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
training_args = TrainingArguments(
output_dir=args.output,
per_device_train_batch_size=2,
num_train_epochs=1,
logging_steps=1,
save_steps=5,
save_total_limit=1,
report_to=[]
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized,
tokenizer=tokenizer,
data_collator=data_collator,
)
print("πŸš€ Starting training...")
trainer.train()
trainer.save_model(args.output)
print("βœ… Done.")