python_ai_coder / train.py
Percy3822's picture
Update train.py
86cb75d verified
raw
history blame
1.43 kB
import argparse
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
import torch
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, required=True)
args = parser.parse_args()
print("πŸ“₯ Loading dataset...")
dataset = load_dataset("json", data_files=args.dataset, split="train")
tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
tokenizer.pad_token = tokenizer.eos_token
def tokenize_function(examples):
return tokenizer(examples["prompt"], truncation=True, padding="max_length", max_length=256)
tokenized_dataset = dataset.map(tokenize_function, batched=True)
print("πŸ“¦ Loading model...")
model = AutoModelForCausalLM.from_pretrained("distilgpt2")
training_args = TrainingArguments(
output_dir="./trained_model",
overwrite_output_dir=True,
num_train_epochs=1,
per_device_train_batch_size=2,
save_strategy="epoch",
logging_dir="./logs",
logging_steps=10,
no_cuda=not torch.cuda.is_available()
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset
)
print("πŸš€ Starting training...")
trainer.train()
print("βœ… Training finished. Model saved to ./trained_model")
if __name__ == "__main__":
main()