township-chatbot / model_training.py
puseletso55's picture
Added training data, script, and updated chatbot code
4b1883f
from transformers import GPT2LMHeadModel, GPT2Tokenizer, Trainer, TrainingArguments
from datasets import load_dataset
# Load dataset (change this to your dataset)
train_data = load_dataset("csv", data_files={"train": "data/train_data.csv"})
val_data = load_dataset("csv", data_files={"validation": "data/validation_data.csv"})
# Load pre-trained GPT-2 model and tokenizer
model_name = "gpt2"
model = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
# Tokenize datasets
def tokenize_function(examples):
return tokenizer(examples["text"], padding="max_length", truncation=True)
train_data = train_data.map(tokenize_function, batched=True)
val_data = val_data.map(tokenize_function, batched=True)
# Set up training arguments
training_args = TrainingArguments(
output_dir="./results",
evaluation_strategy="epoch",
learning_rate=2e-5,
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
num_train_epochs=3,
weight_decay=0.01,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_data["train"],
eval_dataset=val_data["validation"],
)
# Train the model
trainer.train()
# Save the model and tokenizer
model.save_pretrained("township_business_model")
tokenizer.save_pretrained("township_business_model")