|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| import torch
|
| import torch.nn as nn
|
| from datasets import Dataset
|
| from transformers import Trainer, TrainingArguments
|
|
|
| from trl.trainer.callbacks import RichProgressCallback
|
|
|
| from .testing_utils import TrlTestCase, require_rich
|
|
|
|
|
| class DummyModel(nn.Module):
|
| def __init__(self):
|
| super().__init__()
|
| self.a = nn.Parameter(torch.tensor(1.0))
|
|
|
| def forward(self, x):
|
| return self.a * x
|
|
|
|
|
| @require_rich
|
| class TestRichProgressCallback(TrlTestCase):
|
| def setup_method(self):
|
| self.dummy_model = DummyModel()
|
| self.dummy_train_dataset = Dataset.from_list([{"x": 1.0, "y": 2.0}] * 5)
|
| self.dummy_val_dataset = Dataset.from_list([{"x": 1.0, "y": 2.0}] * 101)
|
|
|
| def test_rich_progress_callback_logging(self):
|
| training_args = TrainingArguments(
|
| output_dir=self.tmp_dir,
|
| per_device_eval_batch_size=2,
|
| per_device_train_batch_size=2,
|
| num_train_epochs=4,
|
| eval_strategy="steps",
|
| eval_steps=1,
|
| logging_strategy="steps",
|
| logging_steps=1,
|
| save_strategy="no",
|
| report_to="none",
|
| disable_tqdm=True,
|
| )
|
| callbacks = [RichProgressCallback()]
|
| trainer = Trainer(
|
| model=self.dummy_model,
|
| train_dataset=self.dummy_train_dataset,
|
| eval_dataset=self.dummy_val_dataset,
|
| args=training_args,
|
| callbacks=callbacks,
|
| )
|
|
|
| trainer.train()
|
|
|