File size: 6,786 Bytes
b5b9c2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
"""Tests for the low context length warning in the CLI banner."""

import os
from types import SimpleNamespace
from unittest.mock import MagicMock, patch

import pytest


@pytest.fixture
def _isolate(tmp_path, monkeypatch):
    """Isolate HERMES_HOME so tests don't touch real config."""
    home = tmp_path / ".hermes"
    home.mkdir()
    monkeypatch.setenv("HERMES_HOME", str(home))


@pytest.fixture
def cli_obj(_isolate):
    """Create a minimal HermesCLI instance for banner testing."""
    with patch("cli.load_cli_config", return_value={
        "display": {"tool_progress": "new"},
        "terminal": {},
    }), patch("cli.get_tool_definitions", return_value=[]), \
         patch("cli.build_welcome_banner"):
        from cli import HermesCLI
        obj = HermesCLI.__new__(HermesCLI)
        obj.model = "test-model"
        obj.enabled_toolsets = ["hermes-core"]
        obj.compact = False
        obj.console = MagicMock()
        obj.session_id = None
        obj.api_key = "test"
        obj.base_url = ""
        obj.provider = "test"
        obj._provider_source = None
        # Mock agent with context compressor
        obj.agent = SimpleNamespace(
            context_compressor=SimpleNamespace(context_length=None)
        )
        return obj


class TestLowContextWarning:
    """Tests that the CLI warns about low context lengths."""

    def test_no_warning_for_normal_context(self, cli_obj):
        """No warning when context is 32k+."""
        cli_obj.agent.context_compressor.context_length = 32768
        with patch("cli.get_tool_definitions", return_value=[]), \
             patch("cli.build_welcome_banner"):
            cli_obj.show_banner()

        # Check that no yellow warning was printed
        calls = [str(c) for c in cli_obj.console.print.call_args_list]
        warning_calls = [c for c in calls if "too low" in c]
        assert len(warning_calls) == 0

    def test_warning_for_low_context(self, cli_obj):
        """Warning shown when context is 4096 (Ollama default)."""
        cli_obj.agent.context_compressor.context_length = 4096
        with patch("cli.get_tool_definitions", return_value=[]), \
             patch("cli.build_welcome_banner"):
            cli_obj.show_banner()

        calls = [str(c) for c in cli_obj.console.print.call_args_list]
        warning_calls = [c for c in calls if "too low" in c]
        assert len(warning_calls) == 1
        assert "4,096" in warning_calls[0]

    def test_warning_for_2048_context(self, cli_obj):
        """Warning shown for 2048 tokens (common LM Studio default)."""
        cli_obj.agent.context_compressor.context_length = 2048
        with patch("cli.get_tool_definitions", return_value=[]), \
             patch("cli.build_welcome_banner"):
            cli_obj.show_banner()

        calls = [str(c) for c in cli_obj.console.print.call_args_list]
        warning_calls = [c for c in calls if "too low" in c]
        assert len(warning_calls) == 1

    def test_no_warning_at_boundary(self, cli_obj):
        """No warning at exactly 8192 — 8192 is borderline but included in warning."""
        cli_obj.agent.context_compressor.context_length = 8192
        with patch("cli.get_tool_definitions", return_value=[]), \
             patch("cli.build_welcome_banner"):
            cli_obj.show_banner()

        calls = [str(c) for c in cli_obj.console.print.call_args_list]
        warning_calls = [c for c in calls if "too low" in c]
        assert len(warning_calls) == 1  # 8192 is still warned about

    def test_no_warning_above_boundary(self, cli_obj):
        """No warning at 16384."""
        cli_obj.agent.context_compressor.context_length = 16384
        with patch("cli.get_tool_definitions", return_value=[]), \
             patch("cli.build_welcome_banner"):
            cli_obj.show_banner()

        calls = [str(c) for c in cli_obj.console.print.call_args_list]
        warning_calls = [c for c in calls if "too low" in c]
        assert len(warning_calls) == 0

    def test_ollama_specific_hint(self, cli_obj):
        """Ollama-specific fix shown when port 11434 detected."""
        cli_obj.agent.context_compressor.context_length = 4096
        cli_obj.base_url = "http://localhost:11434/v1"
        with patch("cli.get_tool_definitions", return_value=[]), \
             patch("cli.build_welcome_banner"):
            cli_obj.show_banner()

        calls = [str(c) for c in cli_obj.console.print.call_args_list]
        ollama_hints = [c for c in calls if "OLLAMA_CONTEXT_LENGTH" in c]
        assert len(ollama_hints) == 1

    def test_lm_studio_specific_hint(self, cli_obj):
        """LM Studio-specific fix shown when port 1234 detected."""
        cli_obj.agent.context_compressor.context_length = 2048
        cli_obj.base_url = "http://localhost:1234/v1"
        with patch("cli.get_tool_definitions", return_value=[]), \
             patch("cli.build_welcome_banner"):
            cli_obj.show_banner()

        calls = [str(c) for c in cli_obj.console.print.call_args_list]
        lms_hints = [c for c in calls if "LM Studio" in c]
        assert len(lms_hints) == 1

    def test_generic_hint_for_other_servers(self, cli_obj):
        """Generic fix shown for unknown servers."""
        cli_obj.agent.context_compressor.context_length = 4096
        cli_obj.base_url = "http://localhost:8080/v1"
        with patch("cli.get_tool_definitions", return_value=[]), \
             patch("cli.build_welcome_banner"):
            cli_obj.show_banner()

        calls = [str(c) for c in cli_obj.console.print.call_args_list]
        generic_hints = [c for c in calls if "config.yaml" in c]
        assert len(generic_hints) == 1

    def test_no_warning_when_no_context_length(self, cli_obj):
        """No warning when context length is not yet known."""
        cli_obj.agent.context_compressor.context_length = None
        with patch("cli.get_tool_definitions", return_value=[]), \
             patch("cli.build_welcome_banner"):
            cli_obj.show_banner()

        calls = [str(c) for c in cli_obj.console.print.call_args_list]
        warning_calls = [c for c in calls if "too low" in c]
        assert len(warning_calls) == 0

    def test_compact_banner_does_not_crash_on_narrow_terminal(self, cli_obj):
        """Compact mode should still have ctx_len defined for warning logic."""
        cli_obj.agent.context_compressor.context_length = 4096

        with patch("shutil.get_terminal_size", return_value=os.terminal_size((70, 40))), \
             patch("cli._build_compact_banner", return_value="compact banner"):
            cli_obj.show_banner()

        calls = [str(c) for c in cli_obj.console.print.call_args_list]
        warning_calls = [c for c in calls if "too low" in c]
        assert len(warning_calls) == 1