fduhomework02 / README.md
Ulys5e's picture
Update README.md
f615f2f verified
---
license: apache-2.0
---
pip install transformers datasets torch scikit-learn
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score
def load_and_prepare_data():
dataset = load_dataset("emotion")
train_dataset = dataset["train"]
test_dataset = dataset["test"]
return train_dataset, test_dataset
def tokenize_dataset(dataset):
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
def tokenize_function(examples):
return tokenizer(examples["text"], padding="max_length", truncation=True)
tokenized_dataset = dataset.map(tokenize_function, batched=True)
return tokenized_dataset
def load_model():
num_labels = 6
model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=num_labels)
return model
def define_training_arguments():
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=3,
per_device_train_batch_size=16,
per_device_eval_batch_size=64,
warmup_steps=500,
weight_decay=0.01,
logging_dir="./logs",
logging_steps=10,
evaluation_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model="accuracy",
greater_is_better=True,
)
return training_args
def compute_metrics(eval_pred):
logits, labels = eval_pred
predictions = torch.argmax(torch.tensor(logits), dim=-1)
accuracy = accuracy_score(labels, predictions)
f1 = f1_score(labels, predictions, average="weighted")
return {"accuracy": accuracy, "f1": f1}
def main():
train_dataset, test_dataset = load_and_prepare_data()
tokenized_train_dataset = tokenize_dataset(train_dataset)
tokenized_test_dataset = tokenize_dataset(test_dataset)
model = load_model()
training_args = define_training_arguments()
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_train_dataset,
eval_dataset=tokenized_test_dataset,
compute_metrics=compute_metrics,
)
trainer.train()
trainer.evaluate()
trainer.save_model()
if __name__ == "__main__":
main()