File size: 2,233 Bytes
fcca8c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
"""Tests for evaluate CLI command."""

from unittest.mock import patch

from axolotl.cli.main import cli

from .test_cli_base import BaseCliTest


class TestEvaluateCommand(BaseCliTest):
    """Test cases for evaluate command."""

    cli = cli

    def test_evaluate_cli_validation(self, cli_runner):
        """Test CLI validation"""
        self._test_cli_validation(cli_runner, "evaluate")

    def test_evaluate_basic_execution(self, cli_runner, tmp_path, valid_test_config):
        """Test basic successful execution"""
        self._test_basic_execution(cli_runner, tmp_path, valid_test_config, "evaluate")

    def test_evaluate_basic_execution_no_accelerate(
        self, cli_runner, tmp_path, valid_test_config
    ):
        """Test basic successful execution without accelerate"""
        config_path = tmp_path / "config.yml"
        config_path.write_text(valid_test_config)

        with patch("axolotl.cli.evaluate.do_evaluate") as mock_evaluate:
            result = cli_runner.invoke(
                cli,
                [
                    "evaluate",
                    str(config_path),
                    "--no-accelerate",
                ],
                catch_exceptions=False,
            )

            assert result.exit_code == 0
            mock_evaluate.assert_called_once()

    def test_evaluate_cli_overrides(self, cli_runner, tmp_path, valid_test_config):
        """Test CLI arguments properly override config values"""
        config_path = self._test_cli_overrides(tmp_path, valid_test_config)

        with patch("axolotl.cli.evaluate.do_evaluate") as mock_evaluate:
            result = cli_runner.invoke(
                cli,
                [
                    "evaluate",
                    str(config_path),
                    "--micro-batch-size",
                    "2",
                    "--sequence-len",
                    "128",
                    "--no-accelerate",
                ],
                catch_exceptions=False,
            )

            assert result.exit_code == 0
            mock_evaluate.assert_called_once()
            cfg = mock_evaluate.call_args[0][0]
            assert cfg.micro_batch_size == 2
            assert cfg.sequence_len == 128