license: apache-2.0
pip install transformers datasets torch scikit-learn import torch from datasets import load_dataset from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer from sklearn.model_selection import train_test_split from sklearn.metrics import accuracy_score, f1_score def load_and_prepare_data(): dataset = load_dataset("emotion") train_dataset = dataset["train"] test_dataset = dataset["test"] return train_dataset, test_dataset def tokenize_dataset(dataset): tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") def tokenize_function(examples): return tokenizer(examples["text"], padding="max_length", truncation=True) tokenized_dataset = dataset.map(tokenize_function, batched=True) return tokenized_dataset def load_model(): num_labels = 6 model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=num_labels) return model def define_training_arguments(): training_args = TrainingArguments( output_dir="./results", num_train_epochs=3, per_device_train_batch_size=16, per_device_eval_batch_size=64, warmup_steps=500, weight_decay=0.01, logging_dir="./logs", logging_steps=10, evaluation_strategy="epoch", save_strategy="epoch", load_best_model_at_end=True, metric_for_best_model="accuracy", greater_is_better=True, ) return training_args def compute_metrics(eval_pred): logits, labels = eval_pred predictions = torch.argmax(torch.tensor(logits), dim=-1) accuracy = accuracy_score(labels, predictions) f1 = f1_score(labels, predictions, average="weighted") return {"accuracy": accuracy, "f1": f1} def main(): train_dataset, test_dataset = load_and_prepare_data() tokenized_train_dataset = tokenize_dataset(train_dataset) tokenized_test_dataset = tokenize_dataset(test_dataset) model = load_model() training_args = define_training_arguments() trainer = Trainer( model=model, args=training_args, train_dataset=tokenized_train_dataset, eval_dataset=tokenized_test_dataset, compute_metrics=compute_metrics, ) trainer.train() trainer.evaluate() trainer.save_model() if name == "main": main()