Spaces:
Paused
Paused
| """Tests for trajectory_compressor.py — config, metrics, and compression logic.""" | |
| import json | |
| from types import SimpleNamespace | |
| from unittest.mock import AsyncMock, patch, MagicMock | |
| import pytest | |
| from trajectory_compressor import ( | |
| CompressionConfig, | |
| TrajectoryMetrics, | |
| AggregateMetrics, | |
| TrajectoryCompressor, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # CompressionConfig | |
| # --------------------------------------------------------------------------- | |
| class TestCompressionConfig: | |
| def test_defaults(self): | |
| config = CompressionConfig() | |
| assert config.target_max_tokens == 15250 | |
| assert config.summary_target_tokens == 750 | |
| assert config.protect_last_n_turns == 4 | |
| assert config.skip_under_target is True | |
| def test_from_yaml(self, tmp_path): | |
| yaml_content = """\ | |
| tokenizer: | |
| name: custom-tokenizer | |
| trust_remote_code: false | |
| compression: | |
| target_max_tokens: 10000 | |
| summary_target_tokens: 500 | |
| protected_turns: | |
| first_system: true | |
| first_human: false | |
| last_n_turns: 6 | |
| summarization: | |
| model: gpt-4 | |
| temperature: 0.5 | |
| max_retries: 5 | |
| output: | |
| add_summary_notice: false | |
| output_suffix: _short | |
| processing: | |
| num_workers: 8 | |
| max_concurrent_requests: 100 | |
| skip_under_target: false | |
| save_over_limit: false | |
| metrics: | |
| enabled: false | |
| per_trajectory: false | |
| output_file: my_metrics.json | |
| """ | |
| yaml_file = tmp_path / "config.yaml" | |
| yaml_file.write_text(yaml_content) | |
| config = CompressionConfig.from_yaml(str(yaml_file)) | |
| assert config.tokenizer_name == "custom-tokenizer" | |
| assert config.trust_remote_code is False | |
| assert config.target_max_tokens == 10000 | |
| assert config.summary_target_tokens == 500 | |
| assert config.protect_first_human is False | |
| assert config.protect_last_n_turns == 6 | |
| assert config.summarization_model == "gpt-4" | |
| assert config.temperature == 0.5 | |
| assert config.max_retries == 5 | |
| assert config.add_summary_notice is False | |
| assert config.output_suffix == "_short" | |
| assert config.num_workers == 8 | |
| assert config.max_concurrent_requests == 100 | |
| assert config.skip_under_target is False | |
| assert config.save_over_limit is False | |
| assert config.metrics_enabled is False | |
| assert config.metrics_output_file == "my_metrics.json" | |
| def test_from_yaml_partial(self, tmp_path): | |
| """Only specified sections override defaults.""" | |
| yaml_file = tmp_path / "config.yaml" | |
| yaml_file.write_text("compression:\n target_max_tokens: 8000\n") | |
| config = CompressionConfig.from_yaml(str(yaml_file)) | |
| assert config.target_max_tokens == 8000 | |
| # Other sections keep defaults | |
| assert config.protect_last_n_turns == 4 | |
| assert config.num_workers == 4 | |
| def test_from_yaml_empty(self, tmp_path): | |
| yaml_file = tmp_path / "config.yaml" | |
| yaml_file.write_text("{}\n") | |
| config = CompressionConfig.from_yaml(str(yaml_file)) | |
| assert config.target_max_tokens == 15250 # all defaults | |
| # --------------------------------------------------------------------------- | |
| # TrajectoryMetrics | |
| # --------------------------------------------------------------------------- | |
| class TestTrajectoryMetrics: | |
| def test_to_dict(self): | |
| m = TrajectoryMetrics() | |
| m.original_tokens = 10000 | |
| m.compressed_tokens = 5000 | |
| m.tokens_saved = 5000 | |
| m.compression_ratio = 0.5 | |
| m.original_turns = 20 | |
| m.compressed_turns = 10 | |
| m.turns_removed = 10 | |
| m.was_compressed = True | |
| d = m.to_dict() | |
| assert d["original_tokens"] == 10000 | |
| assert d["compressed_tokens"] == 5000 | |
| assert d["compression_ratio"] == 0.5 | |
| assert d["was_compressed"] is True | |
| assert d["compression_region"]["start_idx"] == -1 | |
| def test_default_values(self): | |
| m = TrajectoryMetrics() | |
| d = m.to_dict() | |
| assert d["original_tokens"] == 0 | |
| assert d["was_compressed"] is False | |
| assert d["skipped_under_target"] is False | |
| # --------------------------------------------------------------------------- | |
| # AggregateMetrics | |
| # --------------------------------------------------------------------------- | |
| class TestAggregateMetrics: | |
| def test_empty_to_dict(self): | |
| agg = AggregateMetrics() | |
| d = agg.to_dict() | |
| assert d["summary"]["total_trajectories"] == 0 | |
| assert d["averages"]["avg_compression_ratio"] == 1.0 | |
| assert d["averages"]["avg_tokens_saved_per_compressed"] == 0 | |
| def test_add_compressed_trajectory(self): | |
| agg = AggregateMetrics() | |
| m = TrajectoryMetrics() | |
| m.original_tokens = 20000 | |
| m.compressed_tokens = 10000 | |
| m.tokens_saved = 10000 | |
| m.compression_ratio = 0.5 | |
| m.original_turns = 30 | |
| m.compressed_turns = 15 | |
| m.turns_removed = 15 | |
| m.was_compressed = True | |
| agg.add_trajectory_metrics(m) | |
| assert agg.total_trajectories == 1 | |
| assert agg.trajectories_compressed == 1 | |
| assert agg.total_tokens_saved == 10000 | |
| assert len(agg.compression_ratios) == 1 | |
| def test_add_skipped_trajectory(self): | |
| agg = AggregateMetrics() | |
| m = TrajectoryMetrics() | |
| m.original_tokens = 5000 | |
| m.compressed_tokens = 5000 | |
| m.skipped_under_target = True | |
| agg.add_trajectory_metrics(m) | |
| assert agg.trajectories_skipped_under_target == 1 | |
| assert agg.trajectories_compressed == 0 | |
| def test_add_over_limit_trajectory(self): | |
| agg = AggregateMetrics() | |
| m = TrajectoryMetrics() | |
| m.original_tokens = 20000 | |
| m.compressed_tokens = 16000 | |
| m.still_over_limit = True | |
| m.was_compressed = True | |
| m.compression_ratio = 0.8 | |
| agg.add_trajectory_metrics(m) | |
| assert agg.trajectories_still_over_limit == 1 | |
| def test_multiple_trajectories_aggregation(self): | |
| agg = AggregateMetrics() | |
| for i in range(3): | |
| m = TrajectoryMetrics() | |
| m.original_tokens = 10000 | |
| m.compressed_tokens = 5000 | |
| m.tokens_saved = 5000 | |
| m.turns_removed = 5 | |
| m.was_compressed = True | |
| m.compression_ratio = 0.5 | |
| agg.add_trajectory_metrics(m) | |
| d = agg.to_dict() | |
| assert d["summary"]["total_trajectories"] == 3 | |
| assert d["summary"]["trajectories_compressed"] == 3 | |
| assert d["tokens"]["total_saved"] == 15000 | |
| assert d["averages"]["avg_compression_ratio"] == 0.5 | |
| def test_to_dict_no_division_by_zero(self): | |
| """Ensure no ZeroDivisionError with empty data.""" | |
| agg = AggregateMetrics() | |
| d = agg.to_dict() | |
| assert d["summarization"]["success_rate"] == 1.0 | |
| assert d["tokens"]["overall_compression_ratio"] == 0.0 | |
| # --------------------------------------------------------------------------- | |
| # TrajectoryCompressor._find_protected_indices | |
| # --------------------------------------------------------------------------- | |
| def _make_compressor(config=None): | |
| """Create a TrajectoryCompressor with mocked tokenizer and summarizer.""" | |
| if config is None: | |
| config = CompressionConfig() | |
| with patch.object(TrajectoryCompressor, '_init_tokenizer'), \ | |
| patch.object(TrajectoryCompressor, '_init_summarizer'): | |
| compressor = TrajectoryCompressor(config) | |
| # Provide a simple token counter for tests (1 token per 4 chars) | |
| compressor.tokenizer = MagicMock() | |
| compressor.tokenizer.encode = lambda text: [0] * (len(text) // 4) | |
| return compressor | |
| class TestFindProtectedIndices: | |
| def test_basic_trajectory(self): | |
| tc = _make_compressor() | |
| trajectory = [ | |
| {"from": "system", "value": "You are an agent."}, | |
| {"from": "human", "value": "Do something."}, | |
| {"from": "gpt", "value": "I will use a tool."}, | |
| {"from": "tool", "value": "Tool result."}, | |
| {"from": "gpt", "value": "More work."}, | |
| {"from": "tool", "value": "Another result."}, | |
| {"from": "gpt", "value": "Work continues."}, | |
| {"from": "tool", "value": "Result 3."}, | |
| {"from": "gpt", "value": "Done."}, | |
| {"from": "human", "value": "Thanks."}, | |
| ] | |
| protected, start, end = tc._find_protected_indices(trajectory) | |
| # First system (0), human (1), gpt (2), tool (3) are protected | |
| assert 0 in protected | |
| assert 1 in protected | |
| assert 2 in protected | |
| assert 3 in protected | |
| # Last 4 turns (6,7,8,9) are protected | |
| assert 6 in protected | |
| assert 7 in protected | |
| assert 8 in protected | |
| assert 9 in protected | |
| # Compressible region should be between head and tail | |
| assert start >= 4 | |
| assert end <= 6 | |
| def test_short_trajectory_all_protected(self): | |
| tc = _make_compressor() | |
| trajectory = [ | |
| {"from": "system", "value": "sys"}, | |
| {"from": "human", "value": "hi"}, | |
| {"from": "gpt", "value": "hello"}, | |
| ] | |
| protected, start, end = tc._find_protected_indices(trajectory) | |
| # All 3 turns should be protected (first of each + last 4 covers all) | |
| assert len(protected) == 3 | |
| assert start >= end # Nothing to compress | |
| def test_protect_last_n_zero(self): | |
| config = CompressionConfig() | |
| config.protect_last_n_turns = 0 | |
| tc = _make_compressor(config) | |
| trajectory = [ | |
| {"from": "system", "value": "sys"}, | |
| {"from": "human", "value": "q"}, | |
| {"from": "gpt", "value": "a"}, | |
| {"from": "tool", "value": "r"}, | |
| {"from": "gpt", "value": "b"}, | |
| {"from": "tool", "value": "r2"}, | |
| {"from": "gpt", "value": "c"}, | |
| {"from": "tool", "value": "r3"}, | |
| ] | |
| protected, start, end = tc._find_protected_indices(trajectory) | |
| # Only first occurrences protected, no tail protection | |
| assert 0 in protected | |
| assert 1 in protected | |
| assert 2 in protected | |
| assert 3 in protected | |
| assert 7 not in protected | |
| def test_no_system_turn(self): | |
| tc = _make_compressor() | |
| trajectory = [ | |
| {"from": "human", "value": "hi"}, | |
| {"from": "gpt", "value": "hello"}, | |
| {"from": "tool", "value": "data"}, | |
| {"from": "gpt", "value": "result"}, | |
| {"from": "human", "value": "thanks"}, | |
| ] | |
| protected, start, end = tc._find_protected_indices(trajectory) | |
| assert 0 in protected # first human | |
| def test_disable_protect_first_system(self): | |
| config = CompressionConfig() | |
| config.protect_first_system = False | |
| tc = _make_compressor(config) | |
| trajectory = [ | |
| {"from": "system", "value": "sys"}, | |
| {"from": "human", "value": "q"}, | |
| {"from": "gpt", "value": "a"}, | |
| {"from": "tool", "value": "r"}, | |
| {"from": "gpt", "value": "b"}, | |
| {"from": "tool", "value": "r2"}, | |
| {"from": "gpt", "value": "c"}, | |
| {"from": "tool", "value": "r3"}, | |
| ] | |
| protected, _, _ = tc._find_protected_indices(trajectory) | |
| assert 0 not in protected # system not protected | |
| # --------------------------------------------------------------------------- | |
| # TrajectoryCompressor._extract_turn_content_for_summary | |
| # --------------------------------------------------------------------------- | |
| class TestExtractTurnContent: | |
| def test_basic_extraction(self): | |
| tc = _make_compressor() | |
| trajectory = [ | |
| {"from": "gpt", "value": "I will search."}, | |
| {"from": "tool", "value": "Search result: found it."}, | |
| {"from": "gpt", "value": "Great, done."}, | |
| ] | |
| content = tc._extract_turn_content_for_summary(trajectory, 0, 2) | |
| assert "[Turn 0 - GPT]" in content | |
| assert "I will search." in content | |
| assert "[Turn 1 - TOOL]" in content | |
| assert "Search result: found it." in content | |
| # Turn 2 should NOT be included (end is exclusive) | |
| assert "[Turn 2" not in content | |
| def test_long_content_truncated(self): | |
| tc = _make_compressor() | |
| trajectory = [ | |
| {"from": "tool", "value": "x" * 5000}, | |
| ] | |
| content = tc._extract_turn_content_for_summary(trajectory, 0, 1) | |
| assert "...[truncated]..." in content | |
| assert len(content) < 5000 | |
| def test_empty_range(self): | |
| tc = _make_compressor() | |
| trajectory = [{"from": "gpt", "value": "hello"}] | |
| content = tc._extract_turn_content_for_summary(trajectory, 0, 0) | |
| assert content == "" | |
| # --------------------------------------------------------------------------- | |
| # TrajectoryCompressor.count_tokens / count_trajectory_tokens | |
| # --------------------------------------------------------------------------- | |
| class TestTokenCounting: | |
| def test_count_tokens_empty(self): | |
| tc = _make_compressor() | |
| assert tc.count_tokens("") == 0 | |
| def test_count_tokens_basic(self): | |
| tc = _make_compressor() | |
| # Our mock: 1 token per 4 chars | |
| assert tc.count_tokens("12345678") == 2 | |
| def test_count_trajectory_tokens(self): | |
| tc = _make_compressor() | |
| trajectory = [ | |
| {"from": "system", "value": "12345678"}, # 2 tokens | |
| {"from": "human", "value": "1234567890ab"}, # 3 tokens | |
| ] | |
| assert tc.count_trajectory_tokens(trajectory) == 5 | |
| def test_count_turn_tokens(self): | |
| tc = _make_compressor() | |
| trajectory = [ | |
| {"from": "system", "value": "1234"}, # 1 token | |
| {"from": "human", "value": "12345678"}, # 2 tokens | |
| ] | |
| result = tc.count_turn_tokens(trajectory) | |
| assert result == [1, 2] | |
| def test_count_tokens_fallback_on_error(self): | |
| tc = _make_compressor() | |
| tc.tokenizer.encode = MagicMock(side_effect=Exception("fail")) | |
| # Should fallback to len(text) // 4 | |
| assert tc.count_tokens("12345678") == 2 | |
| class TestGenerateSummary: | |
| def test_generate_summary_handles_none_content(self): | |
| tc = _make_compressor() | |
| tc.client = MagicMock() | |
| tc.client.chat.completions.create.return_value = SimpleNamespace( | |
| choices=[SimpleNamespace(message=SimpleNamespace(content=None))] | |
| ) | |
| metrics = TrajectoryMetrics() | |
| summary = tc._generate_summary("Turn content", metrics) | |
| assert summary == "[CONTEXT SUMMARY]:" | |
| async def test_generate_summary_async_handles_none_content(self): | |
| tc = _make_compressor() | |
| mock_client = MagicMock() | |
| mock_client.chat.completions.create = AsyncMock( | |
| return_value=SimpleNamespace( | |
| choices=[SimpleNamespace(message=SimpleNamespace(content=None))] | |
| ) | |
| ) | |
| tc._get_async_client = MagicMock(return_value=mock_client) | |
| metrics = TrajectoryMetrics() | |
| summary = await tc._generate_summary_async("Turn content", metrics) | |
| assert summary == "[CONTEXT SUMMARY]:" | |