|
|
--- |
|
|
license: mit |
|
|
base_model: roberta-large |
|
|
tags: |
|
|
- generated_from_trainer |
|
|
metrics: |
|
|
- accuracy |
|
|
model-index: |
|
|
- name: imdb_roberta_large |
|
|
results: [] |
|
|
--- |
|
|
|
|
|
<!-- This model card has been generated automatically according to the information the Trainer had access to. You |
|
|
should probably proofread and complete it, then remove this comment. --> |
|
|
|
|
|
# imdb_roberta_large |
|
|
|
|
|
This model is a fine-tuned version of [roberta-large](https://huggingface.co/roberta-large) on an unknown dataset. |
|
|
It achieves the following results on the evaluation set: |
|
|
- Loss: 0.1728 |
|
|
- Accuracy: 0.9627 |
|
|
|
|
|
## Model description |
|
|
|
|
|
Train and Test Code |
|
|
```python |
|
|
from datasets import load_dataset |
|
|
imdb = load_dataset("imdb") |
|
|
|
|
|
import numpy as np |
|
|
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer |
|
|
import torch |
|
|
from transformers import AutoTokenizer |
|
|
from transformers import DataCollatorWithPadding |
|
|
from transformers import EarlyStoppingCallback |
|
|
import evaluate |
|
|
|
|
|
|
|
|
# model_name = 'xlnet-large-cased' |
|
|
model_name = 'roberta-large' |
|
|
|
|
|
id2label = {0: "NEGATIVE", 1: "POSITIVE"} |
|
|
label2id = {"NEGATIVE": 0, "POSITIVE": 1} |
|
|
def compute_metrics(eval_pred): |
|
|
predictions, labels = eval_pred |
|
|
predictions = np.argmax(predictions, axis=1) |
|
|
return accuracy.compute(predictions=predictions, references=labels) |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
def preprocess_function(examples): |
|
|
return tokenizer(examples["text"], truncation=True) |
|
|
tokenized_imdb = imdb.map(preprocess_function, batched=True) |
|
|
|
|
|
data_collator = DataCollatorWithPadding(tokenizer=tokenizer) |
|
|
accuracy = evaluate.load("accuracy") |
|
|
|
|
|
model = AutoModelForSequenceClassification.from_pretrained( |
|
|
model_name, num_labels=2, id2label=id2label, label2id=label2id |
|
|
) |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
model = model.to(device) |
|
|
|
|
|
|
|
|
bts = 8 |
|
|
accumulated_step = 2 |
|
|
training_args = TrainingArguments( |
|
|
output_dir=f"5imdb_{model_name.replace('-','_')}", |
|
|
learning_rate=2e-5, |
|
|
per_device_train_batch_size=bts, |
|
|
per_device_eval_batch_size=bts, |
|
|
num_train_epochs=2, |
|
|
weight_decay=0.01, |
|
|
evaluation_strategy="epoch", |
|
|
save_strategy="epoch", |
|
|
load_best_model_at_end=True, |
|
|
push_to_hub=True, |
|
|
gradient_accumulation_steps=accumulated_step, |
|
|
) |
|
|
# 创建 EarlyStoppingCallback 回调 |
|
|
early_stopping = EarlyStoppingCallback(early_stopping_patience=3) |
|
|
trainer = Trainer( |
|
|
model=model, |
|
|
args=training_args, |
|
|
train_dataset=tokenized_imdb["train"], |
|
|
eval_dataset=tokenized_imdb["test"], |
|
|
tokenizer=tokenizer, |
|
|
data_collator=data_collator, |
|
|
compute_metrics=compute_metrics, |
|
|
callbacks=[early_stopping], |
|
|
) |
|
|
|
|
|
trainer.train() |
|
|
``` |
|
|
|
|
|
## Intended uses & limitations |
|
|
|
|
|
More information needed |
|
|
|
|
|
## Training and evaluation data |
|
|
|
|
|
More information needed |
|
|
|
|
|
## Training procedure |
|
|
|
|
|
### Training hyperparameters |
|
|
|
|
|
The following hyperparameters were used during training: |
|
|
- learning_rate: 2e-05 |
|
|
- train_batch_size: 8 |
|
|
- eval_batch_size: 8 |
|
|
- seed: 42 |
|
|
- gradient_accumulation_steps: 2 |
|
|
- total_train_batch_size: 16 |
|
|
- optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08 |
|
|
- lr_scheduler_type: linear |
|
|
- num_epochs: 2 |
|
|
|
|
|
### Training results |
|
|
|
|
|
| Training Loss | Epoch | Step | Validation Loss | Accuracy | |
|
|
|:-------------:|:-----:|:----:|:---------------:|:--------:| |
|
|
| 0.1732 | 1.0 | 1562 | 0.1323 | 0.9574 | |
|
|
| 0.0978 | 2.0 | 3124 | 0.1728 | 0.9627 | |
|
|
|
|
|
|
|
|
### Framework versions |
|
|
|
|
|
- Transformers 4.38.2 |
|
|
- Pytorch 2.2.1 |
|
|
- Datasets 2.18.0 |
|
|
- Tokenizers 0.15.2 |
|
|
|