Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import torch | |
| from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score | |
| from transformers import TrainingArguments, Trainer | |
| from transformers import EarlyStoppingCallback | |
| import pickle as pkl | |
| from datetime import datetime | |
| class Dataset(torch.utils.data.Dataset): | |
| def __init__(self, encodings, labels=None): | |
| self.encodings = encodings | |
| self.labels = labels | |
| def __getitem__(self, idx): | |
| item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()} | |
| item["labels"] = torch.tensor(self.labels[idx]) | |
| return item | |
| def __len__(self): | |
| return len(self.encodings["input_ids"]) | |
| def compute_metrics(p): | |
| pred, labels = p | |
| pred = np.argmax(pred, axis=1) | |
| accuracy = accuracy_score(y_true=labels, y_pred=pred) | |
| recall = recall_score(y_true=labels, y_pred=pred, average='macro', zero_division=0) | |
| precision = precision_score(y_true=labels, y_pred=pred, average='macro', zero_division=0) | |
| f1 = f1_score(y_true=labels, y_pred=pred, average="macro", zero_division=0) | |
| return {"eval_accuracy": accuracy, "eval_precision": precision, "eval_recall": recall, "eval_f1": f1} | |
| def train(model, train_dataset, val_dataset, output_dir, save_steps, num_train_epochs=10): | |
| args = TrainingArguments( | |
| output_dir=output_dir, | |
| overwrite_output_dir=True, | |
| evaluation_strategy="steps", | |
| eval_steps=save_steps, | |
| per_device_train_batch_size=16, | |
| per_device_eval_batch_size=16, | |
| num_train_epochs=num_train_epochs, | |
| seed=0, | |
| save_steps=save_steps, | |
| save_total_limit=2, | |
| load_best_model_at_end=True, | |
| metric_for_best_model='eval_f1' | |
| ) | |
| trainer = Trainer( | |
| model=model, | |
| args=args, | |
| train_dataset=train_dataset, | |
| eval_dataset=val_dataset, | |
| compute_metrics=compute_metrics, | |
| callbacks = [EarlyStoppingCallback(early_stopping_patience=3)] | |
| ) | |
| res = trainer.train() | |