|
|
import os |
|
|
import tempfile |
|
|
import unittest |
|
|
|
|
|
from transformers import TrainingArguments |
|
|
|
|
|
|
|
|
class TestTrainingArguments(unittest.TestCase): |
|
|
def test_default_output_dir(self): |
|
|
"""Test that output_dir defaults to 'trainer_output' when not specified.""" |
|
|
args = TrainingArguments(output_dir=None) |
|
|
self.assertEqual(args.output_dir, "trainer_output") |
|
|
|
|
|
def test_custom_output_dir(self): |
|
|
"""Test that output_dir is respected when specified.""" |
|
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
|
args = TrainingArguments(output_dir=tmp_dir) |
|
|
self.assertEqual(args.output_dir, tmp_dir) |
|
|
|
|
|
def test_output_dir_creation(self): |
|
|
"""Test that output_dir is created only when needed.""" |
|
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
|
output_dir = os.path.join(tmp_dir, "test_output") |
|
|
|
|
|
|
|
|
self.assertFalse(os.path.exists(output_dir)) |
|
|
|
|
|
|
|
|
args = TrainingArguments( |
|
|
output_dir=output_dir, |
|
|
do_train=True, |
|
|
save_strategy="no", |
|
|
report_to=None, |
|
|
) |
|
|
self.assertFalse(os.path.exists(output_dir)) |
|
|
|
|
|
|
|
|
args.save_strategy = "steps" |
|
|
args.save_steps = 1 |
|
|
self.assertFalse(os.path.exists(output_dir)) |
|
|
|
|
|
|
|
|
|