Spaces:
Running
Running
| """Tests for the Traffic Pattern Learner. | |
| Tests pattern extraction from proxy traffic without requiring | |
| a real memory backend. | |
| """ | |
| from __future__ import annotations | |
| from datetime import datetime, timedelta, timezone | |
| import pytest | |
| from headroom.memory.traffic_learner import ( | |
| ExtractedPattern, | |
| PatternCategory, | |
| TrafficLearner, | |
| _classify_error, | |
| _is_error, | |
| _load_persisted_patterns_from_sqlite, | |
| _normalize_bash_for_hash, | |
| _parse_iso_timestamp, | |
| _patterns_to_recommendations, | |
| _project_for_pattern, | |
| _refine_error_recovery, | |
| ) | |
| UTC = timezone.utc | |
| # ============================================================================= | |
| # Error Classification Tests | |
| # ============================================================================= | |
| class TestErrorClassification: | |
| def test_file_not_found(self): | |
| assert _classify_error("No such file or directory: foo.py") == "file_not_found" | |
| assert _classify_error("FileNotFoundError: [Errno 2]") == "file_not_found" | |
| def test_command_not_found(self): | |
| assert _classify_error("zsh: command not found: ruff") == "command_not_found" | |
| def test_module_not_found(self): | |
| assert _classify_error("ModuleNotFoundError: No module named 'foo'") == "module_not_found" | |
| def test_permission_denied(self): | |
| assert _classify_error("Permission denied: /etc/shadow") == "permission_denied" | |
| def test_not_an_error(self): | |
| assert _classify_error("Everything is fine, tests passed!") is None | |
| assert _classify_error("") is None | |
| def test_is_error_helper(self): | |
| assert _is_error("No such file or directory") | |
| assert not _is_error("All tests passed") | |
| assert not _is_error("") | |
| assert not _is_error("short") | |
| # ============================================================================= | |
| # Traffic Learner Core Tests | |
| # ============================================================================= | |
| class TestTrafficLearner: | |
| def learner(self): | |
| """Create a learner with low evidence threshold for testing.""" | |
| return TrafficLearner( | |
| backend=None, | |
| user_id="test-user", | |
| min_evidence=1, # Save on first sighting for tests | |
| ) | |
| async def test_error_recovery_bash(self, learner: TrafficLearner): | |
| """Test error→recovery pattern extraction for Bash commands.""" | |
| # First: a failed command | |
| await learner.on_tool_result( | |
| tool_name="Bash", | |
| tool_input={"command": "ruff check ."}, | |
| tool_output="zsh: command not found: ruff", | |
| is_error=True, | |
| ) | |
| # Then: the recovery | |
| await learner.on_tool_result( | |
| tool_name="Bash", | |
| tool_input={"command": "source .venv/bin/activate && ruff check ."}, | |
| tool_output="All checks passed!", | |
| is_error=False, | |
| ) | |
| stats = learner.get_stats() | |
| assert stats["patterns_extracted"] >= 1 | |
| assert stats["requests_processed"] == 2 | |
| async def test_error_recovery_read(self, learner: TrafficLearner): | |
| """Test error→recovery for Read tool (wrong path → correct path).""" | |
| await learner.on_tool_result( | |
| tool_name="Read", | |
| tool_input={"file_path": "/src/old_module.py"}, | |
| tool_output="No such file or directory: /src/old_module.py", | |
| is_error=True, | |
| ) | |
| await learner.on_tool_result( | |
| tool_name="Read", | |
| tool_input={"file_path": "/src/new_module.py"}, | |
| tool_output="# Module content here\nclass Foo: pass", | |
| is_error=False, | |
| ) | |
| stats = learner.get_stats() | |
| assert stats["patterns_extracted"] >= 1 | |
| async def test_environment_venv_detection(self, learner: TrafficLearner): | |
| """Test detection of virtual environment activation patterns.""" | |
| await learner.on_tool_result( | |
| tool_name="Bash", | |
| tool_input={"command": "source /project/.venv/bin/activate && pytest"}, | |
| tool_output="5 passed in 2.1s", | |
| is_error=False, | |
| ) | |
| stats = learner.get_stats() | |
| assert stats["patterns_extracted"] >= 1 | |
| async def test_preference_extraction(self, learner: TrafficLearner): | |
| """Test extraction of user preference signals.""" | |
| await learner.on_messages( | |
| [ | |
| {"role": "user", "content": "don't use git push, I'll push manually"}, | |
| ] | |
| ) | |
| stats = learner.get_stats() | |
| assert stats["patterns_extracted"] >= 1 | |
| async def test_preference_from_content_blocks(self, learner: TrafficLearner): | |
| """Test preference extraction from Anthropic content block format.""" | |
| await learner.on_messages( | |
| [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "text", "text": "stop running the full test suite without asking"}, | |
| ], | |
| }, | |
| ] | |
| ) | |
| stats = learner.get_stats() | |
| assert stats["patterns_extracted"] >= 1 | |
| async def test_evidence_accumulation(self): | |
| """Test that patterns need min_evidence before saving.""" | |
| learner = TrafficLearner(backend=None, min_evidence=3) | |
| # Same error→recovery pattern 3 times | |
| for _ in range(3): | |
| await learner.on_tool_result( | |
| tool_name="Bash", | |
| tool_input={"command": "python test.py"}, | |
| tool_output="command not found: python", | |
| is_error=True, | |
| ) | |
| await learner.on_tool_result( | |
| tool_name="Bash", | |
| tool_input={"command": "python3 test.py"}, | |
| tool_output="OK", | |
| is_error=False, | |
| ) | |
| stats = learner.get_stats() | |
| assert stats["patterns_extracted"] >= 3 | |
| async def test_dedup(self, learner: TrafficLearner): | |
| """Test that identical patterns are deduplicated.""" | |
| # Same pattern twice | |
| for _ in range(2): | |
| await learner.on_tool_result( | |
| tool_name="Bash", | |
| tool_input={"command": "ruff check ."}, | |
| tool_output="command not found: ruff", | |
| is_error=True, | |
| ) | |
| await learner.on_tool_result( | |
| tool_name="Bash", | |
| tool_input={"command": ".venv/bin/ruff check ."}, | |
| tool_output="OK", | |
| is_error=False, | |
| ) | |
| # Should not double-count the same pattern | |
| stats = learner.get_stats() | |
| # First extraction saves, second is deduped | |
| assert stats["patterns_extracted"] >= 1 | |
| async def test_extract_tool_results_from_messages(self, learner: TrafficLearner): | |
| """Test extraction of tool results from Anthropic message format.""" | |
| messages = [ | |
| { | |
| "role": "assistant", | |
| "content": [ | |
| { | |
| "type": "tool_use", | |
| "id": "tu_1", | |
| "name": "Bash", | |
| "input": {"command": "ls"}, | |
| } | |
| ], | |
| }, | |
| { | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "tool_result", | |
| "tool_use_id": "tu_1", | |
| "content": [{"type": "text", "text": "file1.py\nfile2.py"}], | |
| } | |
| ], | |
| }, | |
| ] | |
| results = learner.extract_tool_results_from_messages(messages) | |
| assert len(results) == 1 | |
| assert results[0]["tool_name"] == "Bash" | |
| assert "file1.py" in results[0]["output"] | |
| assert not results[0]["is_error"] | |
| async def test_tool_history_bounded(self, learner: TrafficLearner): | |
| """Test that tool history stays within max_history.""" | |
| for i in range(30): | |
| await learner.on_tool_result( | |
| tool_name="Read", | |
| tool_input={"file_path": f"/file{i}.py"}, | |
| tool_output=f"content {i}", | |
| is_error=False, | |
| ) | |
| assert len(learner._tool_history) <= learner._max_history | |
| async def test_no_pattern_from_success_only(self, learner: TrafficLearner): | |
| """Test that success without prior error doesn't generate error_recovery pattern.""" | |
| await learner.on_tool_result( | |
| tool_name="Bash", | |
| tool_input={"command": "echo hello"}, | |
| tool_output="hello", | |
| is_error=False, | |
| ) | |
| stats = learner.get_stats() | |
| # Only environment patterns possible, no error_recovery | |
| assert stats["requests_processed"] == 1 | |
| # ============================================================================= | |
| # Pattern Model Tests | |
| # ============================================================================= | |
| class TestExtractedPattern: | |
| def test_content_hash_deterministic(self): | |
| p1 = ExtractedPattern( | |
| category=PatternCategory.ENVIRONMENT, | |
| content="Use venv", | |
| importance=0.5, | |
| ) | |
| p2 = ExtractedPattern( | |
| category=PatternCategory.ENVIRONMENT, | |
| content="Use venv", | |
| importance=0.8, # Different importance, same hash | |
| ) | |
| assert p1.content_hash == p2.content_hash | |
| def test_different_content_different_hash(self): | |
| p1 = ExtractedPattern( | |
| category=PatternCategory.ENVIRONMENT, | |
| content="Use venv", | |
| importance=0.5, | |
| ) | |
| p2 = ExtractedPattern( | |
| category=PatternCategory.ENVIRONMENT, | |
| content="Use conda", | |
| importance=0.5, | |
| ) | |
| assert p1.content_hash != p2.content_hash | |
| # ============================================================================= | |
| # Project Routing | |
| # ============================================================================= | |
| class TestProjectForPattern: | |
| def _project(self, path: str): | |
| from pathlib import Path as _P | |
| from headroom.learn.models import ProjectInfo | |
| p = _P(path) | |
| return ProjectInfo(name=p.name, project_path=p, data_path=p) | |
| def test_matches_longest_root(self): | |
| proj_a = self._project("/x/a") | |
| proj_b = self._project("/x/a/b") | |
| pattern = ExtractedPattern( | |
| category=PatternCategory.ERROR_RECOVERY, | |
| content="File `/x/a/b/foo.py` does not exist.", | |
| importance=0.5, | |
| ) | |
| result = _project_for_pattern(pattern, [proj_a, proj_b]) | |
| assert result is proj_b | |
| def test_returns_none_for_unanchored(self): | |
| proj_a = self._project("/x/a") | |
| pattern = ExtractedPattern( | |
| category=PatternCategory.PREFERENCE, | |
| content="User preference: use terse responses", | |
| importance=0.7, | |
| ) | |
| assert _project_for_pattern(pattern, [proj_a]) is None | |
| def test_matches_via_entity_refs(self): | |
| proj = self._project("/x/a") | |
| pattern = ExtractedPattern( | |
| category=PatternCategory.ERROR_RECOVERY, | |
| content="Command failed.", | |
| importance=0.5, | |
| entity_refs=["/x/a/tool.py"], | |
| ) | |
| assert _project_for_pattern(pattern, [proj]) is proj | |
| def test_no_false_match_on_prefix_boundary(self): | |
| # /x/ab should not match a project rooted at /x/a | |
| proj_a = self._project("/x/a") | |
| pattern = ExtractedPattern( | |
| category=PatternCategory.ERROR_RECOVERY, | |
| content="File `/x/ab/foo.py` does not exist.", | |
| importance=0.5, | |
| ) | |
| assert _project_for_pattern(pattern, [proj_a]) is None | |
| # ============================================================================= | |
| # Persisted-pattern loading from memory.db | |
| # ============================================================================= | |
| class TestLoadPersistedPatterns: | |
| def _make_db(self, tmp_path, rows: list[dict]): | |
| import json as _json | |
| import sqlite3 as _sql | |
| db = tmp_path / "memory.db" | |
| conn = _sql.connect(db) | |
| conn.execute( | |
| "CREATE TABLE memories (" | |
| "id TEXT PRIMARY KEY, content TEXT NOT NULL, " | |
| "metadata TEXT NOT NULL DEFAULT '{}', " | |
| "entity_refs TEXT NOT NULL DEFAULT '[]', " | |
| "importance REAL NOT NULL DEFAULT 0.5, " | |
| "created_at TEXT)" | |
| ) | |
| for i, r in enumerate(rows): | |
| conn.execute( | |
| "INSERT INTO memories " | |
| "(id, content, metadata, entity_refs, importance, created_at) " | |
| "VALUES (?,?,?,?,?,?)", | |
| ( | |
| str(i), | |
| r["content"], | |
| _json.dumps(r.get("metadata", {})), | |
| _json.dumps(r.get("entity_refs", [])), | |
| r.get("importance", 0.5), | |
| r.get("created_at"), | |
| ), | |
| ) | |
| conn.commit() | |
| conn.close() | |
| return db | |
| def test_dedupes_by_content_and_sums_evidence(self, tmp_path): | |
| db = self._make_db( | |
| tmp_path, | |
| [ | |
| { | |
| "content": "Command `foo` fails.", | |
| "metadata": { | |
| "source": "traffic_learner", | |
| "category": "error_recovery", | |
| "evidence_count": 2, | |
| }, | |
| }, | |
| { | |
| "content": "Command `foo` fails.", | |
| "metadata": { | |
| "source": "traffic_learner", | |
| "category": "error_recovery", | |
| "evidence_count": 3, | |
| }, | |
| }, | |
| ], | |
| ) | |
| patterns = _load_persisted_patterns_from_sqlite(db) | |
| assert len(patterns) == 1 | |
| assert patterns[0].evidence_count == 5 | |
| assert patterns[0].category == PatternCategory.ERROR_RECOVERY | |
| def test_skips_non_traffic_rows(self, tmp_path): | |
| db = self._make_db( | |
| tmp_path, | |
| [ | |
| { | |
| "content": "Something else", | |
| "metadata": {"source": "other"}, | |
| }, | |
| { | |
| "content": "From traffic", | |
| "metadata": { | |
| "source": "traffic_learner", | |
| "category": "environment", | |
| }, | |
| }, | |
| ], | |
| ) | |
| patterns = _load_persisted_patterns_from_sqlite(db) | |
| assert len(patterns) == 1 | |
| assert patterns[0].content == "From traffic" | |
| def test_reads_importance_column(self, tmp_path): | |
| db = self._make_db( | |
| tmp_path, | |
| [ | |
| { | |
| "content": "High-importance pattern", | |
| "metadata": { | |
| "source": "traffic_learner", | |
| "category": "environment", | |
| }, | |
| "importance": 0.85, | |
| }, | |
| ], | |
| ) | |
| patterns = _load_persisted_patterns_from_sqlite(db) | |
| assert len(patterns) == 1 | |
| assert patterns[0].importance == 0.85 | |
| def test_skips_unknown_category(self, tmp_path): | |
| db = self._make_db( | |
| tmp_path, | |
| [ | |
| { | |
| "content": "X", | |
| "metadata": {"source": "traffic_learner", "category": "bogus"}, | |
| }, | |
| ], | |
| ) | |
| assert _load_persisted_patterns_from_sqlite(db) == [] | |
| # ============================================================================= | |
| # Category → recommendation routing | |
| # ============================================================================= | |
| class TestPatternsToRecommendations: | |
| def test_routes_preference_to_memory_file(self): | |
| from headroom.learn.models import RecommendationTarget | |
| patterns = [ | |
| ExtractedPattern( | |
| category=PatternCategory.PREFERENCE, | |
| content="User prefers terse output", | |
| importance=0.8, | |
| evidence_count=3, | |
| ), | |
| ] | |
| recs = _patterns_to_recommendations(patterns) | |
| assert len(recs) == 1 | |
| assert recs[0].target == RecommendationTarget.MEMORY_FILE | |
| assert "User prefers terse output" in recs[0].content | |
| def test_routes_environment_to_context_file(self): | |
| from headroom.learn.models import RecommendationTarget | |
| patterns = [ | |
| ExtractedPattern( | |
| category=PatternCategory.ENVIRONMENT, | |
| content="Use uv run python", | |
| importance=0.7, | |
| evidence_count=4, | |
| ), | |
| ] | |
| recs = _patterns_to_recommendations(patterns) | |
| assert len(recs) == 1 | |
| assert recs[0].target == RecommendationTarget.CONTEXT_FILE | |
| def test_groups_by_category(self): | |
| patterns = [ | |
| ExtractedPattern( | |
| category=PatternCategory.ERROR_RECOVERY, | |
| content="A", | |
| importance=0.5, | |
| evidence_count=2, | |
| ), | |
| ExtractedPattern( | |
| category=PatternCategory.ERROR_RECOVERY, | |
| content="B", | |
| importance=0.5, | |
| evidence_count=5, | |
| ), | |
| ] | |
| recs = _patterns_to_recommendations(patterns) | |
| assert len(recs) == 1 | |
| # B has higher evidence, should sort first | |
| lines = recs[0].content.splitlines() | |
| assert lines[0] == "- B" | |
| assert lines[1] == "- A" | |
| assert recs[0].evidence_count == 7 | |
| # ============================================================================= | |
| # Debounced flush worker | |
| # ============================================================================= | |
| class TestFlushDebounce: | |
| async def test_flush_worker_rate_limits(self, monkeypatch): | |
| """Rapid dirty flags should not cause rapid flush_to_file calls.""" | |
| from headroom.memory import traffic_learner as tl_mod | |
| # Shorten debounce for a fast test | |
| monkeypatch.setattr(tl_mod, "FLUSH_DEBOUNCE_SECONDS", 0.5) | |
| learner = TrafficLearner(backend=None, min_evidence=1) | |
| call_count = 0 | |
| async def fake_flush() -> None: | |
| nonlocal call_count | |
| call_count += 1 | |
| learner.flush_to_file = fake_flush # type: ignore[method-assign] | |
| await learner.start() | |
| # Toggle dirty rapidly over ~1.2s, which permits at most ~2 flushes. | |
| for _ in range(30): | |
| learner._flush_dirty = True | |
| await __import__("asyncio").sleep(0.04) | |
| await learner.stop() | |
| # start() kicked a flush dirty→false at some point; stop() also calls | |
| # flush_to_file once (final flush). We want evidence the worker did | |
| # NOT call flush on every sleep tick — cap is generous. | |
| assert call_count <= 5, f"Expected few flushes, got {call_count}" | |
| assert call_count >= 1, "Expected at least one flush during the burst" | |
| # ============================================================================= | |
| # Evidence-count persistence & re-sighting bumps | |
| # ============================================================================= | |
| class _FakeBackend: | |
| """Minimal LocalBackend stand-in that persists to a real SQLite file. | |
| Provides just enough surface area for TrafficLearner: `_config.db_path` | |
| (read by `_resolve_backend_db_path`) and an `async save_memory` that | |
| inserts a row and returns an object with `.id`. | |
| """ | |
| def __init__(self, db_path): | |
| import types as _types | |
| self._config = _types.SimpleNamespace(db_path=str(db_path)) | |
| self._db_path = str(db_path) | |
| async def save_memory( | |
| self, | |
| *, | |
| content: str, | |
| user_id: str, | |
| importance: float, | |
| metadata: dict, | |
| ): | |
| import json as _json | |
| import sqlite3 as _sql | |
| import types as _types | |
| import uuid | |
| mid = str(uuid.uuid4()) | |
| conn = _sql.connect(self._db_path) | |
| try: | |
| conn.execute( | |
| "INSERT INTO memories (id, content, metadata, entity_refs, importance) " | |
| "VALUES (?,?,?,?,?)", | |
| (mid, content, _json.dumps(metadata), "[]", importance), | |
| ) | |
| conn.commit() | |
| finally: | |
| conn.close() | |
| return _types.SimpleNamespace(id=mid) | |
| def _init_db(path): | |
| import sqlite3 as _sql | |
| conn = _sql.connect(path) | |
| conn.execute( | |
| "CREATE TABLE memories (" | |
| "id TEXT PRIMARY KEY, content TEXT NOT NULL, " | |
| "metadata TEXT NOT NULL DEFAULT '{}', " | |
| "entity_refs TEXT NOT NULL DEFAULT '[]', " | |
| "importance REAL NOT NULL DEFAULT 0.5, " | |
| "created_at TEXT)" | |
| ) | |
| conn.commit() | |
| conn.close() | |
| def _read_traffic_rows(db_path): | |
| import json as _json | |
| import sqlite3 as _sql | |
| conn = _sql.connect(db_path) | |
| try: | |
| rows = conn.execute( | |
| "SELECT id, content, metadata FROM memories " | |
| "WHERE json_extract(metadata, '$.source') = 'traffic_learner'" | |
| ).fetchall() | |
| finally: | |
| conn.close() | |
| return [(r[0], r[1], _json.loads(r[2])) for r in rows] | |
| async def _wait_for_saved(learner: TrafficLearner, count: int, db_path) -> None: | |
| """Wait until at least `count` traffic_learner rows exist in the DB.""" | |
| import asyncio as _asyncio | |
| for _ in range(100): | |
| if len(_read_traffic_rows(db_path)) >= count: | |
| return | |
| await _asyncio.sleep(0.02) | |
| raise AssertionError( | |
| f"Timeout waiting for {count} saved row(s); got {len(_read_traffic_rows(db_path))}" | |
| ) | |
| class TestEvidencePersistence: | |
| async def test_save_persists_actual_evidence_count(self, tmp_path): | |
| """The count written to the DB reflects total sightings, not the default 1.""" | |
| db = tmp_path / "memory.db" | |
| _init_db(db) | |
| backend = _FakeBackend(db) | |
| learner = TrafficLearner(backend=backend, min_evidence=3) | |
| await learner.start() | |
| pattern_kwargs = { | |
| "category": PatternCategory.ENVIRONMENT, | |
| "content": "Use /usr/bin/python3 for system scripts.", | |
| "importance": 0.6, | |
| } | |
| for _ in range(3): | |
| await learner._accumulate(ExtractedPattern(**pattern_kwargs)) | |
| await _wait_for_saved(learner, 1, db) | |
| await learner.stop() | |
| rows = _read_traffic_rows(db) | |
| assert len(rows) == 1 | |
| assert rows[0][2]["evidence_count"] == 3 | |
| async def test_resighting_bumps_persisted_row(self, tmp_path): | |
| """Sightings after save bump the existing row instead of creating duplicates.""" | |
| db = tmp_path / "memory.db" | |
| _init_db(db) | |
| backend = _FakeBackend(db) | |
| learner = TrafficLearner(backend=backend, min_evidence=2) | |
| await learner.start() | |
| def mk() -> ExtractedPattern: | |
| return ExtractedPattern( | |
| category=PatternCategory.PREFERENCE, | |
| content="User preference: terse replies.", | |
| importance=0.7, | |
| ) | |
| # Two sightings → save with evidence_count=2. | |
| await learner._accumulate(mk()) | |
| await learner._accumulate(mk()) | |
| await _wait_for_saved(learner, 1, db) | |
| # Three more sightings → three bumps. | |
| for _ in range(3): | |
| await learner._accumulate(mk()) | |
| await learner.stop() | |
| rows = _read_traffic_rows(db) | |
| assert len(rows) == 1, "re-sightings must not create duplicate rows" | |
| assert rows[0][2]["evidence_count"] == 5 | |
| async def test_hydrate_prevents_cross_session_duplicates(self, tmp_path): | |
| """A second session re-sighting an already-persisted pattern bumps, doesn't insert.""" | |
| import json as _json | |
| import sqlite3 as _sql | |
| db = tmp_path / "memory.db" | |
| _init_db(db) | |
| # Session 1 row pre-seeded directly. | |
| seeded_content = "Command `foo` fails; use `bar` instead." | |
| conn = _sql.connect(db) | |
| conn.execute( | |
| "INSERT INTO memories (id, content, metadata, entity_refs, importance) " | |
| "VALUES (?,?,?,?,?)", | |
| ( | |
| "seed-id", | |
| seeded_content, | |
| _json.dumps( | |
| { | |
| "source": "traffic_learner", | |
| "category": "error_recovery", | |
| "evidence_count": 2, | |
| } | |
| ), | |
| "[]", | |
| 0.7, | |
| ), | |
| ) | |
| conn.commit() | |
| conn.close() | |
| # Session 2: fresh learner, hydrates from DB on start(). | |
| backend = _FakeBackend(db) | |
| learner = TrafficLearner(backend=backend, min_evidence=2) | |
| await learner.start() | |
| def mk() -> ExtractedPattern: | |
| return ExtractedPattern( | |
| category=PatternCategory.ERROR_RECOVERY, | |
| content=seeded_content, | |
| importance=0.7, | |
| ) | |
| # Two sightings: both should bump the seeded row (no duplicates). | |
| await learner._accumulate(mk()) | |
| await learner._accumulate(mk()) | |
| await learner.stop() | |
| rows = _read_traffic_rows(db) | |
| assert len(rows) == 1 | |
| assert rows[0][0] == "seed-id" | |
| assert rows[0][2]["evidence_count"] == 4 | |
| # ============================================================================= | |
| # flush_to_file end-to-end + early-return paths | |
| # ============================================================================= | |
| class _FakeWriteResult: | |
| def __init__(self, files_written): | |
| self.files_written = files_written | |
| class _FakeWriter: | |
| def __init__(self): | |
| self.calls: list[tuple] = [] | |
| self.files_to_return: list = [] | |
| self.raise_on_write = False | |
| def write(self, recommendations, project, *, dry_run): | |
| self.calls.append((list(recommendations), project, dry_run)) | |
| if self.raise_on_write: | |
| raise RuntimeError("boom") | |
| return _FakeWriteResult(list(self.files_to_return)) | |
| class _FakePlugin: | |
| def __init__(self, roots, writer, discover_raises=False): | |
| self._roots = roots | |
| self._writer = writer | |
| self._discover_raises = discover_raises | |
| def discover_projects(self): | |
| if self._discover_raises: | |
| raise RuntimeError("discover blew up") | |
| return list(self._roots) | |
| def create_writer(self): | |
| return self._writer | |
| def _install_plugin_registry(monkeypatch, plugin): | |
| """Stub out headroom.learn.registry so flush_to_file uses our fake.""" | |
| import sys | |
| import types as _types | |
| fake = _types.ModuleType("headroom.learn.registry") | |
| fake.auto_detect_plugins = lambda: [plugin] if plugin is not None else [] # type: ignore[attr-defined] | |
| fake.get_plugin = lambda agent_type: plugin # type: ignore[attr-defined] | |
| monkeypatch.setitem(sys.modules, "headroom.learn.registry", fake) | |
| def _make_project(path): | |
| from pathlib import Path as _P | |
| from headroom.learn.models import ProjectInfo | |
| p = _P(path) | |
| return ProjectInfo(name=p.name, project_path=p, data_path=p) | |
| class TestFlushToFile: | |
| async def test_end_to_end_writes_per_project(self, tmp_path, monkeypatch): | |
| """Happy path: anchored patterns → bucketed per project → writer called.""" | |
| db = tmp_path / "memory.db" | |
| _init_db(db) | |
| backend = _FakeBackend(db) | |
| learner = TrafficLearner(backend=backend, agent_type="claude", min_evidence=2) | |
| writer = _FakeWriter() | |
| writer.files_to_return = [tmp_path / "CLAUDE.md"] | |
| proj = _make_project(str(tmp_path)) | |
| plugin = _FakePlugin(roots=[proj], writer=writer) | |
| _install_plugin_registry(monkeypatch, plugin) | |
| # Need the save worker running so accumulated patterns actually land in | |
| # the DB where flush_to_file reads them. | |
| await learner.start() | |
| try: | |
| def mk() -> ExtractedPattern: | |
| return ExtractedPattern( | |
| category=PatternCategory.ENVIRONMENT, | |
| content=f"Use /usr/bin/python3 at {tmp_path}/main.py", | |
| importance=0.6, | |
| ) | |
| # Two sightings → save at evidence_count=2 (crosses live-flush gate). | |
| await learner._accumulate(mk()) | |
| await learner._accumulate(mk()) | |
| await _wait_for_saved(learner, 1, db) | |
| await learner.flush_to_file() | |
| finally: | |
| await learner.stop() | |
| assert len(writer.calls) >= 1 | |
| recs, written_proj, dry_run = writer.calls[0] | |
| assert dry_run is False | |
| assert written_proj is proj | |
| assert len(recs) == 1 | |
| assert "python3" in recs[0].content | |
| async def test_early_returns_no_plugin(self, monkeypatch): | |
| """No plugin detected → flush is a no-op.""" | |
| learner = TrafficLearner(backend=None, agent_type="unknown", min_evidence=1) | |
| _install_plugin_registry(monkeypatch, None) | |
| # Seed an accumulator entry so the check isn't vacuously "no patterns". | |
| learner._pattern_counts["h"] = ( | |
| ExtractedPattern( | |
| category=PatternCategory.ENVIRONMENT, | |
| content="x", | |
| importance=0.5, | |
| evidence_count=2, | |
| ), | |
| 2, | |
| ) | |
| await learner.flush_to_file() # returns without raising | |
| async def test_early_return_no_patterns(self, monkeypatch): | |
| """Empty accumulator and empty DB → flush returns without calling writer.""" | |
| writer = _FakeWriter() | |
| plugin = _FakePlugin(roots=[_make_project("/x/a")], writer=writer) | |
| _install_plugin_registry(monkeypatch, plugin) | |
| learner = TrafficLearner(backend=None, agent_type="claude", min_evidence=1) | |
| await learner.flush_to_file() | |
| assert writer.calls == [] | |
| async def test_discover_projects_failure_is_swallowed(self, monkeypatch): | |
| """If plugin.discover_projects raises, flush logs and returns.""" | |
| writer = _FakeWriter() | |
| plugin = _FakePlugin(roots=[], writer=writer, discover_raises=True) | |
| _install_plugin_registry(monkeypatch, plugin) | |
| learner = TrafficLearner(backend=None, agent_type="claude", min_evidence=1) | |
| learner._pattern_counts["h"] = ( | |
| ExtractedPattern( | |
| category=PatternCategory.ENVIRONMENT, | |
| content="whatever", | |
| importance=0.5, | |
| evidence_count=2, | |
| ), | |
| 2, | |
| ) | |
| await learner.flush_to_file() | |
| assert writer.calls == [] # no roots → short-circuits before writer | |
| async def test_unanchored_patterns_dropped(self, tmp_path, monkeypatch): | |
| """Patterns with no path anchoring are dropped before writer is called.""" | |
| writer = _FakeWriter() | |
| plugin = _FakePlugin(roots=[_make_project(str(tmp_path))], writer=writer) | |
| _install_plugin_registry(monkeypatch, plugin) | |
| learner = TrafficLearner(backend=None, agent_type="claude", min_evidence=1) | |
| # Content has no absolute path — should be dropped as un-anchored. | |
| learner._pattern_counts["h"] = ( | |
| ExtractedPattern( | |
| category=PatternCategory.PREFERENCE, | |
| content="User preference: use terse output", | |
| importance=0.7, | |
| evidence_count=2, | |
| ), | |
| 2, | |
| ) | |
| await learner.flush_to_file() | |
| assert writer.calls == [] | |
| async def test_writer_exception_does_not_propagate(self, tmp_path, monkeypatch): | |
| """A writer raising should be logged; flush must not bubble the error.""" | |
| writer = _FakeWriter() | |
| writer.raise_on_write = True | |
| plugin = _FakePlugin(roots=[_make_project(str(tmp_path))], writer=writer) | |
| _install_plugin_registry(monkeypatch, plugin) | |
| learner = TrafficLearner(backend=None, agent_type="claude", min_evidence=1) | |
| learner._pattern_counts["h"] = ( | |
| ExtractedPattern( | |
| category=PatternCategory.ENVIRONMENT, | |
| content=f"Use {tmp_path}/tool.py", | |
| importance=0.6, | |
| evidence_count=2, | |
| ), | |
| 2, | |
| ) | |
| await learner.flush_to_file() # must not raise | |
| assert len(writer.calls) == 1 | |
| # ============================================================================= | |
| # Internal helper edge cases — _resolve_backend_db_path / _collect_all_patterns | |
| # / _hydrate_persisted_state / _bump_persisted_evidence | |
| # ============================================================================= | |
| class TestBackendResolution: | |
| def test_resolve_none_backend(self): | |
| from headroom.memory.traffic_learner import _resolve_backend_db_path | |
| assert _resolve_backend_db_path(None) is None | |
| def test_resolve_backend_without_config(self): | |
| from headroom.memory.traffic_learner import _resolve_backend_db_path | |
| class _Bare: | |
| pass | |
| assert _resolve_backend_db_path(_Bare()) is None | |
| def test_resolve_backend_with_empty_db_path(self): | |
| import types as _types | |
| from headroom.memory.traffic_learner import _resolve_backend_db_path | |
| backend = _types.SimpleNamespace(_config=_types.SimpleNamespace(db_path="")) | |
| assert _resolve_backend_db_path(backend) is None | |
| class TestCollectAllPatterns: | |
| async def test_merges_db_and_accumulator(self, tmp_path): | |
| """Patterns in both DB and accumulator get evidence_count summed by hash.""" | |
| db = tmp_path / "memory.db" | |
| _init_db(db) | |
| backend = _FakeBackend(db) | |
| # Seed DB with a traffic_learner row at evidence_count=3. | |
| await backend.save_memory( | |
| content="shared pattern", | |
| user_id="t", | |
| importance=0.5, | |
| metadata={ | |
| "source": "traffic_learner", | |
| "category": "environment", | |
| "evidence_count": 3, | |
| }, | |
| ) | |
| learner = TrafficLearner(backend=backend, min_evidence=1) | |
| # Same content in accumulator with count=2; hash matches. | |
| p = ExtractedPattern( | |
| category=PatternCategory.ENVIRONMENT, | |
| content="shared pattern", | |
| importance=0.5, | |
| ) | |
| learner._pattern_counts[p.content_hash] = (p, 2) | |
| merged = learner._collect_all_patterns() | |
| assert len(merged) == 1 | |
| assert merged[0].evidence_count == 3 + 2 | |
| def test_handles_missing_db_gracefully(self, tmp_path): | |
| """A backend pointing to a nonexistent DB is skipped, not raised.""" | |
| backend = _FakeBackend(tmp_path / "absent.db") # file not created | |
| learner = TrafficLearner(backend=backend, min_evidence=1) | |
| merged = learner._collect_all_patterns() | |
| assert merged == [] | |
| class TestHydrateEdgeCases: | |
| async def test_no_backend(self): | |
| """start() with backend=None hydrates to empty state and still runs.""" | |
| learner = TrafficLearner(backend=None, min_evidence=1) | |
| await learner.start() | |
| try: | |
| assert learner._saved_hashes == set() | |
| assert learner._persisted_ids == {} | |
| finally: | |
| await learner.stop() | |
| async def test_missing_db_file(self, tmp_path): | |
| """Backend with a db_path that doesn't exist → hydrate is a no-op.""" | |
| backend = _FakeBackend(tmp_path / "not-there.db") | |
| learner = TrafficLearner(backend=backend, min_evidence=1) | |
| await learner._hydrate_persisted_state() | |
| assert learner._saved_hashes == set() | |
| assert learner._persisted_ids == {} | |
| class TestBumpEdgeCases: | |
| async def test_bump_with_no_backend_is_noop(self): | |
| learner = TrafficLearner(backend=None, min_evidence=1) | |
| # Should not raise even with no backend. | |
| await learner._bump_persisted_evidence("some-id") | |
| async def test_bump_with_missing_db_is_noop(self, tmp_path): | |
| backend = _FakeBackend(tmp_path / "absent.db") | |
| learner = TrafficLearner(backend=backend, min_evidence=1) | |
| await learner._bump_persisted_evidence("some-id") # no exception | |
| async def test_bump_unknown_id_is_noop(self, tmp_path): | |
| """Updating a non-existent memory id silently affects zero rows.""" | |
| db = tmp_path / "memory.db" | |
| _init_db(db) | |
| backend = _FakeBackend(db) | |
| learner = TrafficLearner(backend=backend, min_evidence=1) | |
| await learner._bump_persisted_evidence("no-such-id") | |
| assert _read_traffic_rows(db) == [] | |
| # ============================================================================= | |
| # stop() cancels the flush task | |
| # ============================================================================= | |
| class TestStopCancels: | |
| async def test_stop_cancels_flush_task(self): | |
| learner = TrafficLearner(backend=None, min_evidence=1) | |
| await learner.start() | |
| assert learner._flush_task is not None and not learner._flush_task.done() | |
| await learner.stop() | |
| assert learner._flush_task is None or learner._flush_task.done() | |
| class TestNormalizedHash: | |
| """Error-recovery patterns hash on recovery intent, not literal text.""" | |
| def _mk(self, **meta) -> ExtractedPattern: | |
| return ExtractedPattern( | |
| category=PatternCategory.ERROR_RECOVERY, | |
| content=f"content-{meta.get('tool', 'none')}-{len(meta)}", | |
| importance=0.7, | |
| metadata=meta, | |
| ) | |
| def test_read_recovery_basename_hash(self): | |
| a = ExtractedPattern( | |
| category=PatternCategory.ERROR_RECOVERY, | |
| content="File `/a/state.rs` does not exist. The correct path is `/a/lib.rs`.", | |
| importance=0.7, | |
| metadata={"tool": "Read", "error_path": "/a/state.rs", "success_path": "/a/lib.rs"}, | |
| ) | |
| b = ExtractedPattern( | |
| category=PatternCategory.ERROR_RECOVERY, | |
| content="File `/b/state.rs` does not exist. The correct path is `/b/lib.rs`.", | |
| importance=0.7, | |
| metadata={"tool": "Read", "error_path": "/b/state.rs", "success_path": "/b/lib.rs"}, | |
| ) | |
| assert a.content_hash == b.content_hash | |
| def test_bash_recovery_tail_count_collapse(self): | |
| a = ExtractedPattern( | |
| category=PatternCategory.ERROR_RECOVERY, | |
| content="Command `cargo check` fails. Use `cargo check --manifest-path src-tauri/Cargo.toml | tail -10` instead.", | |
| importance=0.7, | |
| metadata={ | |
| "tool": "Bash", | |
| "failed_cmd": "cargo check", | |
| "success_cmd": "cargo check --manifest-path src-tauri/Cargo.toml | tail -10", | |
| }, | |
| ) | |
| b = ExtractedPattern( | |
| category=PatternCategory.ERROR_RECOVERY, | |
| content="Command `cargo check` fails. Use `cargo check --manifest-path src-tauri/Cargo.toml | tail -50` instead.", | |
| importance=0.7, | |
| metadata={ | |
| "tool": "Bash", | |
| "failed_cmd": "cargo check", | |
| "success_cmd": "cargo check --manifest-path src-tauri/Cargo.toml | tail -50", | |
| }, | |
| ) | |
| assert a.content_hash == b.content_hash | |
| def test_bash_recovery_pipe_boundary(self): | |
| a = ExtractedPattern( | |
| category=PatternCategory.ERROR_RECOVERY, | |
| content="x", | |
| importance=0.7, | |
| metadata={ | |
| "tool": "Bash", | |
| "failed_cmd": "grep foo bar.txt", | |
| "success_cmd": "grep -n foo bar.txt | head -5", | |
| }, | |
| ) | |
| b = ExtractedPattern( | |
| category=PatternCategory.ERROR_RECOVERY, | |
| content="y", | |
| importance=0.7, | |
| metadata={ | |
| "tool": "Bash", | |
| "failed_cmd": "grep foo bar.txt", | |
| "success_cmd": "grep -n foo bar.txt | wc -l", | |
| }, | |
| ) | |
| assert a.content_hash == b.content_hash | |
| def test_bash_recovery_different_primary_cmd_different_hash(self): | |
| a = ExtractedPattern( | |
| category=PatternCategory.ERROR_RECOVERY, | |
| content="x", | |
| importance=0.7, | |
| metadata={ | |
| "tool": "Bash", | |
| "failed_cmd": "cargo check", | |
| "success_cmd": "cargo build", | |
| }, | |
| ) | |
| b = ExtractedPattern( | |
| category=PatternCategory.ERROR_RECOVERY, | |
| content="y", | |
| importance=0.7, | |
| metadata={ | |
| "tool": "Bash", | |
| "failed_cmd": "cargo check", | |
| "success_cmd": "cargo test", | |
| }, | |
| ) | |
| assert a.content_hash != b.content_hash | |
| def test_non_error_recovery_unchanged(self): | |
| a = ExtractedPattern( | |
| category=PatternCategory.ENVIRONMENT, | |
| content="Use /usr/bin/python3.", | |
| importance=0.7, | |
| ) | |
| b = ExtractedPattern( | |
| category=PatternCategory.ENVIRONMENT, | |
| content="Use /opt/bin/python3.", | |
| importance=0.7, | |
| ) | |
| assert a.content_hash != b.content_hash | |
| def test_error_recovery_without_tool_falls_back_to_content(self): | |
| """Legacy error_recovery rows without a `tool` metadata key still work.""" | |
| a = ExtractedPattern( | |
| category=PatternCategory.ERROR_RECOVERY, | |
| content="Some legacy bullet.", | |
| importance=0.7, | |
| ) | |
| b = ExtractedPattern( | |
| category=PatternCategory.ERROR_RECOVERY, | |
| content="Some legacy bullet.", | |
| importance=0.7, | |
| ) | |
| assert a.content_hash == b.content_hash | |
| class TestRefineErrorRecovery: | |
| """Render-time pipeline: hard floor, re-validate, collapse, rank, cap.""" | |
| def _mk_read( | |
| self, | |
| *, | |
| error_path: str, | |
| success_path: str, | |
| evidence: int = 1, | |
| last_seen: datetime | None = None, | |
| ) -> ExtractedPattern: | |
| now = datetime.now(UTC) | |
| return ExtractedPattern( | |
| category=PatternCategory.ERROR_RECOVERY, | |
| content=f"File `{error_path}` does not exist. The correct path is `{success_path}`.", | |
| importance=0.7, | |
| evidence_count=evidence, | |
| metadata={ | |
| "tool": "Read", | |
| "error_path": error_path, | |
| "success_path": success_path, | |
| }, | |
| last_seen_at=last_seen or now, | |
| first_seen_at=last_seen or now, | |
| ) | |
| def test_drops_patterns_beyond_hard_floor(self, tmp_path): | |
| target = tmp_path / "lib.rs" | |
| target.write_text("pub fn x() {}") | |
| old = self._mk_read( | |
| error_path=str(tmp_path / "state.rs"), | |
| success_path=str(target), | |
| last_seen=datetime.now(UTC) - timedelta(days=22), | |
| ) | |
| fresh = self._mk_read( | |
| error_path=str(tmp_path / "other.rs"), | |
| success_path=str(target), | |
| ) | |
| refined = _refine_error_recovery([old, fresh]) | |
| assert fresh in refined | |
| assert old not in refined | |
| def test_revalidates_read_success_path(self, tmp_path): | |
| present = tmp_path / "present.rs" | |
| present.write_text("x") | |
| p_ok = self._mk_read( | |
| error_path=str(tmp_path / "miss.rs"), | |
| success_path=str(present), | |
| ) | |
| p_missing = self._mk_read( | |
| error_path=str(tmp_path / "other.rs"), | |
| success_path=str(tmp_path / "gone.rs"), | |
| ) | |
| refined = _refine_error_recovery([p_ok, p_missing]) | |
| assert p_ok in refined | |
| assert p_missing not in refined | |
| def test_collapses_ambiguous_error_path(self, tmp_path): | |
| a = tmp_path / "a.rs" | |
| a.write_text("x") | |
| b = tmp_path / "b.rs" | |
| b.write_text("y") | |
| c = tmp_path / "c.rs" | |
| c.write_text("z") | |
| error_path = str(tmp_path / "ambiguous.rs") | |
| group = [ | |
| self._mk_read(error_path=error_path, success_path=str(a), evidence=3), | |
| self._mk_read(error_path=error_path, success_path=str(b), evidence=2), | |
| self._mk_read(error_path=error_path, success_path=str(c), evidence=1), | |
| ] | |
| refined = _refine_error_recovery(group) | |
| assert len(refined) == 1 | |
| collapsed = refined[0] | |
| assert collapsed.metadata.get("collapsed") is True | |
| assert collapsed.evidence_count == 6 | |
| assert "ambiguous.rs" in collapsed.content | |
| assert "Glob/Grep" in collapsed.content | |
| def test_single_success_path_not_collapsed(self, tmp_path): | |
| a = tmp_path / "a.rs" | |
| a.write_text("x") | |
| error_path = str(tmp_path / "only-one-target.rs") | |
| patterns = [ | |
| self._mk_read(error_path=error_path, success_path=str(a), evidence=3), | |
| self._mk_read(error_path=error_path, success_path=str(a), evidence=2), | |
| ] | |
| refined = _refine_error_recovery(patterns) | |
| # Not collapsed — only one distinct success_path. | |
| assert all(p.metadata.get("collapsed") is not True for p in refined) | |
| assert len(refined) == 2 | |
| def test_recency_ranking_prefers_fresh_over_stale_heavy(self, tmp_path): | |
| target = tmp_path / "lib.rs" | |
| target.write_text("x") | |
| # Heavy but old: evidence=10, seen 10 days ago → score ~10 * 0.5**2 = 2.5 | |
| heavy_old = self._mk_read( | |
| error_path=str(tmp_path / "old.rs"), | |
| success_path=str(target), | |
| evidence=10, | |
| last_seen=datetime.now(UTC) - timedelta(days=10), | |
| ) | |
| # Light but fresh: evidence=3, seen now → score ~3 | |
| light_fresh = self._mk_read( | |
| error_path=str(tmp_path / "fresh.rs"), | |
| success_path=str(target), | |
| evidence=3, | |
| ) | |
| refined = _refine_error_recovery([heavy_old, light_fresh]) | |
| assert refined[0] is light_fresh | |
| assert refined[1] is heavy_old | |
| def test_section_cap_enforced(self, tmp_path): | |
| target = tmp_path / "lib.rs" | |
| target.write_text("x") | |
| patterns = [ | |
| self._mk_read( | |
| error_path=str(tmp_path / f"miss_{i}.rs"), | |
| success_path=str(target), | |
| evidence=i + 1, | |
| ) | |
| for i in range(25) | |
| ] | |
| refined = _refine_error_recovery(patterns) | |
| assert len(refined) == 15 | |
| # Highest-evidence ones kept (all are equally fresh, so evidence wins). | |
| kept_evidence = sorted(p.evidence_count for p in refined) | |
| assert kept_evidence[0] >= 11 # Bottom of top-15 out of 1..25 | |
| def test_read_recovery_without_success_path_not_revalidated(self): | |
| """Read patterns lacking `success_path` in metadata skip re-validation cleanly.""" | |
| p = ExtractedPattern( | |
| category=PatternCategory.ERROR_RECOVERY, | |
| content="Some legacy Read bullet", | |
| importance=0.7, | |
| metadata={"tool": "Read", "error_path": "/something.rs"}, | |
| last_seen_at=datetime.now(UTC), | |
| ) | |
| refined = _refine_error_recovery([p]) | |
| assert p in refined | |
| def test_bash_recoveries_not_revalidated(self, tmp_path): | |
| """Bash patterns pass through re-validation regardless of command content.""" | |
| bash_pat = ExtractedPattern( | |
| category=PatternCategory.ERROR_RECOVERY, | |
| content="Command `x` fails. Use `y` instead.", | |
| importance=0.7, | |
| evidence_count=1, | |
| metadata={ | |
| "tool": "Bash", | |
| "failed_cmd": "x", | |
| "success_cmd": "y", | |
| }, | |
| last_seen_at=datetime.now(UTC), | |
| ) | |
| refined = _refine_error_recovery([bash_pat]) | |
| assert bash_pat in refined | |
| def test_empty_input_returns_empty(self): | |
| assert _refine_error_recovery([]) == [] | |
| def test_missing_timestamps_survive_one_render(self): | |
| """Patterns without timestamps are kept rather than silently dropped.""" | |
| p = ExtractedPattern( | |
| category=PatternCategory.ERROR_RECOVERY, | |
| content="legacy bullet", | |
| importance=0.7, | |
| ) | |
| assert p.first_seen_at is None | |
| assert p.last_seen_at is None | |
| refined = _refine_error_recovery([p]) | |
| assert p in refined | |
| def test_refined_empty_skips_section_in_recommendations(self, tmp_path): | |
| """If all error_recovery patterns fail re-validation, no recommendation is emitted.""" | |
| # Only pattern is a Read recovery pointing at a nonexistent success_path. | |
| stale = ExtractedPattern( | |
| category=PatternCategory.ERROR_RECOVERY, | |
| content="File `/a.rs` does not exist. The correct path is `/gone.rs`.", | |
| importance=0.7, | |
| metadata={ | |
| "tool": "Read", | |
| "error_path": "/a.rs", | |
| "success_path": str(tmp_path / "does-not-exist.rs"), | |
| }, | |
| last_seen_at=datetime.now(UTC), | |
| ) | |
| recs = _patterns_to_recommendations([stale]) | |
| # Section should be skipped entirely — no recommendation produced. | |
| assert recs == [] | |
| def test_oserror_during_revalidation_keeps_row(self, monkeypatch): | |
| """Transient OS errors during path checks should not drop the row.""" | |
| p = ExtractedPattern( | |
| category=PatternCategory.ERROR_RECOVERY, | |
| content="File `/a.rs` does not exist. The correct path is `/b.rs`.", | |
| importance=0.7, | |
| metadata={"tool": "Read", "error_path": "/a.rs", "success_path": "/b.rs"}, | |
| last_seen_at=datetime.now(UTC), | |
| ) | |
| def _raise(self): | |
| raise OSError("simulated permission error") | |
| monkeypatch.setattr("pathlib.Path.exists", _raise) | |
| refined = _refine_error_recovery([p]) | |
| assert p in refined | |
| class TestNormalizeBashForHash: | |
| """Bash command normalization for hash-key collapse.""" | |
| def test_empty_string_returns_empty(self): | |
| assert _normalize_bash_for_hash("") == "" | |
| def test_no_volatile_suffix_unchanged(self): | |
| assert _normalize_bash_for_hash("cargo check") == "cargo check" | |
| def test_strips_head_suffix(self): | |
| assert _normalize_bash_for_hash("grep foo bar | head -20") == "grep foo bar" | |
| def test_strips_tail_suffix(self): | |
| assert _normalize_bash_for_hash("cargo check | tail -5") == "cargo check" | |
| def test_strips_trailing_context_flags(self): | |
| # The regex is anchored to end-of-string; context flags must be trailing. | |
| assert _normalize_bash_for_hash("grep foo bar -A 3") == "grep foo bar" | |
| def test_strips_stderr_redirect(self): | |
| assert _normalize_bash_for_hash("cargo check 2>&1") == "cargo check" | |
| def test_cuts_at_first_chain(self): | |
| # && boundary collapses to just the primary command | |
| assert _normalize_bash_for_hash("cd /tmp && ls") == "cd /tmp" | |
| class TestParseIsoTimestamp: | |
| """Edge-case coverage for _parse_iso_timestamp.""" | |
| def test_none_returns_none(self): | |
| assert _parse_iso_timestamp(None) is None | |
| def test_empty_string_returns_none(self): | |
| assert _parse_iso_timestamp("") is None | |
| def test_non_string_returns_none(self): | |
| assert _parse_iso_timestamp(12345) is None | |
| assert _parse_iso_timestamp(3.14) is None | |
| def test_invalid_format_returns_none(self): | |
| assert _parse_iso_timestamp("not an iso string") is None | |
| def test_naive_timestamp_assumed_utc(self): | |
| parsed = _parse_iso_timestamp("2026-04-20T12:00:00") | |
| assert parsed is not None | |
| assert parsed.tzinfo == UTC | |
| def test_aware_timestamp_preserved(self): | |
| parsed = _parse_iso_timestamp("2026-04-20T12:00:00+00:00") | |
| assert parsed is not None | |
| assert parsed.tzinfo is not None | |
| class TestLoadPersistedPatternsTimestamps: | |
| """The sqlite load path reads first_seen_at / last_seen_at correctly.""" | |
| def _make_db(self, tmp_path, rows: list[dict]): | |
| import json as _json | |
| import sqlite3 as _sql | |
| db = tmp_path / "memory.db" | |
| conn = _sql.connect(db) | |
| conn.execute( | |
| "CREATE TABLE memories (" | |
| "id TEXT PRIMARY KEY, content TEXT NOT NULL, " | |
| "metadata TEXT NOT NULL DEFAULT '{}', " | |
| "entity_refs TEXT NOT NULL DEFAULT '[]', " | |
| "importance REAL NOT NULL DEFAULT 0.5, " | |
| "created_at TEXT)" | |
| ) | |
| for i, r in enumerate(rows): | |
| conn.execute( | |
| "INSERT INTO memories " | |
| "(id, content, metadata, entity_refs, importance, created_at) " | |
| "VALUES (?,?,?,?,?,?)", | |
| ( | |
| str(i), | |
| r["content"], | |
| _json.dumps(r.get("metadata", {})), | |
| _json.dumps(r.get("entity_refs", [])), | |
| r.get("importance", 0.5), | |
| r.get("created_at"), | |
| ), | |
| ) | |
| conn.commit() | |
| conn.close() | |
| return db | |
| def test_reads_timestamps_from_metadata(self, tmp_path): | |
| db = self._make_db( | |
| tmp_path, | |
| [ | |
| { | |
| "content": "env bullet", | |
| "metadata": { | |
| "source": "traffic_learner", | |
| "category": "environment", | |
| "evidence_count": 3, | |
| "first_seen_at": "2026-04-10T10:00:00+00:00", | |
| "last_seen_at": "2026-04-20T15:00:00+00:00", | |
| }, | |
| } | |
| ], | |
| ) | |
| patterns = _load_persisted_patterns_from_sqlite(db) | |
| assert len(patterns) == 1 | |
| p = patterns[0] | |
| assert p.first_seen_at is not None | |
| assert p.first_seen_at.year == 2026 and p.first_seen_at.month == 4 | |
| assert p.last_seen_at is not None | |
| assert p.last_seen_at.day == 20 | |
| def test_falls_back_to_created_at(self, tmp_path): | |
| """When metadata has no timestamps, `created_at` is used.""" | |
| db = self._make_db( | |
| tmp_path, | |
| [ | |
| { | |
| "content": "env bullet", | |
| "metadata": { | |
| "source": "traffic_learner", | |
| "category": "environment", | |
| "evidence_count": 1, | |
| }, | |
| "created_at": "2026-03-01T09:00:00+00:00", | |
| } | |
| ], | |
| ) | |
| patterns = _load_persisted_patterns_from_sqlite(db) | |
| assert len(patterns) == 1 | |
| assert patterns[0].first_seen_at is not None | |
| assert patterns[0].first_seen_at.month == 3 | |
| # last_seen defaults to first_seen when metadata lacks both. | |
| assert patterns[0].last_seen_at == patterns[0].first_seen_at | |
| def test_collision_merges_timestamps_max_last_min_first(self, tmp_path): | |
| """Two rows collapsing to the same hash keep the widest timestamp range.""" | |
| db = self._make_db( | |
| tmp_path, | |
| [ | |
| { | |
| "content": "dup bullet", | |
| "importance": 0.4, | |
| "metadata": { | |
| "source": "traffic_learner", | |
| "category": "preference", | |
| "evidence_count": 2, | |
| "first_seen_at": "2026-04-10T00:00:00+00:00", | |
| "last_seen_at": "2026-04-15T00:00:00+00:00", | |
| }, | |
| }, | |
| { | |
| "content": "dup bullet", | |
| "importance": 0.9, | |
| "metadata": { | |
| "source": "traffic_learner", | |
| "category": "preference", | |
| "evidence_count": 3, | |
| "first_seen_at": "2026-04-01T00:00:00+00:00", | |
| "last_seen_at": "2026-04-20T00:00:00+00:00", | |
| }, | |
| }, | |
| ], | |
| ) | |
| patterns = _load_persisted_patterns_from_sqlite(db) | |
| assert len(patterns) == 1 | |
| p = patterns[0] | |
| assert p.evidence_count == 5 | |
| # Higher importance wins when collision merges. | |
| assert p.importance == 0.9 | |
| assert p.first_seen_at is not None and p.first_seen_at.day == 1 | |
| assert p.last_seen_at is not None and p.last_seen_at.day == 20 | |
| def test_non_numeric_importance_falls_back_to_default(self, tmp_path): | |
| """Rows with an unparseable importance value use 0.5.""" | |
| import json as _json | |
| import sqlite3 as _sql | |
| db = tmp_path / "memory.db" | |
| conn = _sql.connect(db) | |
| conn.execute( | |
| "CREATE TABLE memories (" | |
| "id TEXT PRIMARY KEY, content TEXT NOT NULL, " | |
| "metadata TEXT NOT NULL DEFAULT '{}', " | |
| "entity_refs TEXT NOT NULL DEFAULT '[]', " | |
| "importance TEXT, " | |
| "created_at TEXT)" | |
| ) | |
| conn.execute( | |
| "INSERT INTO memories (id, content, metadata, importance) VALUES (?,?,?,?)", | |
| ( | |
| "0", | |
| "bullet", | |
| _json.dumps( | |
| { | |
| "source": "traffic_learner", | |
| "category": "environment", | |
| "evidence_count": 1, | |
| } | |
| ), | |
| "not-a-number", | |
| ), | |
| ) | |
| conn.commit() | |
| conn.close() | |
| patterns = _load_persisted_patterns_from_sqlite(db) | |
| assert len(patterns) == 1 | |
| assert patterns[0].importance == 0.5 | |
| def test_malformed_metadata_json_skipped_gracefully(self, tmp_path): | |
| """Rows with invalid JSON metadata don't crash the load.""" | |
| import sqlite3 as _sql | |
| db = tmp_path / "memory.db" | |
| conn = _sql.connect(db) | |
| conn.execute( | |
| "CREATE TABLE memories (" | |
| "id TEXT PRIMARY KEY, content TEXT NOT NULL, " | |
| "metadata TEXT NOT NULL DEFAULT '{}', " | |
| "entity_refs TEXT NOT NULL DEFAULT '[]', " | |
| "importance REAL NOT NULL DEFAULT 0.5, " | |
| "created_at TEXT)" | |
| ) | |
| # Invalid JSON in metadata | |
| conn.execute( | |
| "INSERT INTO memories VALUES (?,?,?,?,?,?)", | |
| ("0", "bullet", "{not json", "[]", 0.5, None), | |
| ) | |
| conn.commit() | |
| conn.close() | |
| # Should not raise — the row is simply skipped (no recognizable category). | |
| patterns = _load_persisted_patterns_from_sqlite(db) | |
| assert patterns == [] | |
| class TestBumpPersistsLastSeenAt: | |
| """_bump_persisted_evidence sets $.last_seen_at on every bump.""" | |
| async def test_bump_sets_last_seen_at_in_metadata(self, tmp_path): | |
| import sqlite3 as _sql | |
| db = tmp_path / "memory.db" | |
| _init_db(db) | |
| # Seed a traffic_learner row with no last_seen_at. | |
| import json as _json | |
| conn = _sql.connect(db) | |
| conn.execute( | |
| "INSERT INTO memories (id, content, metadata) VALUES (?,?,?)", | |
| ( | |
| "row-1", | |
| "bullet", | |
| _json.dumps( | |
| { | |
| "source": "traffic_learner", | |
| "category": "environment", | |
| "evidence_count": 1, | |
| } | |
| ), | |
| ), | |
| ) | |
| conn.commit() | |
| conn.close() | |
| backend = _FakeBackend(db) | |
| learner = TrafficLearner(backend=backend, min_evidence=1) | |
| await learner._bump_persisted_evidence("row-1") | |
| conn = _sql.connect(db) | |
| row = conn.execute("SELECT metadata FROM memories WHERE id='row-1'").fetchone() | |
| conn.close() | |
| meta = _json.loads(row[0]) | |
| assert meta["evidence_count"] == 2 | |
| assert "last_seen_at" in meta | |
| # Should be parseable back. | |
| parsed = _parse_iso_timestamp(meta["last_seen_at"]) | |
| assert parsed is not None | |
| class TestHydrateLegacyRow: | |
| """Legacy rows without `category` metadata fall back to literal-content hashing.""" | |
| async def test_hydrate_legacy_row_without_category(self, tmp_path): | |
| import sqlite3 as _sql | |
| db = tmp_path / "memory.db" | |
| _init_db(db) | |
| import json as _json | |
| conn = _sql.connect(db) | |
| # No `category` key in metadata — must still hydrate. | |
| conn.execute( | |
| "INSERT INTO memories (id, content, metadata) VALUES (?,?,?)", | |
| ( | |
| "legacy-1", | |
| "legacy bullet", | |
| _json.dumps({"source": "traffic_learner"}), | |
| ), | |
| ) | |
| conn.commit() | |
| conn.close() | |
| backend = _FakeBackend(db) | |
| learner = TrafficLearner(backend=backend, min_evidence=1) | |
| await learner._hydrate_persisted_state() | |
| # Falls back to sha256(content) for the hash key. | |
| import hashlib as _h | |
| expected = _h.sha256(b"legacy bullet").hexdigest()[:16] | |
| assert expected in learner._saved_hashes | |
| assert learner._persisted_ids[expected] == "legacy-1" | |
| async def test_hydrate_skips_empty_content(self, tmp_path): | |
| """Rows with empty content are skipped during hydration.""" | |
| import json as _json | |
| import sqlite3 as _sql | |
| db = tmp_path / "memory.db" | |
| _init_db(db) | |
| conn = _sql.connect(db) | |
| conn.execute( | |
| "INSERT INTO memories (id, content, metadata) VALUES (?,?,?)", | |
| ("empty", "", _json.dumps({"source": "traffic_learner"})), | |
| ) | |
| conn.execute( | |
| "INSERT INTO memories (id, content, metadata) VALUES (?,?,?)", | |
| ( | |
| "ok", | |
| "normal bullet", | |
| _json.dumps({"source": "traffic_learner", "category": "environment"}), | |
| ), | |
| ) | |
| conn.commit() | |
| conn.close() | |
| backend = _FakeBackend(db) | |
| learner = TrafficLearner(backend=backend, min_evidence=1) | |
| await learner._hydrate_persisted_state() | |
| assert "empty" not in learner._persisted_ids.values() | |
| assert "ok" in learner._persisted_ids.values() | |
| async def test_hydrate_invalid_category_falls_back(self, tmp_path): | |
| """Unknown category values (e.g., typos) are handled as legacy rows.""" | |
| import sqlite3 as _sql | |
| db = tmp_path / "memory.db" | |
| _init_db(db) | |
| import json as _json | |
| conn = _sql.connect(db) | |
| conn.execute( | |
| "INSERT INTO memories (id, content, metadata) VALUES (?,?,?)", | |
| ( | |
| "bad-cat", | |
| "mystery bullet", | |
| _json.dumps({"source": "traffic_learner", "category": "mystery_type"}), | |
| ), | |
| ) | |
| conn.commit() | |
| conn.close() | |
| backend = _FakeBackend(db) | |
| learner = TrafficLearner(backend=backend, min_evidence=1) | |
| # Must not raise. | |
| await learner._hydrate_persisted_state() | |
| class TestCollectAllPatternsTimestamps: | |
| """_collect_all_patterns bumps last_seen_at on in-session re-sightings.""" | |
| async def test_re_sighting_bumps_last_seen_at(self, tmp_path): | |
| """A persisted pattern re-observed in this session gets last_seen_at=now.""" | |
| import json as _json | |
| import sqlite3 as _sql | |
| db = tmp_path / "memory.db" | |
| _init_db(db) | |
| old_last_seen = "2026-01-01T00:00:00+00:00" | |
| conn = _sql.connect(db) | |
| conn.execute( | |
| "INSERT INTO memories (id, content, metadata) VALUES (?,?,?)", | |
| ( | |
| "seed-1", | |
| "some env bullet", | |
| _json.dumps( | |
| { | |
| "source": "traffic_learner", | |
| "category": "environment", | |
| "evidence_count": 1, | |
| "first_seen_at": old_last_seen, | |
| "last_seen_at": old_last_seen, | |
| } | |
| ), | |
| ), | |
| ) | |
| conn.commit() | |
| conn.close() | |
| backend = _FakeBackend(db) | |
| learner = TrafficLearner(backend=backend, min_evidence=1) | |
| # Simulate in-session accumulation of the same pattern. | |
| pattern = ExtractedPattern( | |
| category=PatternCategory.ENVIRONMENT, | |
| content="some env bullet", | |
| importance=0.7, | |
| ) | |
| learner._pattern_counts[pattern.content_hash] = (pattern, 1) | |
| merged = learner._collect_all_patterns() | |
| assert len(merged) == 1 | |
| m = merged[0] | |
| assert m.last_seen_at is not None | |
| # last_seen_at should be bumped past the stale 2026-01 timestamp. | |
| assert m.last_seen_at.year == datetime.now(UTC).year | |
| assert m.last_seen_at > _parse_iso_timestamp(old_last_seen) | |