"""Tests for train CLI command.""" from unittest.mock import MagicMock, patch from axolotl.cli.main import cli from .test_cli_base import BaseCliTest class TestTrainCommand(BaseCliTest): """Test cases for train command.""" cli = cli def test_train_cli_validation(self, cli_runner): """Test CLI validation""" self._test_cli_validation(cli_runner, "train") def test_train_basic_execution(self, cli_runner, tmp_path, valid_test_config): """Test basic successful execution""" self._test_basic_execution(cli_runner, tmp_path, valid_test_config, "train") def test_train_basic_execution_no_accelerate( self, cli_runner, tmp_path, valid_test_config ): """Test basic successful execution without accelerate""" config_path = tmp_path / "config.yml" config_path.write_text(valid_test_config) with patch("axolotl.cli.train.train") as mock_train: mock_train.return_value = (MagicMock(), MagicMock(), MagicMock()) result = cli_runner.invoke( cli, [ "train", str(config_path), "--no-accelerate", ], catch_exceptions=False, ) assert result.exit_code == 0 mock_train.assert_called_once() def test_train_cli_overrides(self, cli_runner, tmp_path, valid_test_config): """Test CLI arguments properly override config values""" config_path = self._test_cli_overrides(tmp_path, valid_test_config) with patch("axolotl.cli.train.train") as mock_train: mock_train.return_value = (MagicMock(), MagicMock(), MagicMock()) result = cli_runner.invoke( cli, [ "train", str(config_path), "--learning-rate", "1e-4", "--micro-batch-size", "2", "--no-accelerate", ], catch_exceptions=False, ) assert result.exit_code == 0 mock_train.assert_called_once() cfg = mock_train.call_args[1]["cfg"] assert cfg["learning_rate"] == 1e-4 assert cfg["micro_batch_size"] == 2