Flair
English
text-generation-inference
HGRXv1 / train.py
vipisdeath's picture
Create train.py
fa59067 verified
raw
history blame contribute delete
914 Bytes
# train.py
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
# 1. Get the data
dataset = load_dataset("imdb") # Or "csv", data_files="my_data.csv"
# 2. Get the tools
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
# 3. Process the data
def tokenize_func(examples):
return tokenizer(examples["text"], truncation=True, padding="max_length")
tokenized_datasets = dataset.map(tokenize_func, batched=True)
# 4. Set the rules
training_args = TrainingArguments(output_dir="my_model_folder", num_train_epochs=1)
# 5. Do the training
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_datasets["train"],
eval_dataset=tokenized_datasets["test"],
)
trainer.train()