JerrettDavis commited on
Commit
7ecbffe
·
2 Parent(s): 545dffd7c97511

Merge branch 'chopratejas:main' into ci/release-automation

Browse files
Files changed (39) hide show
  1. .github/dependabot.yml +27 -0
  2. .github/workflows/docker.yml +17 -0
  3. .github/workflows/publish.yml +15 -2
  4. .gitignore +3 -0
  5. Dockerfile +7 -3
  6. README.md +32 -35
  7. headroom/cli/learn.py +14 -1
  8. headroom/cli/mcp.py +23 -3
  9. headroom/cli/proxy.py +35 -4
  10. headroom/cli/wrap.py +260 -12
  11. headroom/install/state.py +16 -7
  12. headroom/learn/base.py +9 -3
  13. headroom/learn/plugins/claude.py +18 -10
  14. headroom/learn/plugins/codex.py +17 -6
  15. headroom/learn/plugins/gemini.py +14 -6
  16. headroom/memory/factory.py +10 -0
  17. headroom/memory/mcp_server.py +375 -0
  18. headroom/memory/sync.py +395 -0
  19. headroom/memory/sync_adapters/__init__.py +1 -0
  20. headroom/memory/sync_adapters/claude_code.py +233 -0
  21. headroom/memory/sync_adapters/codex_agent.py +106 -0
  22. headroom/memory/writers/claude_writer.py +2 -1
  23. headroom/proxy/handlers/openai.py +370 -20
  24. headroom/proxy/handlers/streaming.py +7 -4
  25. headroom/proxy/memory_handler.py +33 -13
  26. headroom/proxy/models.py +3 -0
  27. headroom/proxy/request_logger.py +23 -7
  28. headroom/telemetry/toin.py +4 -2
  29. headroom/transforms/kompress_compressor.py +98 -79
  30. headroom/transforms/smart_crusher.py +60 -48
  31. plugins/openclaw/package.json +54 -54
  32. tests/test_cli/test_wrap_copilot.py +17 -4
  33. tests/test_learn/test_scanner.py +64 -0
  34. tests/test_memory_sync.py +647 -0
  35. tests/test_package_init_lazy.py +2 -1
  36. tests/test_transforms/test_kompress_compressor.py +2 -3
  37. tests/test_transforms/test_smart_crusher_bugs.py +212 -0
  38. tests/test_transforms/test_universal_json_crush.py +20 -18
  39. tests/test_ws_memory_relay.py +523 -0
.github/dependabot.yml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: 2
2
+ updates:
3
+ # Docker base image digest updates
4
+ - package-ecosystem: docker
5
+ directory: /
6
+ schedule:
7
+ interval: weekly
8
+ commit-message:
9
+ prefix: "docker"
10
+
11
+ # GitHub Actions version updates
12
+ - package-ecosystem: github-actions
13
+ directory: /
14
+ schedule:
15
+ interval: weekly
16
+ commit-message:
17
+ prefix: "ci"
18
+
19
+ # Python dependency updates (pip)
20
+ - package-ecosystem: pip
21
+ directory: /
22
+ schedule:
23
+ interval: weekly
24
+ commit-message:
25
+ prefix: "deps"
26
+ # Only open PRs for security updates to avoid noise
27
+ open-pull-requests-limit: 5
.github/workflows/docker.yml CHANGED
@@ -12,6 +12,7 @@ env:
12
  permissions:
13
  contents: read
14
  packages: write
 
15
 
16
  jobs:
17
  docker-variant-tags:
@@ -81,3 +82,19 @@ jobs:
81
  set: |
82
  *.cache-from=type=gha
83
  *.cache-to=type=gha,mode=max
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  permissions:
13
  contents: read
14
  packages: write
15
+ id-token: write # For cosign keyless signing via Sigstore OIDC
16
 
17
  jobs:
18
  docker-variant-tags:
 
82
  set: |
83
  *.cache-from=type=gha
84
  *.cache-to=type=gha,mode=max
85
+
86
+ - name: Install cosign
87
+ uses: sigstore/cosign-installer@v3
88
+
89
+ - name: Sign images with cosign (keyless via Sigstore OIDC)
90
+ env:
91
+ BAKE_META: ${{ steps.bake.outputs.metadata }}
92
+ run: |
93
+ # Extract all pushed image digests from bake metadata and sign each
94
+ echo "$BAKE_META" | jq -r '
95
+ to_entries[].value."containerimage.digest" // empty
96
+ ' | while read -r digest; do
97
+ image="${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}@${digest}"
98
+ echo "Signing ${image}"
99
+ cosign sign --yes "${image}"
100
+ done
.github/workflows/publish.yml CHANGED
@@ -11,7 +11,8 @@ jobs:
11
  runs-on: ubuntu-latest
12
  environment: pypi
13
  permissions:
14
- id-token: write # For trusted publishing
 
15
 
16
  steps:
17
  - uses: actions/checkout@v4
@@ -23,11 +24,23 @@ jobs:
23
 
24
  - name: Install build tools
25
  run: |
26
- python -m pip install --upgrade pip build
27
 
28
  - name: Build package
29
  run: |
30
  python -m build
31
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  - name: Publish to PyPI
33
  uses: pypa/gh-action-pypi-publish@release/v1
 
11
  runs-on: ubuntu-latest
12
  environment: pypi
13
  permissions:
14
+ id-token: write # For trusted publishing
15
+ contents: write # For uploading release assets
16
 
17
  steps:
18
  - uses: actions/checkout@v4
 
24
 
25
  - name: Install build tools
26
  run: |
27
+ python -m pip install --upgrade pip build cyclonedx-bom
28
 
29
  - name: Build package
30
  run: |
31
  python -m build
32
 
33
+ - name: Generate SBOM (CycloneDX)
34
+ run: |
35
+ pip install -e ".[proxy]"
36
+ cyclonedx-py environment \
37
+ --output-format json \
38
+ --outfile dist/headroom-sbom.cdx.json
39
+
40
+ - name: Upload SBOM to release
41
+ uses: softprops/action-gh-release@v2
42
+ with:
43
+ files: dist/headroom-sbom.cdx.json
44
+
45
  - name: Publish to PyPI
46
  uses: pypa/gh-action-pypi-publish@release/v1
.gitignore CHANGED
@@ -12,6 +12,9 @@ scripts/*
12
  # Swift SDK (separate repo)
13
  swift/
14
 
 
 
 
15
  # Audit/scan outputs (contain security findings — never commit)
16
  bandit_result.txt
17
  pip_audit_result.txt
 
12
  # Swift SDK (separate repo)
13
  swift/
14
 
15
+ # Local planning docs (never commit)
16
+ ENTERPRISE_HARDENING.md
17
+
18
  # Audit/scan outputs (contain security findings — never commit)
19
  bandit_result.txt
20
  pip_audit_result.txt
Dockerfile CHANGED
@@ -1,10 +1,14 @@
1
  ARG PYTHON_VERSION=3.11
2
  ARG UV_VERSION=0.6.17
 
 
 
 
3
  ARG DISTROLESS_IMAGE=gcr.io/distroless/python3-debian13
4
  ARG PYTHON_SITE_PACKAGES=/usr/local/lib/python${PYTHON_VERSION}/site-packages
5
 
6
  # ---- Build stage: compile native extensions, build wheel ----
7
- FROM python:${PYTHON_VERSION}-slim AS builder
8
 
9
  ARG UV_VERSION
10
 
@@ -32,7 +36,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
32
  uv pip install --system --no-deps --reinstall-package headroom-ai .
33
 
34
  # ---- Runtime stage (python-slim): supports root/nonroot via build arg ----
35
- FROM python:${PYTHON_VERSION}-slim AS runtime-slim-base
36
 
37
  ARG RUNTIME_USER=nonroot
38
  ARG PYTHON_SITE_PACKAGES
@@ -69,7 +73,7 @@ HEALTHCHECK --interval=30s --timeout=5s --start-period=20s --retries=3 \
69
  ENTRYPOINT ["headroom", "proxy"]
70
  CMD ["--host", "0.0.0.0", "--port", "8787"]
71
 
72
- FROM ${DISTROLESS_IMAGE} AS runtime-slim
73
 
74
  ARG RUNTIME_USER=nonroot
75
  ARG PYTHON_SITE_PACKAGES
 
1
  ARG PYTHON_VERSION=3.11
2
  ARG UV_VERSION=0.6.17
3
+ # Pinned 2026-04-15. Update via Dependabot or: docker pull python:3.11-slim
4
+ ARG PYTHON_DIGEST=sha256:233de06753d30d120b1a3ce359d8d3be8bda78524cd8f520c99883bfe33964cf
5
+ # Pinned 2026-04-15. Update via Dependabot or: docker pull gcr.io/distroless/python3-debian13
6
+ ARG DISTROLESS_DIGEST=sha256:ed3a4beb46f8f8baac068743ba1b1f95ea3f793422129cf6dd23967f779b6018
7
  ARG DISTROLESS_IMAGE=gcr.io/distroless/python3-debian13
8
  ARG PYTHON_SITE_PACKAGES=/usr/local/lib/python${PYTHON_VERSION}/site-packages
9
 
10
  # ---- Build stage: compile native extensions, build wheel ----
11
+ FROM python:${PYTHON_VERSION}-slim@${PYTHON_DIGEST} AS builder
12
 
13
  ARG UV_VERSION
14
 
 
36
  uv pip install --system --no-deps --reinstall-package headroom-ai .
37
 
38
  # ---- Runtime stage (python-slim): supports root/nonroot via build arg ----
39
+ FROM python:${PYTHON_VERSION}-slim@${PYTHON_DIGEST} AS runtime-slim-base
40
 
41
  ARG RUNTIME_USER=nonroot
42
  ARG PYTHON_SITE_PACKAGES
 
73
  ENTRYPOINT ["headroom", "proxy"]
74
  CMD ["--host", "0.0.0.0", "--port", "8787"]
75
 
76
+ FROM ${DISTROLESS_IMAGE}@${DISTROLESS_DIGEST} AS runtime-slim
77
 
78
  ARG RUNTIME_USER=nonroot
79
  ARG PYTHON_SITE_PACKAGES
README.md CHANGED
@@ -155,9 +155,9 @@ OPENAI_BASE_URL=http://localhost:8787/v1 your-app
155
  Use `token` mode for short/medium sessions where raw compression savings matter most.
156
  Use `cache` mode for long-running chats where preserving prior-turn bytes improves provider cache reuse.
157
 
158
- Works with any language, any tool, any framework. **[Proxy docs](docs/proxy.md)**
159
 
160
- Prefer Docker as the runtime provider? See **[Docker-native install](docs/docker-install.md)**. Want Headroom to stay up in the background? See **[Persistent installs](docs/persistent-installs.md)**.
161
 
162
  ### Coding agents — one command
163
 
@@ -189,7 +189,7 @@ summary = ctx.get("research") # Agent B reads (~80% smaller)
189
  full = ctx.get("research", full=True) # Agent B gets original if needed
190
  ```
191
 
192
- Compress what moves between agents — any framework. **[SharedContext Guide](docs/shared-context.md)**
193
 
194
  ### MCP Tools (Claude Code, Cursor)
195
 
@@ -197,7 +197,7 @@ Compress what moves between agents — any framework. **[SharedContext Guide](do
197
  headroom mcp install && claude
198
  ```
199
 
200
- Gives your AI tool three MCP tools: `headroom_compress`, `headroom_retrieve`, `headroom_stats`. **[MCP Guide](docs/mcp.md)**
201
 
202
  ### Drop into your existing stack
203
 
@@ -219,7 +219,7 @@ Gives your AI tool three MCP tools: `headroom_compress`, `headroom_retrieve`, `h
219
  | **Codex / Aider** | Wrap | `headroom wrap codex` or `headroom wrap aider` |
220
  | **Always-on local proxy** | Persistent install | `headroom install apply --preset persistent-service --providers auto` |
221
 
222
- **[Full Integration Guide](docs/integration-guide.md)** | **[TypeScript SDK](docs/typescript-sdk.md)**
223
 
224
  ---
225
 
@@ -292,7 +292,7 @@ python -m headroom.evals suite --tier 1 -o eval_results/
292
  python -m headroom.evals suite --tier 1 --ci
293
  ```
294
 
295
- Full methodology: [Benchmarks](docs/benchmarks.md) | [Evals Framework](headroom/evals/README.md)
296
 
297
  ---
298
 
@@ -317,7 +317,7 @@ headroom wrap claude --memory # Claude with persistent memory
317
  headroom wrap codex --memory # Codex shares the SAME memory store
318
  ```
319
 
320
- Claude saves a fact, Codex reads it back. All agents sharing one proxy share one memory — project-scoped, user-isolated, with agent provenance tracking and automatic deduplication. No SDK changes needed. **[Memory docs](docs/memory.md)**
321
 
322
  ### Failure Learning
323
 
@@ -327,7 +327,7 @@ headroom learn --apply # Write learnings to agent-native files
327
  headroom learn --agent codex --all # Analyze all Codex sessions
328
  ```
329
 
330
- Plugin-based: reads conversation history from Claude Code, Codex, or Gemini CLI. Finds failure patterns, correlates with successes, writes corrections to CLAUDE.md / AGENTS.md / GEMINI.md. External plugins via entry points. **[Learn docs](docs/learn.md)**
331
 
332
  <p align="center">
333
  <img src="headroom_learn.gif" alt="headroom learn demo" width="800">
@@ -405,7 +405,7 @@ Context compression is a new space. Here's how the approaches differ:
405
  Originals are in the Compressed Store — nothing is thrown away.
406
  ```
407
 
408
- **Overhead**: 15-200ms compression latency (net positive for Sonnet/Opus). Full data: [Latency Benchmarks](docs/LATENCY_BENCHMARKS.md)
409
 
410
  ---
411
 
@@ -413,16 +413,16 @@ Context compression is a new space. Here's how the approaches differ:
413
 
414
  | Integration | Status | Docs |
415
  |-------------|--------|------|
416
- | `headroom wrap claude/copilot/codex/aider/cursor` | **Stable** | [Proxy Docs](docs/proxy.md) |
417
- | `compress()` — one function | **Stable** | [Integration Guide](docs/integration-guide.md) |
418
- | `SharedContext` — multi-agent | **Stable** | [SharedContext Guide](docs/shared-context.md) |
419
- | LiteLLM callback | **Stable** | [Integration Guide](docs/integration-guide.md#litellm) |
420
- | ASGI middleware | **Stable** | [Integration Guide](docs/integration-guide.md#asgi-middleware) |
421
- | Proxy server | **Stable** | [Proxy Docs](docs/proxy.md) |
422
- | Agno | **Stable** | [Agno Guide](docs/agno.md) |
423
- | MCP (Claude Code, Cursor, etc.) | **Stable** | [MCP Guide](docs/mcp.md) |
424
- | Strands | **Stable** | [Strands Guide](docs/strands.md) |
425
- | LangChain | **Stable** | [LangChain Guide](docs/langchain.md) |
426
  | **OpenClaw** | **Stable** | [OpenClaw plugin](#openclaw-plugin) |
427
 
428
  ---
@@ -521,23 +521,20 @@ Python 3.10+
521
 
522
  | | |
523
  |---|---|
524
- | [Integration Guide](docs/integration-guide.md) | LiteLLM, ASGI, compress(), proxy |
525
- | [Proxy Docs](docs/proxy.md) | Proxy server configuration |
526
- | [Architecture](docs/ARCHITECTURE.md) | How the pipeline works |
527
- | [CCR Guide](docs/ccr.md) | Reversible compression |
528
- | [Benchmarks](docs/benchmarks.md) | Accuracy validation |
529
- | [Latency Benchmarks](docs/LATENCY_BENCHMARKS.md) | Compression overhead & cost-benefit analysis |
530
- | [Limitations](docs/LIMITATIONS.md) | When compression helps, when it doesn't |
531
  | [Evals Framework](headroom/evals/README.md) | Prove compression preserves accuracy |
532
- | [Memory](docs/memory.md) | Cross-agent persistent memory with provenance + dedup |
533
- | [Agno](docs/agno.md) | Agno agent framework |
534
- | [MCP](docs/mcp.md) | Context engineering toolkit (compress, retrieve, stats) |
535
- | [SharedContext](docs/shared-context.md) | Compressed inter-agent context sharing |
536
- | [Learn](docs/learn.md) | Plugin-based failure learning (Claude, Codex, Gemini, extensible) |
537
- | [CLI Reference](docs/cli.md) | Complete command surface, help output, and Docker parity matrix |
538
- | [Docker-Native Install](docs/docker-install.md) | Host wrapper install, compose support, and Docker runtime behavior |
539
- | [Persistent Installs](docs/persistent-installs.md) | Service/task/docker deployment models and provider scopes |
540
- | [Configuration](docs/configuration.md) | All options |
541
 
542
  ---
543
 
 
155
  Use `token` mode for short/medium sessions where raw compression savings matter most.
156
  Use `cache` mode for long-running chats where preserving prior-turn bytes improves provider cache reuse.
157
 
158
+ Works with any language, any tool, any framework. **[Proxy docs](docs/content/docs/proxy.mdx)**
159
 
160
+ Prefer Docker as the runtime provider? See **[Installation — Docker](docs/content/docs/installation.mdx)**.
161
 
162
  ### Coding agents — one command
163
 
 
189
  full = ctx.get("research", full=True) # Agent B gets original if needed
190
  ```
191
 
192
+ Compress what moves between agents — any framework. **[SharedContext Guide](docs/content/docs/shared-context.mdx)**
193
 
194
  ### MCP Tools (Claude Code, Cursor)
195
 
 
197
  headroom mcp install && claude
198
  ```
199
 
200
+ Gives your AI tool three MCP tools: `headroom_compress`, `headroom_retrieve`, `headroom_stats`. **[MCP Guide](docs/content/docs/mcp.mdx)**
201
 
202
  ### Drop into your existing stack
203
 
 
219
  | **Codex / Aider** | Wrap | `headroom wrap codex` or `headroom wrap aider` |
220
  | **Always-on local proxy** | Persistent install | `headroom install apply --preset persistent-service --providers auto` |
221
 
222
+ **[Full Integration Guide](docs/content/docs/index.mdx)**
223
 
224
  ---
225
 
 
292
  python -m headroom.evals suite --tier 1 --ci
293
  ```
294
 
295
+ Full methodology: [Benchmarks](docs/content/docs/benchmarks.mdx) | [Evals Framework](headroom/evals/README.md)
296
 
297
  ---
298
 
 
317
  headroom wrap codex --memory # Codex shares the SAME memory store
318
  ```
319
 
320
+ Claude saves a fact, Codex reads it back. All agents sharing one proxy share one memory — project-scoped, user-isolated, with agent provenance tracking and automatic deduplication. No SDK changes needed. **[Memory docs](docs/content/docs/memory.mdx)**
321
 
322
  ### Failure Learning
323
 
 
327
  headroom learn --agent codex --all # Analyze all Codex sessions
328
  ```
329
 
330
+ Plugin-based: reads conversation history from Claude Code, Codex, or Gemini CLI. Finds failure patterns, correlates with successes, writes corrections to CLAUDE.md / AGENTS.md / GEMINI.md. External plugins via entry points. **[Learn docs](docs/content/docs/failure-learning.mdx)**
331
 
332
  <p align="center">
333
  <img src="headroom_learn.gif" alt="headroom learn demo" width="800">
 
405
  Originals are in the Compressed Store — nothing is thrown away.
406
  ```
407
 
408
+ **Overhead**: 15-200ms compression latency (net positive for Sonnet/Opus). Full data: [Benchmarks](docs/content/docs/benchmarks.mdx)
409
 
410
  ---
411
 
 
413
 
414
  | Integration | Status | Docs |
415
  |-------------|--------|------|
416
+ | `headroom wrap claude/copilot/codex/aider/cursor` | **Stable** | [Proxy Docs](docs/content/docs/proxy.mdx) |
417
+ | `compress()` — one function | **Stable** | [Integration Guide](docs/content/docs/index.mdx) |
418
+ | `SharedContext` — multi-agent | **Stable** | [SharedContext Guide](docs/content/docs/shared-context.mdx) |
419
+ | LiteLLM callback | **Stable** | [LiteLLM Guide](docs/content/docs/litellm.mdx) |
420
+ | ASGI middleware | **Stable** | [Integration Guide](docs/content/docs/index.mdx) |
421
+ | Proxy server | **Stable** | [Proxy Docs](docs/content/docs/proxy.mdx) |
422
+ | Agno | **Stable** | [Agno Guide](docs/content/docs/agno.mdx) |
423
+ | MCP (Claude Code, Cursor, etc.) | **Stable** | [MCP Guide](docs/content/docs/mcp.mdx) |
424
+ | Strands | **Stable** | [Strands Guide](docs/content/docs/strands.mdx) |
425
+ | LangChain | **Stable** | [LangChain Guide](docs/content/docs/langchain.mdx) |
426
  | **OpenClaw** | **Stable** | [OpenClaw plugin](#openclaw-plugin) |
427
 
428
  ---
 
521
 
522
  | | |
523
  |---|---|
524
+ | [Integration Guide](docs/content/docs/index.mdx) | LiteLLM, ASGI, compress(), proxy |
525
+ | [Proxy Docs](docs/content/docs/proxy.mdx) | Proxy server configuration |
526
+ | [Architecture](docs/content/docs/architecture.mdx) | How the pipeline works |
527
+ | [CCR Guide](docs/content/docs/ccr.mdx) | Reversible compression |
528
+ | [Benchmarks](docs/content/docs/benchmarks.mdx) | Accuracy validation |
529
+ | [Limitations](docs/content/docs/limitations.mdx) | When compression helps, when it doesn't |
 
530
  | [Evals Framework](headroom/evals/README.md) | Prove compression preserves accuracy |
531
+ | [Memory](docs/content/docs/memory.mdx) | Cross-agent persistent memory with provenance + dedup |
532
+ | [Agno](docs/content/docs/agno.mdx) | Agno agent framework |
533
+ | [MCP](docs/content/docs/mcp.mdx) | Context engineering toolkit (compress, retrieve, stats) |
534
+ | [SharedContext](docs/content/docs/shared-context.mdx) | Compressed inter-agent context sharing |
535
+ | [Learn](docs/content/docs/failure-learning.mdx) | Plugin-based failure learning (Claude, Codex, Gemini, extensible) |
536
+ | [Installation](docs/content/docs/installation.mdx) | pip, npm, Docker install methods |
537
+ | [Configuration](docs/content/docs/configuration.mdx) | All options |
 
 
538
 
539
  ---
540
 
headroom/cli/learn.py CHANGED
@@ -90,12 +90,21 @@ Use 'auto' (default) to scan all detected agents."""
90
  help="LLM model for analysis (e.g., claude-sonnet-4-6, gpt-4o, gemini/gemini-2.0-flash). "
91
  "Auto-detected from API keys if not specified.",
92
  )
 
 
 
 
 
 
 
 
93
  def learn(
94
  project: Path | None,
95
  analyze_all: bool,
96
  apply: bool,
97
  agent: str,
98
  model: str | None,
 
99
  ) -> None:
100
  """Learn from past tool call failures to prevent future ones.
101
 
@@ -115,9 +124,13 @@ def learn(
115
  headroom learn --all # Analyze all projects
116
  headroom learn --agent codex --all # Analyze all Codex sessions
117
  """
 
 
118
  from ..learn.analyzer import SessionAnalyzer, _detect_default_model
119
  from ..learn.registry import auto_detect_plugins, get_plugin
120
 
 
 
121
  # Resolve model early to fail fast with a clear message
122
  try:
123
  resolved_model = model or _detect_default_model()
@@ -185,7 +198,7 @@ def learn(
185
  click.echo(f"Path: {proj.project_path}")
186
  click.echo(f"{'=' * 60}")
187
 
188
- sessions = plugin.scan_project(proj)
189
  if not sessions:
190
  click.echo(" No conversation data found.")
191
  continue
 
90
  help="LLM model for analysis (e.g., claude-sonnet-4-6, gpt-4o, gemini/gemini-2.0-flash). "
91
  "Auto-detected from API keys if not specified.",
92
  )
93
+ @click.option(
94
+ "--workers",
95
+ "-j",
96
+ type=int,
97
+ default=None,
98
+ help="Parallel workers for session scanning. "
99
+ "Default: auto (min of CPU count, 8). Use 1 for serial.",
100
+ )
101
  def learn(
102
  project: Path | None,
103
  analyze_all: bool,
104
  apply: bool,
105
  agent: str,
106
  model: str | None,
107
+ workers: int | None,
108
  ) -> None:
109
  """Learn from past tool call failures to prevent future ones.
110
 
 
124
  headroom learn --all # Analyze all projects
125
  headroom learn --agent codex --all # Analyze all Codex sessions
126
  """
127
+ import os
128
+
129
  from ..learn.analyzer import SessionAnalyzer, _detect_default_model
130
  from ..learn.registry import auto_detect_plugins, get_plugin
131
 
132
+ max_workers = workers if workers is not None else min(os.cpu_count() or 4, 8)
133
+
134
  # Resolve model early to fail fast with a clear message
135
  try:
136
  resolved_model = model or _detect_default_model()
 
198
  click.echo(f"Path: {proj.project_path}")
199
  click.echo(f"{'=' * 60}")
200
 
201
+ sessions = plugin.scan_project(proj, max_workers=max_workers)
202
  if not sessions:
203
  click.echo(" No conversation data found.")
204
  continue
headroom/cli/mcp.py CHANGED
@@ -244,13 +244,33 @@ def mcp_uninstall() -> None:
244
  err=True,
245
  )
246
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  # Also remove from mcp.json fallback config if present
248
  if MCP_CONFIG_PATH.exists():
249
  config = load_mcp_config()
250
- if "headroom" in config.get("mcpServers", {}):
251
- del config["mcpServers"]["headroom"]
 
 
 
 
252
  save_mcp_config(config)
253
- click.echo(f"✓ Headroom MCP server removed from {MCP_CONFIG_PATH}")
254
  removed = True
255
 
256
  if not removed:
 
244
  err=True,
245
  )
246
 
247
+ # Also remove codebase-memory-mcp if registered (installed by --code-graph)
248
+ if claude_cli:
249
+ cbm_check = subprocess.run(
250
+ [claude_cli, "mcp", "get", "codebase-memory-mcp"],
251
+ capture_output=True,
252
+ )
253
+ if cbm_check.returncode == 0:
254
+ cbm_rm = subprocess.run(
255
+ [claude_cli, "mcp", "remove", "codebase-memory-mcp", "-s", "user"],
256
+ capture_output=True,
257
+ text=True,
258
+ )
259
+ if cbm_rm.returncode == 0:
260
+ click.echo("✓ codebase-memory-mcp MCP server removed")
261
+ removed = True
262
+
263
  # Also remove from mcp.json fallback config if present
264
  if MCP_CONFIG_PATH.exists():
265
  config = load_mcp_config()
266
+ changed = False
267
+ for server_name in ("headroom", "codebase-memory-mcp"):
268
+ if server_name in config.get("mcpServers", {}):
269
+ del config["mcpServers"][server_name]
270
+ changed = True
271
+ if changed:
272
  save_mcp_config(config)
273
+ click.echo(f"✓ MCP servers removed from {MCP_CONFIG_PATH}")
274
  removed = True
275
 
276
  if not removed:
headroom/cli/proxy.py CHANGED
@@ -179,6 +179,13 @@ from .main import main
179
  is_flag=True,
180
  help="Disable anonymous usage telemetry (env: HEADROOM_TELEMETRY=off)",
181
  )
 
 
 
 
 
 
 
182
  @click.pass_context
183
  def proxy(
184
  ctx: click.Context,
@@ -213,6 +220,7 @@ def proxy(
213
  bedrock_region: str | None,
214
  bedrock_profile: str | None,
215
  no_telemetry: bool,
 
216
  ) -> None:
217
  """Start the optimization proxy server.
218
 
@@ -251,10 +259,22 @@ def proxy(
251
  mode or os.environ.get("HEADROOM_MODE") or PROXY_MODE_TOKEN
252
  )
253
 
 
 
 
 
 
 
 
 
254
  # Telemetry opt-out: --no-telemetry flag sets the env var
255
  if no_telemetry:
256
  os.environ["HEADROOM_TELEMETRY"] = "off"
257
 
 
 
 
 
258
  # License key for managed/enterprise deployments (optional)
259
  license_key = os.environ.get("HEADROOM_LICENSE_KEY")
260
 
@@ -272,7 +292,7 @@ def proxy(
272
  connect_timeout_seconds=connect_timeout_seconds
273
  if connect_timeout_seconds is not None
274
  else 10,
275
- log_file=log_file,
276
  budget_limit_usd=budget,
277
  # Code graph: live file watcher for incremental reindexing
278
  code_graph_watcher=code_graph,
@@ -284,13 +304,15 @@ def proxy(
284
  intelligent_context_compress_first=not no_compress_first,
285
  # Memory System (Multi-Provider with auto-detection)
286
  # --learn implies --memory (need backend for storing patterns)
287
- memory_enabled=memory or (learn and not no_learn),
 
288
  memory_db_path=memory_db_path,
289
  memory_inject_tools=not no_memory_tools,
290
  memory_inject_context=not no_memory_context,
291
  memory_top_k=memory_top_k,
292
  # Traffic Learning: only with --learn, never with --no-learn
293
- traffic_learning_enabled=learn and not no_learn,
 
294
  traffic_learning_agent_type=os.environ.get("HEADROOM_AGENT_TYPE", "unknown"),
295
  # Backend (Anthropic direct, Bedrock, LiteLLM, or any-llm)
296
  backend=backend,
@@ -299,6 +321,8 @@ def proxy(
299
  anyllm_provider=effective_anyllm_provider,
300
  # License / Usage Reporting (managed/enterprise)
301
  license_key=license_key,
 
 
302
  )
303
 
304
  memory_status = "DISABLED"
@@ -355,6 +379,13 @@ Memory (Multi-Provider):
355
  - Database: {config.memory_db_path}
356
  """
357
 
 
 
 
 
 
 
 
358
  from headroom.telemetry.beacon import is_telemetry_enabled
359
 
360
  # Build telemetry section for the startup banner
@@ -381,7 +412,7 @@ Starting proxy server...
381
  Rate Limit: {"ENABLED" if config.rate_limit_enabled else "DISABLED"}
382
  Memory: {memory_status}
383
  License: {license_status}
384
- {telemetry_line}
385
  {backend_section}
386
  Routing:
387
  /v1/messages → {anthropic_url}
 
179
  is_flag=True,
180
  help="Disable anonymous usage telemetry (env: HEADROOM_TELEMETRY=off)",
181
  )
182
+ @click.option(
183
+ "--stateless",
184
+ is_flag=True,
185
+ help="Disable all filesystem writes — run purely in-memory. "
186
+ "For containerized / read-only / load-balanced deployments. "
187
+ "(env: HEADROOM_STATELESS=true)",
188
+ )
189
  @click.pass_context
190
  def proxy(
191
  ctx: click.Context,
 
220
  bedrock_region: str | None,
221
  bedrock_profile: str | None,
222
  no_telemetry: bool,
223
+ stateless: bool,
224
  ) -> None:
225
  """Start the optimization proxy server.
226
 
 
259
  mode or os.environ.get("HEADROOM_MODE") or PROXY_MODE_TOKEN
260
  )
261
 
262
+ # Stateless mode: CLI flag or env var
263
+ is_stateless = stateless or os.environ.get("HEADROOM_STATELESS", "").lower() in (
264
+ "true",
265
+ "1",
266
+ "yes",
267
+ "on",
268
+ )
269
+
270
  # Telemetry opt-out: --no-telemetry flag sets the env var
271
  if no_telemetry:
272
  os.environ["HEADROOM_TELEMETRY"] = "off"
273
 
274
+ # Stateless mode: suppress TOIN filesystem persistence
275
+ if is_stateless:
276
+ os.environ["HEADROOM_TOIN_BACKEND"] = "none"
277
+
278
  # License key for managed/enterprise deployments (optional)
279
  license_key = os.environ.get("HEADROOM_LICENSE_KEY")
280
 
 
292
  connect_timeout_seconds=connect_timeout_seconds
293
  if connect_timeout_seconds is not None
294
  else 10,
295
+ log_file=None if is_stateless else log_file,
296
  budget_limit_usd=budget,
297
  # Code graph: live file watcher for incremental reindexing
298
  code_graph_watcher=code_graph,
 
304
  intelligent_context_compress_first=not no_compress_first,
305
  # Memory System (Multi-Provider with auto-detection)
306
  # --learn implies --memory (need backend for storing patterns)
307
+ # Stateless mode disables memory (requires SQLite on disk)
308
+ memory_enabled=False if is_stateless else (memory or (learn and not no_learn)),
309
  memory_db_path=memory_db_path,
310
  memory_inject_tools=not no_memory_tools,
311
  memory_inject_context=not no_memory_context,
312
  memory_top_k=memory_top_k,
313
  # Traffic Learning: only with --learn, never with --no-learn
314
+ # Stateless mode disables learning (requires filesystem)
315
+ traffic_learning_enabled=False if is_stateless else (learn and not no_learn),
316
  traffic_learning_agent_type=os.environ.get("HEADROOM_AGENT_TYPE", "unknown"),
317
  # Backend (Anthropic direct, Bedrock, LiteLLM, or any-llm)
318
  backend=backend,
 
321
  anyllm_provider=effective_anyllm_provider,
322
  # License / Usage Reporting (managed/enterprise)
323
  license_key=license_key,
324
+ # Stateless mode: disable all filesystem writes
325
+ stateless=is_stateless,
326
  )
327
 
328
  memory_status = "DISABLED"
 
379
  - Database: {config.memory_db_path}
380
  """
381
 
382
+ # Stateless mode warning
383
+ stateless_line = ""
384
+ if is_stateless:
385
+ stateless_line = (
386
+ " Stateless: YES (no filesystem writes — memory, logs, TOIN disabled)\n"
387
+ )
388
+
389
  from headroom.telemetry.beacon import is_telemetry_enabled
390
 
391
  # Build telemetry section for the startup banner
 
412
  Rate Limit: {"ENABLED" if config.rate_limit_enabled else "DISABLED"}
413
  Memory: {memory_status}
414
  License: {license_status}
415
+ {stateless_line}{telemetry_line}
416
  {backend_section}
417
  Routing:
418
  /v1/messages → {anthropic_url}
headroom/cli/wrap.py CHANGED
@@ -192,13 +192,51 @@ def _setup_rtk(verbose: bool = False) -> Path | None:
192
  return rtk_path
193
 
194
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  def _setup_code_graph(verbose: bool = False) -> bool:
196
- """Ensure codebase-memory-mcp is installed and project is indexed.
197
 
198
  codebase-memory-mcp builds a knowledge graph of the codebase using
199
  tree-sitter, enabling the LLM to query code structure (call chains,
200
  function definitions, impact analysis) instead of reading entire files.
201
 
 
 
 
 
 
202
  With Claude Code's MCP Tool Search, the 14 graph tools add ~200 tokens
203
  overhead per request (not the full ~1,915) — they're lazy-loaded.
204
 
@@ -218,6 +256,9 @@ def _setup_code_graph(verbose: bool = False) -> bool:
218
 
219
  cbm_bin = str(cbm_path)
220
 
 
 
 
221
  # Index current project (fast — ~1s for most repos, idempotent)
222
  project_dir = str(Path.cwd())
223
  try:
@@ -320,6 +361,11 @@ rtk pip list rtk pnpm install rtk npm run <script>
320
  # Marker used to detect if instructions are already injected
321
  _RTK_MARKER = "<!-- headroom:rtk-instructions -->"
322
 
 
 
 
 
 
323
 
324
  def _ensure_rtk_binary(verbose: bool = False) -> Path | None:
325
  """Ensure rtk binary is installed (download if needed). No hook registration."""
@@ -364,11 +410,12 @@ def _inject_codex_provider_config(port: int) -> None:
364
  config_dir = Path.home() / ".codex"
365
  config_file = config_dir / "config.toml"
366
 
367
- headroom_section = (
368
- f"\n# --- Headroom proxy (auto-injected by headroom wrap codex) ---\n"
369
- f'model_provider = "headroom"\n'
370
- f"\n"
371
- f"[model_providers.headroom]\n"
 
372
  f'name = "OpenAI via Headroom proxy"\n'
373
  f'base_url = "http://127.0.0.1:{port}/v1"\n'
374
  f'env_key = "OPENAI_API_KEY"\n'
@@ -377,7 +424,7 @@ def _inject_codex_provider_config(port: int) -> None:
377
  f"# --- end Headroom ---\n"
378
  )
379
 
380
- marker = "# --- Headroom proxy (auto-injected by headroom wrap codex) ---"
381
  end_marker = "# --- end Headroom ---"
382
 
383
  try:
@@ -386,14 +433,20 @@ def _inject_codex_provider_config(port: int) -> None:
386
  if config_file.exists():
387
  content = config_file.read_text()
388
  if marker in content:
389
- # Replace existing section
390
  start = content.index(marker)
391
  end = content.index(end_marker) + len(end_marker)
392
- content = content[:start].rstrip() + headroom_section + content[end:].lstrip("\n")
393
- else:
394
- content = content.rstrip() + "\n" + headroom_section
 
 
 
 
 
 
395
  else:
396
- content = headroom_section
397
 
398
  config_file.write_text(content)
399
  click.echo(f" Codex config: injected Headroom provider (WS + HTTP) into {config_file}")
@@ -424,6 +477,81 @@ def _inject_rtk_instructions(file_path: Path, verbose: bool = False) -> bool:
424
  return True
425
 
426
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
427
  def _resolve_copilot_provider_type(backend: str | None, provider_type: str) -> str:
428
  """Resolve Copilot BYOK provider type for the current proxy backend."""
429
  if provider_type != "auto":
@@ -1088,6 +1216,48 @@ def claude(
1088
  signal.signal(signal.SIGINT, cleanup)
1089
  signal.signal(signal.SIGTERM, cleanup)
1090
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1091
  try:
1092
  click.echo()
1093
  click.echo(" ╔═══════════════════════════════════════════════╗")
@@ -1241,6 +1411,20 @@ def copilot(
1241
  env["COPILOT_PROVIDER_TYPE"] = effective_provider_type
1242
  env.pop("COPILOT_PROVIDER_WIRE_API", None)
1243
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1244
  env_vars_display: list[str]
1245
  if effective_provider_type == "anthropic":
1246
  env["COPILOT_PROVIDER_BASE_URL"] = f"http://127.0.0.1:{port}"
@@ -1258,6 +1442,19 @@ def copilot(
1258
  f"COPILOT_PROVIDER_WIRE_API={effective_wire_api}",
1259
  ]
1260
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1261
  if not _copilot_model_configured(copilot_args, env):
1262
  click.echo(
1263
  " Note: Copilot BYOK requires a model. Pass `--model <name>` "
@@ -1357,6 +1554,49 @@ def codex(
1357
  global_agents = Path.home() / ".codex" / "AGENTS.md"
1358
  _inject_rtk_instructions(global_agents, verbose=verbose)
1359
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1360
  if prepare_only:
1361
  _inject_codex_provider_config(port)
1362
  return
@@ -1373,7 +1613,15 @@ def codex(
1373
  # Inject Headroom provider into Codex config so WebSocket traffic also
1374
  # routes through the proxy. Codex ignores OPENAI_BASE_URL for its WS
1375
  # transport unless a custom provider declares supports_websockets = true.
 
 
1376
  _inject_codex_provider_config(port)
 
 
 
 
 
 
1377
 
1378
  _launch_tool(
1379
  binary=codex_bin,
 
192
  return rtk_path
193
 
194
 
195
+ _CBM_MCP_SERVER_NAME = "codebase-memory-mcp"
196
+
197
+
198
+ def _register_cbm_mcp_server(cbm_bin: str) -> None:
199
+ """Register codebase-memory-mcp as an MCP server in Claude Code.
200
+
201
+ Uses ``claude mcp add`` so the tools appear in ``/mcp`` automatically.
202
+ Idempotent — skips if already registered.
203
+ """
204
+ claude_cli = shutil.which("claude")
205
+ if not claude_cli:
206
+ return
207
+
208
+ # Check if already registered
209
+ check = subprocess.run(
210
+ [claude_cli, "mcp", "get", _CBM_MCP_SERVER_NAME],
211
+ capture_output=True,
212
+ text=True,
213
+ )
214
+ if check.returncode == 0:
215
+ return # Already registered
216
+
217
+ result = subprocess.run(
218
+ [claude_cli, "mcp", "add", _CBM_MCP_SERVER_NAME, "-s", "user", "--", cbm_bin],
219
+ capture_output=True,
220
+ text=True,
221
+ )
222
+ if result.returncode == 0:
223
+ click.echo(f" Code graph: registered {_CBM_MCP_SERVER_NAME} MCP server")
224
+ else:
225
+ pass # Non-critical — tools won't appear in /mcp but graph still works
226
+
227
+
228
  def _setup_code_graph(verbose: bool = False) -> bool:
229
+ """Ensure codebase-memory-mcp is installed, registered as MCP server, and project is indexed.
230
 
231
  codebase-memory-mcp builds a knowledge graph of the codebase using
232
  tree-sitter, enabling the LLM to query code structure (call chains,
233
  function definitions, impact analysis) instead of reading entire files.
234
 
235
+ Steps:
236
+ 1. Download the binary if not already present.
237
+ 2. Register as an MCP server in Claude Code (``claude mcp add``).
238
+ 3. Index the current project (fast, idempotent).
239
+
240
  With Claude Code's MCP Tool Search, the 14 graph tools add ~200 tokens
241
  overhead per request (not the full ~1,915) — they're lazy-loaded.
242
 
 
256
 
257
  cbm_bin = str(cbm_path)
258
 
259
+ # Register as MCP server so tools appear in /mcp
260
+ _register_cbm_mcp_server(cbm_bin)
261
+
262
  # Index current project (fast — ~1s for most repos, idempotent)
263
  project_dir = str(Path.cwd())
264
  try:
 
361
  # Marker used to detect if instructions are already injected
362
  _RTK_MARKER = "<!-- headroom:rtk-instructions -->"
363
 
364
+ # Memory MCP markers
365
+ _MEMORY_MCP_MARKER = "# --- Headroom memory MCP (auto-injected) ---"
366
+ _MEMORY_MCP_END = "# --- end Headroom memory ---"
367
+ _MEMORY_AGENTS_MARKER = "<!-- headroom:memory-instructions -->"
368
+
369
 
370
  def _ensure_rtk_binary(verbose: bool = False) -> Path | None:
371
  """Ensure rtk binary is installed (download if needed). No hook registration."""
 
410
  config_dir = Path.home() / ".codex"
411
  config_file = config_dir / "config.toml"
412
 
413
+ # model_provider must be a top-level TOML key (before any [section]).
414
+ # The [model_providers.headroom] table can go at the end.
415
+ top_level_marker = "# --- Headroom proxy (auto-injected by headroom wrap codex) ---"
416
+ top_level_block = f'{top_level_marker}\nmodel_provider = "headroom"\n'
417
+ provider_section = (
418
+ f"\n[model_providers.headroom]\n"
419
  f'name = "OpenAI via Headroom proxy"\n'
420
  f'base_url = "http://127.0.0.1:{port}/v1"\n'
421
  f'env_key = "OPENAI_API_KEY"\n'
 
424
  f"# --- end Headroom ---\n"
425
  )
426
 
427
+ marker = top_level_marker
428
  end_marker = "# --- end Headroom ---"
429
 
430
  try:
 
433
  if config_file.exists():
434
  content = config_file.read_text()
435
  if marker in content:
436
+ # Remove existing Headroom blocks entirely
437
  start = content.index(marker)
438
  end = content.index(end_marker) + len(end_marker)
439
+ content = content[:start].rstrip("\n") + content[end:].lstrip("\n")
440
+
441
+ # Strip any stale top-level model_provider left behind
442
+ import re
443
+
444
+ content = re.sub(r'\nmodel_provider\s*=\s*"headroom"\n', "\n", content)
445
+
446
+ # Place top-level key at the very beginning, provider table at the end
447
+ content = top_level_block + "\n" + content.strip() + "\n" + provider_section
448
  else:
449
+ content = top_level_block + "\n" + provider_section
450
 
451
  config_file.write_text(content)
452
  click.echo(f" Codex config: injected Headroom provider (WS + HTTP) into {config_file}")
 
477
  return True
478
 
479
 
480
+ def _inject_memory_mcp_config(db_path: str, user_id: str) -> None:
481
+ """Register headroom memory as an MCP server in Codex's config.toml.
482
+
483
+ Idempotent — replaces existing section if present.
484
+ """
485
+ import sys
486
+
487
+ config_dir = Path.home() / ".codex"
488
+ config_file = config_dir / "config.toml"
489
+
490
+ # Use forward slashes in TOML paths (works on all platforms, avoids
491
+ # backslash escaping issues on Windows)
492
+ python_bin = sys.executable.replace("\\", "/")
493
+ db_path_toml = db_path.replace("\\", "/")
494
+ mcp_section = (
495
+ f"\n{_MEMORY_MCP_MARKER}\n"
496
+ f"[mcp_servers.headroom_memory]\n"
497
+ f'command = "{python_bin}"\n'
498
+ f'args = ["-m", "headroom.memory.mcp_server", "--db", "{db_path_toml}", "--user", "{user_id}"]\n'
499
+ f"startup_timeout_sec = 30\n"
500
+ f"tool_timeout_sec = 30\n"
501
+ f"{_MEMORY_MCP_END}\n"
502
+ )
503
+
504
+ try:
505
+ config_dir.mkdir(parents=True, exist_ok=True)
506
+
507
+ if config_file.exists():
508
+ content = config_file.read_text()
509
+ if _MEMORY_MCP_MARKER in content:
510
+ start = content.index(_MEMORY_MCP_MARKER)
511
+ end = content.index(_MEMORY_MCP_END) + len(_MEMORY_MCP_END)
512
+ content = content[:start].rstrip("\n") + mcp_section + content[end:].lstrip("\n")
513
+ else:
514
+ content = content.rstrip() + "\n" + mcp_section
515
+ else:
516
+ content = mcp_section
517
+
518
+ config_file.write_text(content)
519
+ click.echo(f" Memory MCP: registered in {config_file}")
520
+ except Exception as e:
521
+ click.echo(f" Warning: could not register memory MCP: {e}")
522
+
523
+
524
+ def _inject_memory_agents_md(file_path: Path) -> bool:
525
+ """Inject memory usage guidance into AGENTS.md.
526
+
527
+ Idempotent — skips if marker already present.
528
+ """
529
+ memory_block = (
530
+ f"{_MEMORY_AGENTS_MARKER}\n"
531
+ "## Memory\n\n"
532
+ "Use the `headroom_memory` MCP server for persistent cross-session knowledge.\n\n"
533
+ "**Before** answering questions about prior decisions, conventions, project context,\n"
534
+ "architecture, user preferences, org info, codenames, debugging history, or anything\n"
535
+ "from past sessions — call `memory_search` first.\n\n"
536
+ "**After** making durable decisions, discovering conventions, or learning important\n"
537
+ "facts — call `memory_save` to persist them for future sessions.\n\n"
538
+ "Memory is your first source of truth for anything not visible in the current conversation.\n"
539
+ )
540
+
541
+ if file_path.exists():
542
+ existing = file_path.read_text()
543
+ if _MEMORY_AGENTS_MARKER in existing:
544
+ return True # Already injected
545
+ with open(file_path, "a") as f:
546
+ f.write("\n\n" + memory_block)
547
+ else:
548
+ file_path.parent.mkdir(parents=True, exist_ok=True)
549
+ file_path.write_text(memory_block)
550
+
551
+ click.echo(f" Memory guidance injected into {file_path.name}")
552
+ return True
553
+
554
+
555
  def _resolve_copilot_provider_type(backend: str | None, provider_type: str) -> str:
556
  """Resolve Copilot BYOK provider type for the current proxy backend."""
557
  if provider_type != "auto":
 
1216
  signal.signal(signal.SIGINT, cleanup)
1217
  signal.signal(signal.SIGTERM, cleanup)
1218
 
1219
+ # Memory sync BEFORE proxy startup — sync headroom DB ↔ Claude's files
1220
+ if memory:
1221
+ try:
1222
+ import subprocess as _sp
1223
+
1224
+ mem_dir = Path.cwd() / ".headroom"
1225
+ mem_dir.mkdir(parents=True, exist_ok=True)
1226
+ _sync_db = str(mem_dir / "memory.db")
1227
+ _sync_user = os.environ.get("USER", os.environ.get("USERNAME", "default"))
1228
+
1229
+ click.echo(f" Syncing memory (user={_sync_user})...")
1230
+ sync_result = _sp.run(
1231
+ [
1232
+ sys.executable,
1233
+ "-m",
1234
+ "headroom.memory.sync",
1235
+ "--db",
1236
+ _sync_db,
1237
+ "--user",
1238
+ _sync_user,
1239
+ "--agent",
1240
+ "claude",
1241
+ "--force",
1242
+ ],
1243
+ capture_output=True,
1244
+ text=True,
1245
+ timeout=30,
1246
+ )
1247
+ if sync_result.returncode == 0 and sync_result.stdout.strip():
1248
+ import json as _json
1249
+
1250
+ stats = _json.loads(sync_result.stdout.strip().split("\n")[-1])
1251
+ imp, exp, ms = stats["imported"], stats["exported"], stats["ms"]
1252
+ if imp or exp:
1253
+ click.echo(f" Memory synced: {imp} imported, {exp} exported ({ms}ms)")
1254
+ else:
1255
+ click.echo(f" Memory: up to date ({ms}ms)")
1256
+ elif sync_result.returncode != 0:
1257
+ click.echo(f" Warning: memory sync error: {sync_result.stderr[-200:]}")
1258
+ except Exception as e:
1259
+ click.echo(f" Warning: memory sync failed: {e}")
1260
+
1261
  try:
1262
  click.echo()
1263
  click.echo(" ╔═══════════════════════════════════════════════╗")
 
1411
  env["COPILOT_PROVIDER_TYPE"] = effective_provider_type
1412
  env.pop("COPILOT_PROVIDER_WIRE_API", None)
1413
 
1414
+ # Copilot BYOK requires COPILOT_PROVIDER_API_KEY — propagate from the
1415
+ # user's existing provider key so they don't have to set it twice.
1416
+ # Note: `headroom wrap copilot` uses Copilot's BYOK mode, which bypasses
1417
+ # GitHub's Copilot API and talks directly to the model provider through
1418
+ # the Headroom proxy. This requires the provider's own API key — a GitHub
1419
+ # Copilot subscription alone is not sufficient for BYOK mode.
1420
+ if not env.get("COPILOT_PROVIDER_API_KEY"):
1421
+ if effective_provider_type == "anthropic":
1422
+ _key = env.get("ANTHROPIC_API_KEY", "")
1423
+ else:
1424
+ _key = env.get("OPENAI_API_KEY", "")
1425
+ if _key:
1426
+ env["COPILOT_PROVIDER_API_KEY"] = _key
1427
+
1428
  env_vars_display: list[str]
1429
  if effective_provider_type == "anthropic":
1430
  env["COPILOT_PROVIDER_BASE_URL"] = f"http://127.0.0.1:{port}"
 
1442
  f"COPILOT_PROVIDER_WIRE_API={effective_wire_api}",
1443
  ]
1444
 
1445
+ if not env.get("COPILOT_PROVIDER_API_KEY"):
1446
+ src = "ANTHROPIC_API_KEY" if effective_provider_type == "anthropic" else "OPENAI_API_KEY"
1447
+ click.echo(
1448
+ f"\n Error: Copilot BYOK mode requires a provider API key.\n"
1449
+ f" `headroom wrap copilot` uses Copilot's BYOK mode, which bypasses GitHub's\n"
1450
+ f" Copilot API and routes requests directly to the model provider through the\n"
1451
+ f" Headroom proxy. A GitHub Copilot subscription alone is not sufficient.\n\n"
1452
+ f" Set one of:\n"
1453
+ f" export {src}=sk-... # recommended\n"
1454
+ f" export COPILOT_PROVIDER_API_KEY=sk-... # also works\n"
1455
+ )
1456
+ raise SystemExit(1)
1457
+
1458
  if not _copilot_model_configured(copilot_args, env):
1459
  click.echo(
1460
  " Note: Copilot BYOK requires a model. Pass `--model <name>` "
 
1554
  global_agents = Path.home() / ".codex" / "AGENTS.md"
1555
  _inject_rtk_instructions(global_agents, verbose=verbose)
1556
 
1557
+ # Setup memory MCP server for Codex (native tool integration)
1558
+ if memory:
1559
+ click.echo(" Setting up memory for Codex...")
1560
+ mem_dir = Path.cwd() / ".headroom"
1561
+ mem_dir.mkdir(parents=True, exist_ok=True)
1562
+ db_path = str(mem_dir / "memory.db")
1563
+ mem_user = os.environ.get("USER", os.environ.get("USERNAME", "default"))
1564
+
1565
+ # Register MCP server in Codex config
1566
+ _inject_memory_mcp_config(db_path, mem_user)
1567
+
1568
+ # Inject memory guidance into project AGENTS.md
1569
+ agents_md = Path.cwd() / "AGENTS.md"
1570
+ _inject_memory_agents_md(agents_md)
1571
+
1572
+ # Sync Claude's memories → DB so MCP search finds them
1573
+ try:
1574
+ import asyncio
1575
+
1576
+ from headroom.memory.backends.local import LocalBackend, LocalBackendConfig
1577
+ from headroom.memory.sync import sync_import
1578
+ from headroom.memory.sync_adapters.claude_code import (
1579
+ ClaudeCodeAdapter,
1580
+ get_claude_memory_dir,
1581
+ )
1582
+
1583
+ claude_memory_dir = get_claude_memory_dir()
1584
+
1585
+ async def _import_claude_memories() -> int:
1586
+ config = LocalBackendConfig(db_path=db_path)
1587
+ backend = LocalBackend(config)
1588
+ await backend._ensure_initialized()
1589
+ adapter = ClaudeCodeAdapter(claude_memory_dir)
1590
+ count = await sync_import(backend, adapter, mem_user)
1591
+ await backend.close()
1592
+ return count
1593
+
1594
+ imported = asyncio.run(_import_claude_memories())
1595
+ if imported:
1596
+ click.echo(f" Memory: imported {imported} memories from Claude")
1597
+ except Exception as e:
1598
+ click.echo(f" Warning: Claude memory import failed: {e}")
1599
+
1600
  if prepare_only:
1601
  _inject_codex_provider_config(port)
1602
  return
 
1613
  # Inject Headroom provider into Codex config so WebSocket traffic also
1614
  # routes through the proxy. Codex ignores OPENAI_BASE_URL for its WS
1615
  # transport unless a custom provider declares supports_websockets = true.
1616
+ # NOTE: this must run BEFORE _inject_memory_mcp_config because it rewrites
1617
+ # the config file. Re-inject MCP config after if memory is enabled.
1618
  _inject_codex_provider_config(port)
1619
+ if memory:
1620
+ mem_dir = Path.cwd() / ".headroom"
1621
+ _inject_memory_mcp_config(
1622
+ str(mem_dir / "memory.db"),
1623
+ os.environ.get("USER", os.environ.get("USERNAME", "default")),
1624
+ )
1625
 
1626
  _launch_tool(
1627
  binary=codex_bin,
headroom/install/state.py CHANGED
@@ -3,21 +3,30 @@
3
  from __future__ import annotations
4
 
5
  import json
 
6
  import shutil
7
  from dataclasses import asdict
8
 
9
  from .models import ArtifactRecord, DeploymentManifest, ManagedMutation, iso_utc_now
10
  from .paths import deploy_root, manifest_path, profile_root
11
 
 
12
 
13
- def save_manifest(manifest: DeploymentManifest) -> None:
14
- """Persist a deployment manifest to disk."""
15
 
16
- root = profile_root(manifest.profile)
17
- root.mkdir(parents=True, exist_ok=True)
18
- manifest.updated_at = iso_utc_now()
19
- path = manifest_path(manifest.profile)
20
- path.write_text(json.dumps(asdict(manifest), indent=2) + "\n")
 
 
 
 
 
 
 
 
 
21
 
22
 
23
  def load_manifest(profile: str = "default") -> DeploymentManifest | None:
 
3
  from __future__ import annotations
4
 
5
  import json
6
+ import logging
7
  import shutil
8
  from dataclasses import asdict
9
 
10
  from .models import ArtifactRecord, DeploymentManifest, ManagedMutation, iso_utc_now
11
  from .paths import deploy_root, manifest_path, profile_root
12
 
13
+ logger = logging.getLogger(__name__)
14
 
 
 
15
 
16
+ def save_manifest(manifest: DeploymentManifest) -> None:
17
+ """Persist a deployment manifest to disk.
18
+
19
+ Gracefully handles read-only filesystems by logging a warning
20
+ instead of crashing.
21
+ """
22
+ try:
23
+ root = profile_root(manifest.profile)
24
+ root.mkdir(parents=True, exist_ok=True)
25
+ manifest.updated_at = iso_utc_now()
26
+ path = manifest_path(manifest.profile)
27
+ path.write_text(json.dumps(asdict(manifest), indent=2) + "\n")
28
+ except OSError as e:
29
+ logger.warning("Cannot save deployment manifest: %s — continuing without persistence", e)
30
 
31
 
32
  def load_manifest(profile: str = "default") -> DeploymentManifest | None:
headroom/learn/base.py CHANGED
@@ -25,7 +25,7 @@ class ConversationScanner(ABC):
25
  ...
26
 
27
  @abstractmethod
28
- def scan_project(self, project: ProjectInfo) -> list[SessionData]:
29
  """Scan all sessions for a project, returning normalized tool calls."""
30
  ...
31
 
@@ -99,8 +99,14 @@ class LearnPlugin(ABC):
99
  ...
100
 
101
  @abstractmethod
102
- def scan_project(self, project: ProjectInfo) -> list[SessionData]:
103
- """Scan all sessions for a project, returning normalized data."""
 
 
 
 
 
 
104
  ...
105
 
106
  # --- Writing ---
 
25
  ...
26
 
27
  @abstractmethod
28
+ def scan_project(self, project: ProjectInfo, max_workers: int = 1) -> list[SessionData]:
29
  """Scan all sessions for a project, returning normalized tool calls."""
30
  ...
31
 
 
99
  ...
100
 
101
  @abstractmethod
102
+ def scan_project(self, project: ProjectInfo, max_workers: int = 1) -> list[SessionData]:
103
+ """Scan all sessions for a project, returning normalized data.
104
+
105
+ Args:
106
+ project: The project to scan.
107
+ max_workers: Number of threads for parallel file scanning.
108
+ 1 (default) = serial. >1 = concurrent.
109
+ """
110
  ...
111
 
112
  # --- Writing ---
headroom/learn/plugins/claude.py CHANGED
@@ -106,16 +106,24 @@ class ClaudeCodePlugin(LearnPlugin, ConversationScanner):
106
 
107
  return projects
108
 
109
- def scan_project(self, project: ProjectInfo) -> list[SessionData]:
110
  """Scan all conversation JSONL files for a project."""
111
- sessions = []
112
  jsonl_files = sorted(project.data_path.glob("*.jsonl"))
 
 
 
 
 
113
 
114
- for jsonl_path in jsonl_files:
115
- session = self._scan_session(jsonl_path)
116
- if session and session.tool_calls:
117
- sessions.append(session)
118
 
 
 
 
 
 
 
 
119
  return sessions
120
 
121
  def _scan_session(self, jsonl_path: Path) -> SessionData | None:
@@ -375,9 +383,9 @@ def _component_tokenizations(component: str) -> list[list[str]]:
375
 
376
  add([component])
377
 
378
- for separator in ("-", ".", None):
379
  if separator is None:
380
- tokens = [token for token in re.split(r"[-.]", component) if token]
381
  else:
382
  tokens = [token for token in component.split(separator) if token]
383
  add(tokens)
@@ -385,9 +393,9 @@ def _component_tokenizations(component: str) -> list[list[str]]:
385
  if component.startswith(".") and len(component) > 1:
386
  hidden_component = component[1:]
387
  add(["", hidden_component])
388
- for separator in ("-", ".", None):
389
  if separator is None:
390
- tokens = [token for token in re.split(r"[-.]", hidden_component) if token]
391
  else:
392
  tokens = [token for token in hidden_component.split(separator) if token]
393
  add(["", *tokens])
 
106
 
107
  return projects
108
 
109
+ def scan_project(self, project: ProjectInfo, max_workers: int = 1) -> list[SessionData]:
110
  """Scan all conversation JSONL files for a project."""
 
111
  jsonl_files = sorted(project.data_path.glob("*.jsonl"))
112
+ if not jsonl_files:
113
+ return []
114
+
115
+ if max_workers <= 1 or len(jsonl_files) <= 1:
116
+ return [s for f in jsonl_files if (s := self._scan_session(f)) and s.tool_calls]
117
 
118
+ from concurrent.futures import ThreadPoolExecutor, as_completed
 
 
 
119
 
120
+ sessions: list[SessionData] = []
121
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
122
+ futures = {executor.submit(self._scan_session, f): f for f in jsonl_files}
123
+ for future in as_completed(futures):
124
+ session = future.result()
125
+ if session and session.tool_calls:
126
+ sessions.append(session)
127
  return sessions
128
 
129
  def _scan_session(self, jsonl_path: Path) -> SessionData | None:
 
383
 
384
  add([component])
385
 
386
+ for separator in ("-", ".", "_", None):
387
  if separator is None:
388
+ tokens = [token for token in re.split(r"[-._]", component) if token]
389
  else:
390
  tokens = [token for token in component.split(separator) if token]
391
  add(tokens)
 
393
  if component.startswith(".") and len(component) > 1:
394
  hidden_component = component[1:]
395
  add(["", hidden_component])
396
+ for separator in ("-", ".", "_", None):
397
  if separator is None:
398
+ tokens = [token for token in re.split(r"[-._]", hidden_component) if token]
399
  else:
400
  tokens = [token for token in hidden_component.split(separator) if token]
401
  add(["", *tokens])
headroom/learn/plugins/codex.py CHANGED
@@ -91,13 +91,24 @@ class CodexPlugin(LearnPlugin, ConversationScanner):
91
  )
92
  ]
93
 
94
- def scan_project(self, project: ProjectInfo) -> list[SessionData]:
95
  """Scan all Codex session JSON files."""
96
- sessions = []
97
- for json_path in self._iter_session_files(project.data_path):
98
- session = self._scan_session(json_path)
99
- if session and session.tool_calls:
100
- sessions.append(session)
 
 
 
 
 
 
 
 
 
 
 
101
  return sessions
102
 
103
  def _scan_session(self, json_path: Path) -> SessionData | None:
 
91
  )
92
  ]
93
 
94
+ def scan_project(self, project: ProjectInfo, max_workers: int = 1) -> list[SessionData]:
95
  """Scan all Codex session JSON files."""
96
+ session_files = self._iter_session_files(project.data_path)
97
+ if not session_files:
98
+ return []
99
+
100
+ if max_workers <= 1 or len(session_files) <= 1:
101
+ return [s for f in session_files if (s := self._scan_session(f)) and s.tool_calls]
102
+
103
+ from concurrent.futures import ThreadPoolExecutor, as_completed
104
+
105
+ sessions: list[SessionData] = []
106
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
107
+ futures = {executor.submit(self._scan_session, f): f for f in session_files}
108
+ for future in as_completed(futures):
109
+ session = future.result()
110
+ if session and session.tool_calls:
111
+ sessions.append(session)
112
  return sessions
113
 
114
  def _scan_session(self, json_path: Path) -> SessionData | None:
headroom/learn/plugins/gemini.py CHANGED
@@ -106,18 +106,26 @@ class GeminiPlugin(LearnPlugin, ConversationScanner):
106
 
107
  return projects
108
 
109
- def scan_project(self, project: ProjectInfo) -> list[SessionData]:
110
  """Scan all Gemini session files for a project."""
111
- sessions = []
112
  session_files = sorted(project.data_path.glob("session-*.json")) + sorted(
113
  project.data_path.glob("session-*.jsonl")
114
  )
 
 
 
 
 
115
 
116
- for session_path in session_files:
117
- session = self._scan_session(session_path)
118
- if session and session.tool_calls:
119
- sessions.append(session)
120
 
 
 
 
 
 
 
 
121
  return sessions
122
 
123
  def _scan_session(self, session_path: Path) -> SessionData | None:
 
106
 
107
  return projects
108
 
109
+ def scan_project(self, project: ProjectInfo, max_workers: int = 1) -> list[SessionData]:
110
  """Scan all Gemini session files for a project."""
 
111
  session_files = sorted(project.data_path.glob("session-*.json")) + sorted(
112
  project.data_path.glob("session-*.jsonl")
113
  )
114
+ if not session_files:
115
+ return []
116
+
117
+ if max_workers <= 1 or len(session_files) <= 1:
118
+ return [s for f in session_files if (s := self._scan_session(f)) and s.tool_calls]
119
 
120
+ from concurrent.futures import ThreadPoolExecutor, as_completed
 
 
 
121
 
122
+ sessions: list[SessionData] = []
123
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
124
+ futures = {executor.submit(self._scan_session, f): f for f in session_files}
125
+ for future in as_completed(futures):
126
+ session = future.result()
127
+ if session and session.tool_calls:
128
+ sessions.append(session)
129
  return sessions
130
 
131
  def _scan_session(self, session_path: Path) -> SessionData | None:
headroom/memory/factory.py CHANGED
@@ -7,6 +7,7 @@ and proper wiring between components.
7
 
8
  from __future__ import annotations
9
 
 
10
  from typing import TYPE_CHECKING
11
 
12
  from headroom.memory.config import (
@@ -199,12 +200,21 @@ def _create_vector_index(config: MemoryConfig) -> VectorIndex:
199
 
200
  from headroom.memory.adapters.hnsw import HNSWVectorIndex
201
 
 
 
 
 
 
 
 
202
  return HNSWVectorIndex(
203
  dimension=config.vector_dimension,
204
  ef_construction=config.hnsw_ef_construction,
205
  m=config.hnsw_m,
206
  ef_search=config.hnsw_ef_search,
207
  max_entries=config.hnsw_max_entries,
 
 
208
  )
209
 
210
  raise ValueError(f"Unknown vector backend: {config.vector_backend}")
 
7
 
8
  from __future__ import annotations
9
 
10
+ from pathlib import Path
11
  from typing import TYPE_CHECKING
12
 
13
  from headroom.memory.config import (
 
200
 
201
  from headroom.memory.adapters.hnsw import HNSWVectorIndex
202
 
203
+ # Derive persistent save path from the main DB path so the HNSW
204
+ # index survives across process restarts (critical for cross-agent
205
+ # interop: memories saved by Codex MCP must be searchable by Claude).
206
+ hnsw_save_path: str | Path | None = None
207
+ if config.db_path:
208
+ hnsw_save_path = config.db_path.parent / f"{config.db_path.stem}_hnsw"
209
+
210
  return HNSWVectorIndex(
211
  dimension=config.vector_dimension,
212
  ef_construction=config.hnsw_ef_construction,
213
  m=config.hnsw_m,
214
  ef_search=config.hnsw_ef_search,
215
  max_entries=config.hnsw_max_entries,
216
+ save_path=hnsw_save_path,
217
+ auto_save=True,
218
  )
219
 
220
  raise ValueError(f"Unknown vector backend: {config.vector_backend}")
headroom/memory/mcp_server.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Headroom Memory MCP Server.
2
+
3
+ A stdio MCP server that exposes headroom's memory backend as tools
4
+ that Codex (or any MCP-compatible client) can call natively.
5
+
6
+ Tools:
7
+ memory_search — semantic search across stored memories
8
+ memory_save — persist a new fact/decision/convention
9
+
10
+ Design:
11
+ - Embedder is pre-loaded at startup (no cold-start on first query)
12
+ - On startup, any memories missing vector embeddings are re-indexed
13
+ (fixes interop gap when memories were saved via a different path)
14
+ - Save always generates embeddings inline
15
+
16
+ Usage:
17
+ # Standalone (for testing):
18
+ python -m headroom.memory.mcp_server --db /path/to/.headroom/memory.db
19
+
20
+ # Registered in Codex config.toml (done by `headroom wrap codex --memory`):
21
+ [mcp_servers.headroom_memory]
22
+ command = "python"
23
+ args = ["-m", "headroom.memory.mcp_server", "--db", ".headroom/memory.db"]
24
+ """
25
+
26
+ from __future__ import annotations
27
+
28
+ import argparse
29
+ import asyncio
30
+ import logging
31
+ import os
32
+ import sys
33
+ from pathlib import Path
34
+ from typing import Any
35
+
36
+ from mcp.server import Server
37
+ from mcp.server.stdio import stdio_server
38
+ from mcp.types import TextContent, Tool
39
+
40
+ from headroom.memory.backends.local import LocalBackend, LocalBackendConfig
41
+
42
+ logger = logging.getLogger("headroom.memory.mcp")
43
+
44
+ # ---------------------------------------------------------------------------
45
+ # Tool definitions
46
+ # ---------------------------------------------------------------------------
47
+
48
+ _TOOLS = [
49
+ Tool(
50
+ name="memory_search",
51
+ description=(
52
+ "Search persistent memory for relevant knowledge from prior sessions. "
53
+ "Use this for questions about architecture, conventions, prior decisions, "
54
+ "project context, user preferences, org info, codenames, debugging history, "
55
+ "or anything that might have been discussed before."
56
+ ),
57
+ inputSchema={
58
+ "type": "object",
59
+ "properties": {
60
+ "query": {
61
+ "type": "string",
62
+ "description": "Natural-language search query.",
63
+ },
64
+ "top_k": {
65
+ "type": "integer",
66
+ "description": "Max results to return (default 10).",
67
+ "default": 10,
68
+ },
69
+ },
70
+ "required": ["query"],
71
+ },
72
+ ),
73
+ Tool(
74
+ name="memory_save",
75
+ description=(
76
+ "Save information to persistent memory for future sessions. "
77
+ "Use this for decisions, conventions, architecture context, "
78
+ "user preferences, project facts, or anything worth remembering.\n\n"
79
+ "IMPORTANT: Break information into atomic facts — one fact per "
80
+ "entry in the 'facts' array. Each fact should be a single, "
81
+ "self-contained statement that answers one question. "
82
+ "Do NOT combine multiple facts into one string.\n\n"
83
+ "Good: facts: ['Repo owner is Tejas C.', 'User prefers dark mode']\n"
84
+ "Bad: facts: ['Repo owner is Tejas C. Prefers dark mode.']"
85
+ ),
86
+ inputSchema={
87
+ "type": "object",
88
+ "properties": {
89
+ "facts": {
90
+ "type": "array",
91
+ "items": {"type": "string"},
92
+ "description": (
93
+ "Array of atomic facts to save. Each entry should be "
94
+ "one self-contained fact. The system stores and indexes "
95
+ "each fact separately for precise retrieval."
96
+ ),
97
+ },
98
+ "importance": {
99
+ "type": "number",
100
+ "description": "0.0 (low) to 1.0 (critical). Default 0.7.",
101
+ "default": 0.7,
102
+ },
103
+ },
104
+ "required": [],
105
+ },
106
+ ),
107
+ ]
108
+
109
+
110
+ # ---------------------------------------------------------------------------
111
+ # Startup: pre-load embedder + re-index unembedded memories
112
+ # ---------------------------------------------------------------------------
113
+
114
+
115
+ async def _warm_up_backend(backend: LocalBackend, user_id: str) -> None:
116
+ """Pre-load the embedder and re-index memories that lack embeddings.
117
+
118
+ Memories saved via other paths (e.g. Claude Code proxy direct SQL)
119
+ may exist in the store but have no vector embeddings. This scans
120
+ for those and re-indexes them so vector search works across agents.
121
+ """
122
+ await backend._ensure_initialized()
123
+ hm = backend._hierarchical_memory
124
+ if hm is None:
125
+ return
126
+
127
+ # Force-load the embedder now (not lazily on first search)
128
+ _dummy = await hm._embedder.embed("warmup")
129
+ logger.info("Memory MCP: embedder pre-loaded")
130
+
131
+ # Ensure ALL memories are in the vector index.
132
+ # Memories saved via other agents (Claude Code proxy, direct SQL) may
133
+ # exist in the store but not be indexed — re-embed and index them all.
134
+ all_memories = await backend.get_user_memories(user_id, limit=500)
135
+ if not all_memories:
136
+ return
137
+
138
+ indexed = 0
139
+ for mem in all_memories:
140
+ if mem.embedding is None:
141
+ mem.embedding = await hm._embedder.embed(mem.content)
142
+ await hm._store.save(mem)
143
+ await hm._vector_index.index(mem)
144
+ indexed += 1
145
+
146
+ logger.info(f"Memory MCP: indexed {indexed} memories into vector store")
147
+
148
+
149
+ # ---------------------------------------------------------------------------
150
+ # MCP Server
151
+ # ---------------------------------------------------------------------------
152
+
153
+
154
+ def create_memory_server(db_path: str, user_id: str = "default") -> Server:
155
+ """Create an MCP server backed by headroom's local memory."""
156
+
157
+ server = Server("headroom-memory")
158
+ _backend: LocalBackend | None = None
159
+ _init_task: asyncio.Task | None = None
160
+
161
+ async def _init_backend() -> LocalBackend:
162
+ """Initialize backend with ONNX embedder (fast, no PyTorch)."""
163
+ nonlocal _backend
164
+ config = LocalBackendConfig(db_path=db_path, embedder_backend="onnx")
165
+ _backend = LocalBackend(config)
166
+ await _warm_up_backend(_backend, user_id)
167
+ logger.info(f"Memory MCP: ready (db={db_path}, user={user_id})")
168
+ return _backend
169
+
170
+ async def _get_backend() -> LocalBackend:
171
+ nonlocal _backend, _init_task
172
+ if _backend is not None:
173
+ return _backend
174
+ # Wait for background init if it's running
175
+ if _init_task is not None:
176
+ await _init_task
177
+ return _backend # type: ignore[return-value]
178
+ # Fallback: init inline (shouldn't normally happen)
179
+ return await _init_backend()
180
+
181
+ @server.list_tools()
182
+ async def list_tools() -> list[Tool]:
183
+ # Kick off background init on first list_tools (called at MCP handshake)
184
+ nonlocal _init_task
185
+ if _backend is None and _init_task is None:
186
+ _init_task = asyncio.create_task(_init_backend())
187
+ return _TOOLS
188
+
189
+ @server.call_tool()
190
+ async def call_tool(name: str, arguments: dict) -> list[TextContent]:
191
+ backend = await _get_backend()
192
+
193
+ if name == "memory_search":
194
+ return await _handle_search(backend, arguments, user_id)
195
+ elif name == "memory_save":
196
+ return await _handle_save(backend, arguments, user_id)
197
+
198
+ return [TextContent(type="text", text=f"Unknown tool: {name}")]
199
+
200
+ return server
201
+
202
+
203
+ async def _handle_search(
204
+ backend: LocalBackend, arguments: dict[str, Any], user_id: str
205
+ ) -> list[TextContent]:
206
+ query = arguments.get("query", "")
207
+ top_k = arguments.get("top_k", 10)
208
+
209
+ if not query:
210
+ return [TextContent(type="text", text="Error: query is required")]
211
+
212
+ try:
213
+ # Over-fetch to compensate for filtering out superseded memories
214
+ results = await backend.search_memories(
215
+ query=query,
216
+ user_id=user_id,
217
+ top_k=top_k * 3,
218
+ include_related=True,
219
+ )
220
+
221
+ if not results:
222
+ return [TextContent(type="text", text="No memories found.")]
223
+
224
+ # Filter out superseded memories — only return current/active ones.
225
+ # Re-check the store because in-memory HNSW metadata may be stale.
226
+ active_results = []
227
+ for r in results:
228
+ if getattr(r.memory, "superseded_by", None):
229
+ continue
230
+ # Double-check against the store for recently superseded memories
231
+ try:
232
+ stored = await backend.get_memory(r.memory.id)
233
+ if stored and getattr(stored, "superseded_by", None):
234
+ continue
235
+ except Exception:
236
+ pass
237
+ active_results.append(r)
238
+
239
+ if not active_results:
240
+ return [TextContent(type="text", text="No memories found.")]
241
+
242
+ # Trim to requested top_k
243
+ active_results = active_results[:top_k]
244
+
245
+ lines = []
246
+ for i, r in enumerate(active_results, 1):
247
+ score = f"{r.score:.2f}" if hasattr(r, "score") else "?"
248
+ lines.append(f"{i}. [relevance={score}] {r.memory.content}")
249
+ if hasattr(r, "related_entities") and r.related_entities:
250
+ lines.append(f" Related: {', '.join(r.related_entities[:3])}")
251
+
252
+ return [TextContent(type="text", text="\n".join(lines))]
253
+ except Exception as e:
254
+ logger.error(f"memory_search failed: {e}")
255
+ return [TextContent(type="text", text=f"Search error: {e}")]
256
+
257
+
258
+ # Similarity threshold for auto-supersession: if a new memory is this
259
+ # similar to an existing one, it replaces (supersedes) the old one.
260
+ _SUPERSEDE_SIMILARITY = 0.70
261
+
262
+
263
+ async def _handle_save(
264
+ backend: LocalBackend, arguments: dict[str, Any], user_id: str
265
+ ) -> list[TextContent]:
266
+ facts = arguments.get("facts", [])
267
+ importance = arguments.get("importance", 0.7)
268
+
269
+ # Backward compat: accept single "content" string too
270
+ if not facts:
271
+ content = arguments.get("content", "")
272
+ if content:
273
+ facts = [content]
274
+
275
+ if not facts:
276
+ return [TextContent(type="text", text="Error: facts array is required")]
277
+
278
+ try:
279
+ saved = 0
280
+ superseded = 0
281
+ results_lines: list[str] = []
282
+
283
+ for fact in facts:
284
+ fact = fact.strip()
285
+ if not fact:
286
+ continue
287
+
288
+ # Check for semantically similar existing memory to auto-supersede
289
+ superseded_id: str | None = None
290
+ try:
291
+ existing = await backend.search_memories(
292
+ query=fact,
293
+ user_id=user_id,
294
+ top_k=3,
295
+ )
296
+ for r in existing:
297
+ if getattr(r.memory, "superseded_by", None):
298
+ continue
299
+ if r.score >= _SUPERSEDE_SIMILARITY:
300
+ superseded_id = r.memory.id
301
+ logger.info(
302
+ f"Memory MCP: auto-superseding [{r.memory.id[:8]}] "
303
+ f"(similarity={r.score:.2f}): {r.memory.content[:60]}"
304
+ )
305
+ break
306
+ except Exception:
307
+ pass
308
+
309
+ if superseded_id:
310
+ memory = await backend.update_memory(
311
+ memory_id=superseded_id,
312
+ new_content=fact,
313
+ )
314
+ results_lines.append(
315
+ f" updated [{superseded_id[:8]}→{memory.id[:8]}]: {fact[:60]}"
316
+ )
317
+ superseded += 1
318
+ else:
319
+ memory = await backend.save_memory(
320
+ content=fact,
321
+ user_id=user_id,
322
+ importance=importance,
323
+ )
324
+ results_lines.append(f" saved [{memory.id[:8]}]: {fact[:60]}")
325
+ saved += 1
326
+
327
+ summary = f"Saved {saved} new, updated {superseded} existing ({saved + superseded} total)"
328
+ return [TextContent(type="text", text=summary + "\n" + "\n".join(results_lines))]
329
+ except Exception as e:
330
+ logger.error(f"memory_save failed: {e}")
331
+ return [TextContent(type="text", text=f"Save error: {e}")]
332
+
333
+
334
+ # ---------------------------------------------------------------------------
335
+ # Entry point
336
+ # ---------------------------------------------------------------------------
337
+
338
+
339
+ async def _run(db_path: str, user_id: str) -> None:
340
+ server = create_memory_server(db_path, user_id)
341
+ async with stdio_server() as (read_stream, write_stream):
342
+ await server.run(read_stream, write_stream, server.create_initialization_options())
343
+
344
+
345
+ def main() -> None:
346
+ parser = argparse.ArgumentParser(description="Headroom Memory MCP Server")
347
+ parser.add_argument(
348
+ "--db",
349
+ default=str(Path.cwd() / ".headroom" / "memory.db"),
350
+ help="Path to memory SQLite database",
351
+ )
352
+ parser.add_argument(
353
+ "--user",
354
+ default=os.environ.get("USER", os.environ.get("USERNAME", "default")),
355
+ help="User ID for memory scoping",
356
+ )
357
+ args = parser.parse_args()
358
+
359
+ # Skip HuggingFace model freshness checks — use cached models only.
360
+ # This eliminates 1-2s of HTTP HEAD requests on every startup.
361
+ os.environ.setdefault("HF_HUB_OFFLINE", "1")
362
+ os.environ.setdefault("TRANSFORMERS_OFFLINE", "1")
363
+
364
+ # Log to stderr (MCP uses stdout for protocol)
365
+ logging.basicConfig(
366
+ level=logging.INFO,
367
+ stream=sys.stderr,
368
+ format="%(name)s: %(message)s",
369
+ )
370
+
371
+ asyncio.run(_run(args.db, args.user))
372
+
373
+
374
+ if __name__ == "__main__":
375
+ main()
headroom/memory/sync.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Universal memory sync engine for cross-agent interoperability.
2
+
3
+ Provides bidirectional sync between headroom's memory DB and any
4
+ agent's native memory format via pluggable adapters.
5
+
6
+ Architecture:
7
+ DB ← sync_import → Agent files (agent's knowledge enters the shared DB)
8
+ DB → sync_export → Agent files (shared knowledge flows to the agent)
9
+ sync() = import + export (bidirectional, fast no-op when unchanged)
10
+
11
+ Usage:
12
+ from headroom.memory.sync import sync, SyncResult
13
+ from headroom.memory.sync_adapters.claude_code import ClaudeCodeAdapter
14
+
15
+ adapter = ClaudeCodeAdapter(memory_dir=Path("~/.claude/projects/.../memory"))
16
+ backend = LocalBackend(config)
17
+
18
+ result: SyncResult = await sync(backend, adapter, user_id="tcms")
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ import hashlib
24
+ import json
25
+ import logging
26
+ import time
27
+ from abc import ABC, abstractmethod
28
+ from dataclasses import dataclass, field
29
+ from datetime import datetime, timezone
30
+ from pathlib import Path
31
+ from typing import Any
32
+
33
+ logger = logging.getLogger("headroom.memory.sync")
34
+
35
+ # State file for fast no-op detection
36
+ _DEFAULT_STATE_PATH = Path.home() / ".headroom" / "sync_state.json"
37
+
38
+
39
+ # ---------------------------------------------------------------------------
40
+ # Data models
41
+ # ---------------------------------------------------------------------------
42
+
43
+
44
+ @dataclass
45
+ class SyncResult:
46
+ """Result of a sync operation."""
47
+
48
+ imported: int = 0 # agent files → DB
49
+ exported: int = 0 # DB → agent files
50
+ skipped_unchanged: int = 0
51
+ skipped_dedup: int = 0
52
+ duration_ms: float = 0
53
+
54
+
55
+ @dataclass
56
+ class AgentMemory:
57
+ """A memory entry read from an agent's native format."""
58
+
59
+ content: str
60
+ category: str = ""
61
+ source_file: str = ""
62
+ content_hash: str = ""
63
+ metadata: dict[str, Any] = field(default_factory=dict)
64
+
65
+ def __post_init__(self) -> None:
66
+ if not self.content_hash:
67
+ self.content_hash = hashlib.sha256(self.content.encode()).hexdigest()[:16]
68
+
69
+
70
+ # ---------------------------------------------------------------------------
71
+ # Adapter interface
72
+ # ---------------------------------------------------------------------------
73
+
74
+
75
+ class AgentMemoryAdapter(ABC):
76
+ """Base class for agent memory format adapters.
77
+
78
+ Each agent (Claude Code, Codex, Aider, Cursor) has a subclass
79
+ that knows how to read/write that agent's native memory format.
80
+ """
81
+
82
+ agent_name: str = "unknown"
83
+
84
+ @abstractmethod
85
+ async def read_memories(self) -> list[AgentMemory]:
86
+ """Read memories from the agent's native format.
87
+
88
+ Returns a list of AgentMemory entries found in the agent's files.
89
+ """
90
+ ...
91
+
92
+ @abstractmethod
93
+ async def write_memories(self, memories: list[dict[str, Any]]) -> int:
94
+ """Write memories to the agent's native format.
95
+
96
+ Args:
97
+ memories: List of dicts with keys: content, category, importance,
98
+ headroom_id, source_agent, content_hash.
99
+
100
+ Returns:
101
+ Count of memories written.
102
+ """
103
+ ...
104
+
105
+ @abstractmethod
106
+ def fingerprint(self) -> str:
107
+ """Fast hash of the agent's memory state.
108
+
109
+ Used for no-op detection: if the fingerprint hasn't changed
110
+ since last sync, we can skip the full read/compare cycle.
111
+ """
112
+ ...
113
+
114
+
115
+ # ---------------------------------------------------------------------------
116
+ # Sync state persistence
117
+ # ---------------------------------------------------------------------------
118
+
119
+
120
+ def _load_sync_state(state_path: Path) -> dict[str, Any]:
121
+ """Load sync state from disk."""
122
+ if state_path.exists():
123
+ try:
124
+ result: dict[str, Any] = json.loads(state_path.read_text())
125
+ return result
126
+ except (json.JSONDecodeError, OSError):
127
+ pass
128
+ return {}
129
+
130
+
131
+ def _save_sync_state(state_path: Path, state: dict[str, Any]) -> None:
132
+ """Save sync state to disk."""
133
+ state_path.parent.mkdir(parents=True, exist_ok=True)
134
+ state_path.write_text(json.dumps(state, indent=2))
135
+
136
+
137
+ def _db_fingerprint(memories: list[Any]) -> str:
138
+ """Compute a fast fingerprint of DB state."""
139
+ if not memories:
140
+ return "empty"
141
+ # Hash: count + most recent created_at
142
+ parts = [str(len(memories))]
143
+ for m in memories[:5]: # Sample first 5 for speed
144
+ parts.append(getattr(m, "id", "")[:8])
145
+ return hashlib.sha256("|".join(parts).encode()).hexdigest()[:16]
146
+
147
+
148
+ # ---------------------------------------------------------------------------
149
+ # Sync engine
150
+ # ---------------------------------------------------------------------------
151
+
152
+
153
+ async def sync(
154
+ backend: Any,
155
+ adapter: AgentMemoryAdapter,
156
+ user_id: str,
157
+ state_path: Path = _DEFAULT_STATE_PATH,
158
+ force: bool = False,
159
+ ) -> SyncResult:
160
+ """Bidirectional sync between headroom DB and an agent's memory.
161
+
162
+ 1. Fast no-op check (fingerprint comparison)
163
+ 2. Import: agent files → DB (new entries only, deduped by content hash)
164
+ 3. Export: DB → agent files (entries not already in agent's files)
165
+
166
+ Args:
167
+ backend: LocalBackend instance (must have save_memory, get_user_memories).
168
+ adapter: Agent-specific memory adapter.
169
+ user_id: User ID for memory scoping.
170
+ state_path: Path to sync state file.
171
+ force: Skip no-op check and always sync.
172
+
173
+ Returns:
174
+ SyncResult with import/export counts and timing.
175
+ """
176
+ start = time.monotonic()
177
+ result = SyncResult()
178
+
179
+ # --- Fast no-op check ---
180
+ if not force:
181
+ state = _load_sync_state(state_path)
182
+ adapter_key = f"{adapter.agent_name}:{user_id}"
183
+ prev = state.get(adapter_key, {})
184
+
185
+ current_agent_fp = adapter.fingerprint()
186
+ all_memories = await backend.get_user_memories(user_id, limit=500)
187
+ current_db_fp = _db_fingerprint(all_memories)
188
+
189
+ if (
190
+ prev.get("agent_fingerprint") == current_agent_fp
191
+ and prev.get("db_fingerprint") == current_db_fp
192
+ ):
193
+ result.duration_ms = (time.monotonic() - start) * 1000
194
+ logger.info(
195
+ f"Sync [{adapter.agent_name}]: no-op — nothing changed ({result.duration_ms:.1f}ms)"
196
+ )
197
+ return result
198
+ else:
199
+ all_memories = await backend.get_user_memories(user_id, limit=500)
200
+
201
+ # --- Phase 1: Import (agent files → DB) ---
202
+ result.imported = await sync_import(backend, adapter, user_id, all_memories)
203
+
204
+ # --- Phase 2: Export (DB → agent files) ---
205
+ # Re-fetch if imports happened (new entries)
206
+ if result.imported > 0:
207
+ all_memories = await backend.get_user_memories(user_id, limit=500)
208
+ result.exported = await sync_export(backend, adapter, user_id, all_memories)
209
+
210
+ # --- Update sync state ---
211
+ state = _load_sync_state(state_path)
212
+ adapter_key = f"{adapter.agent_name}:{user_id}"
213
+ state[adapter_key] = {
214
+ "agent_fingerprint": adapter.fingerprint(),
215
+ "db_fingerprint": _db_fingerprint(all_memories),
216
+ "last_sync": datetime.now(timezone.utc).isoformat(),
217
+ "last_imported": result.imported,
218
+ "last_exported": result.exported,
219
+ }
220
+ _save_sync_state(state_path, state)
221
+
222
+ result.duration_ms = (time.monotonic() - start) * 1000
223
+ logger.info(
224
+ f"Sync [{adapter.agent_name}]: imported={result.imported}, "
225
+ f"exported={result.exported} ({result.duration_ms:.1f}ms)"
226
+ )
227
+ return result
228
+
229
+
230
+ async def sync_import(
231
+ backend: Any,
232
+ adapter: AgentMemoryAdapter,
233
+ user_id: str,
234
+ existing_memories: list[Any] | None = None,
235
+ ) -> int:
236
+ """Import: agent files → DB. Returns count imported."""
237
+ agent_memories = await adapter.read_memories()
238
+ if not agent_memories:
239
+ return 0
240
+
241
+ # Build set of existing content hashes for dedup
242
+ if existing_memories is None:
243
+ existing_memories = await backend.get_user_memories(user_id, limit=500)
244
+
245
+ existing_hashes: set[str] = set()
246
+ for mem in existing_memories:
247
+ h = (mem.metadata or {}).get("content_hash", "")
248
+ if h:
249
+ existing_hashes.add(h)
250
+ # Also hash the content directly for safety
251
+ existing_hashes.add(hashlib.sha256(mem.content.encode()).hexdigest()[:16])
252
+
253
+ imported = 0
254
+ for am in agent_memories:
255
+ if am.content_hash in existing_hashes:
256
+ continue
257
+
258
+ # Save to DB with lineage metadata
259
+ await backend.save_memory(
260
+ content=am.content,
261
+ user_id=user_id,
262
+ importance=0.6,
263
+ metadata={
264
+ "source_agent": adapter.agent_name,
265
+ "source_file": am.source_file,
266
+ "content_hash": am.content_hash,
267
+ "synced_at": datetime.now(timezone.utc).isoformat(),
268
+ "sync_direction": "import",
269
+ **am.metadata,
270
+ },
271
+ )
272
+ existing_hashes.add(am.content_hash)
273
+ imported += 1
274
+
275
+ if imported:
276
+ logger.info(f"Sync [{adapter.agent_name}]: imported {imported} memories from agent files")
277
+ return imported
278
+
279
+
280
+ async def sync_export(
281
+ backend: Any,
282
+ adapter: AgentMemoryAdapter,
283
+ user_id: str,
284
+ existing_memories: list[Any] | None = None,
285
+ ) -> int:
286
+ """Export: DB → agent files. Returns count exported."""
287
+ if existing_memories is None:
288
+ existing_memories = await backend.get_user_memories(user_id, limit=500)
289
+
290
+ if not existing_memories:
291
+ return 0
292
+
293
+ # Read what the agent already has (to avoid re-exporting)
294
+ agent_memories = await adapter.read_memories()
295
+ agent_hashes: set[str] = {am.content_hash for am in agent_memories}
296
+
297
+ # Find memories to export (not already in agent, not imported FROM this agent)
298
+ to_export: list[dict[str, Any]] = []
299
+ for mem in existing_memories:
300
+ content_hash = hashlib.sha256(mem.content.encode()).hexdigest()[:16]
301
+
302
+ # Skip if agent already has it
303
+ if content_hash in agent_hashes:
304
+ continue
305
+
306
+ # Skip if this memory was originally imported FROM this same agent
307
+ # (prevents echo: agent → DB → agent)
308
+ meta = mem.metadata or {}
309
+ if (
310
+ meta.get("source_agent") == adapter.agent_name
311
+ and meta.get("sync_direction") == "import"
312
+ ):
313
+ continue
314
+
315
+ to_export.append(
316
+ {
317
+ "content": mem.content,
318
+ "category": getattr(mem, "category", "") or "",
319
+ "importance": getattr(mem, "importance", 0.5),
320
+ "headroom_id": mem.id,
321
+ "source_agent": meta.get("source_agent", "unknown"),
322
+ "content_hash": content_hash,
323
+ "created_at": mem.created_at.isoformat()
324
+ if hasattr(mem.created_at, "isoformat")
325
+ else str(mem.created_at),
326
+ }
327
+ )
328
+
329
+ if not to_export:
330
+ return 0
331
+
332
+ exported = await adapter.write_memories(to_export)
333
+ if exported:
334
+ logger.info(f"Sync [{adapter.agent_name}]: exported {exported} memories to agent files")
335
+ return exported
336
+
337
+
338
+ # ---------------------------------------------------------------------------
339
+ # CLI entry point: python -m headroom.memory.sync --db ... --user ... --agent ...
340
+ # ---------------------------------------------------------------------------
341
+
342
+
343
+ def main() -> None:
344
+ """CLI entry point for running sync from a subprocess."""
345
+ import argparse
346
+
347
+ parser = argparse.ArgumentParser(description="Headroom memory sync")
348
+ parser.add_argument("--db", required=True, help="Path to memory DB")
349
+ parser.add_argument("--user", required=True, help="User ID")
350
+ parser.add_argument("--agent", required=True, choices=["claude", "codex"], help="Agent to sync")
351
+ parser.add_argument("--force", action="store_true", help="Skip no-op check")
352
+ args = parser.parse_args()
353
+
354
+ import asyncio
355
+ import json as _json
356
+
357
+ async def _run() -> None:
358
+ from headroom.memory.backends.local import LocalBackend, LocalBackendConfig
359
+
360
+ config = LocalBackendConfig(db_path=args.db)
361
+ backend = LocalBackend(config)
362
+ await backend._ensure_initialized()
363
+
364
+ if args.agent == "claude":
365
+ from headroom.memory.sync_adapters.claude_code import (
366
+ ClaudeCodeAdapter,
367
+ get_claude_memory_dir,
368
+ )
369
+
370
+ adapter: ClaudeCodeAdapter | Any = ClaudeCodeAdapter(get_claude_memory_dir())
371
+ elif args.agent == "codex":
372
+ from headroom.memory.sync_adapters.codex_agent import CodexAdapter
373
+
374
+ adapter = CodexAdapter()
375
+ else:
376
+ print(_json.dumps({"error": f"Unknown agent: {args.agent}"}))
377
+ return
378
+
379
+ result = await sync(backend, adapter, args.user, force=args.force)
380
+ await backend.close()
381
+ print(
382
+ _json.dumps(
383
+ {
384
+ "imported": result.imported,
385
+ "exported": result.exported,
386
+ "ms": round(result.duration_ms),
387
+ }
388
+ )
389
+ )
390
+
391
+ asyncio.run(_run())
392
+
393
+
394
+ if __name__ == "__main__":
395
+ main()
headroom/memory/sync_adapters/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Agent memory sync adapters for cross-agent interoperability."""
headroom/memory/sync_adapters/claude_code.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Claude Code memory sync adapter.
2
+
3
+ Reads/writes Claude Code's native memory format:
4
+ ~/.claude/projects/<sanitized-path>/memory/
5
+ MEMORY.md — index file (first 200 lines always in context)
6
+ user_role.md — individual memory files with YAML frontmatter
7
+ project_codename.md
8
+ ...
9
+
10
+ Each .md file has:
11
+ ---
12
+ name: <title>
13
+ description: <one-line summary>
14
+ type: <user|project|reference|feedback>
15
+ headroom_id: <uuid> (added by sync for cross-reference)
16
+ source_agent: <agent name> (added by sync for lineage)
17
+ ---
18
+ <body content>
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ import hashlib
24
+ import re
25
+ from pathlib import Path
26
+ from typing import Any
27
+
28
+ from headroom.memory.sync import AgentMemory, AgentMemoryAdapter
29
+
30
+
31
+ def _sanitize_for_filename(text: str) -> str:
32
+ """Convert text to a safe filename slug."""
33
+ slug = re.sub(r"[^a-z0-9]+", "_", text.lower().strip())
34
+ slug = slug.strip("_")[:50]
35
+ return slug or "memory"
36
+
37
+
38
+ def _parse_frontmatter(content: str) -> tuple[dict[str, str], str]:
39
+ """Parse YAML frontmatter from a markdown file.
40
+
41
+ Returns (frontmatter_dict, body).
42
+ """
43
+ if not content.startswith("---"):
44
+ return {}, content
45
+
46
+ end = content.find("---", 3)
47
+ if end == -1:
48
+ return {}, content
49
+
50
+ fm_text = content[3:end].strip()
51
+ body = content[end + 3 :].strip()
52
+
53
+ fm: dict[str, str] = {}
54
+ for line in fm_text.split("\n"):
55
+ if ":" in line:
56
+ key, _, value = line.partition(":")
57
+ fm[key.strip()] = value.strip().strip('"').strip("'")
58
+
59
+ return fm, body
60
+
61
+
62
+ def _build_frontmatter(fields: dict[str, str]) -> str:
63
+ """Build YAML frontmatter block."""
64
+ lines = ["---"]
65
+ for key, value in fields.items():
66
+ if value:
67
+ lines.append(f"{key}: {value}")
68
+ lines.append("---")
69
+ return "\n".join(lines)
70
+
71
+
72
+ class ClaudeCodeAdapter(AgentMemoryAdapter):
73
+ """Sync adapter for Claude Code's native memory files."""
74
+
75
+ agent_name = "claude"
76
+
77
+ def __init__(self, memory_dir: Path | str) -> None:
78
+ self._memory_dir = Path(memory_dir)
79
+
80
+ async def read_memories(self) -> list[AgentMemory]:
81
+ """Read all .md memory files (except MEMORY.md index)."""
82
+ if not self._memory_dir.exists():
83
+ return []
84
+
85
+ memories: list[AgentMemory] = []
86
+ for md_file in sorted(self._memory_dir.glob("*.md")):
87
+ if md_file.name == "MEMORY.md":
88
+ continue # Index file, not a memory
89
+
90
+ try:
91
+ content = md_file.read_text(encoding="utf-8")
92
+ except OSError:
93
+ continue
94
+
95
+ fm, body = _parse_frontmatter(content)
96
+ if not body.strip():
97
+ continue
98
+
99
+ memories.append(
100
+ AgentMemory(
101
+ content=body.strip(),
102
+ category=fm.get("type", ""),
103
+ source_file=md_file.name,
104
+ metadata={
105
+ "name": fm.get("name", ""),
106
+ "description": fm.get("description", ""),
107
+ "headroom_id": fm.get("headroom_id", ""),
108
+ "source_agent": fm.get("source_agent", "claude"),
109
+ },
110
+ )
111
+ )
112
+
113
+ return memories
114
+
115
+ async def write_memories(self, memories: list[dict[str, Any]]) -> int:
116
+ """Write memories as individual .md files with frontmatter.
117
+
118
+ Also updates MEMORY.md index.
119
+ """
120
+ if not memories:
121
+ return 0
122
+
123
+ self._memory_dir.mkdir(parents=True, exist_ok=True)
124
+
125
+ written = 0
126
+ new_index_entries: list[str] = []
127
+
128
+ for mem in memories:
129
+ content = mem["content"]
130
+ category = mem.get("category", "project")
131
+ headroom_id = mem.get("headroom_id", "")
132
+ source_agent = mem.get("source_agent", "unknown")
133
+ content_hash = mem.get("content_hash", "")
134
+
135
+ # Generate filename from content
136
+ first_line = content.split("\n")[0][:60].strip()
137
+ slug = _sanitize_for_filename(first_line)
138
+ filename = f"headroom_{slug}.md"
139
+
140
+ # Skip if file already exists with same content
141
+ target = self._memory_dir / filename
142
+ if target.exists():
143
+ existing_fm, existing_body = _parse_frontmatter(target.read_text(encoding="utf-8"))
144
+ existing_hash = hashlib.sha256(existing_body.strip().encode()).hexdigest()[:16]
145
+ if existing_hash == content_hash:
146
+ continue
147
+
148
+ # Build description (first 100 chars)
149
+ description = content.replace("\n", " ")[:100]
150
+
151
+ # Write file
152
+ fm = _build_frontmatter(
153
+ {
154
+ "name": first_line[:60],
155
+ "description": description,
156
+ "type": category or "project",
157
+ "headroom_id": headroom_id,
158
+ "source_agent": source_agent,
159
+ }
160
+ )
161
+ target.write_text(f"{fm}\n\n{content}\n", encoding="utf-8")
162
+ written += 1
163
+
164
+ # Track for MEMORY.md index
165
+ new_index_entries.append(f"- [{first_line[:60]}]({filename}) — {description[:80]}")
166
+
167
+ # Update MEMORY.md index
168
+ if new_index_entries:
169
+ self._update_memory_md_index(new_index_entries)
170
+
171
+ return written
172
+
173
+ def _update_memory_md_index(self, new_entries: list[str]) -> None:
174
+ """Append new entries to MEMORY.md under a Headroom section."""
175
+ memory_md = self._memory_dir / "MEMORY.md"
176
+
177
+ section_marker = "## Headroom Shared Memory"
178
+ new_section = f"\n{section_marker}\n" + "\n".join(new_entries) + "\n"
179
+
180
+ if memory_md.exists():
181
+ content = memory_md.read_text(encoding="utf-8")
182
+ if section_marker in content:
183
+ # Append to existing section (before next ## or end)
184
+ idx = content.index(section_marker)
185
+ # Find end of section (next ## or end of file)
186
+ next_section = content.find("\n## ", idx + len(section_marker))
187
+ if next_section == -1:
188
+ # Append at end
189
+ content = content.rstrip() + "\n" + "\n".join(new_entries) + "\n"
190
+ else:
191
+ # Insert before next section
192
+ content = (
193
+ content[:next_section].rstrip()
194
+ + "\n"
195
+ + "\n".join(new_entries)
196
+ + "\n"
197
+ + content[next_section:]
198
+ )
199
+ else:
200
+ content = content.rstrip() + "\n" + new_section
201
+ else:
202
+ content = "# Memory\n" + new_section
203
+
204
+ memory_md.write_text(content, encoding="utf-8")
205
+
206
+ def fingerprint(self) -> str:
207
+ """Hash of all .md filenames + mtimes for fast change detection."""
208
+ if not self._memory_dir.exists():
209
+ return "empty"
210
+
211
+ parts: list[str] = []
212
+ for md_file in sorted(self._memory_dir.glob("*.md")):
213
+ try:
214
+ stat = md_file.stat()
215
+ parts.append(f"{md_file.name}:{stat.st_mtime_ns}")
216
+ except OSError:
217
+ continue
218
+
219
+ if not parts:
220
+ return "empty"
221
+ return hashlib.sha256("|".join(parts).encode()).hexdigest()[:16]
222
+
223
+
224
+ def get_claude_memory_dir(project_path: Path | None = None) -> Path:
225
+ """Get the Claude Code memory directory for a project.
226
+
227
+ Claude Code stores per-project memory at:
228
+ ~/.claude/projects/-<sanitized-path>/memory/
229
+ """
230
+ project = project_path or Path.cwd()
231
+ # Replace both Unix and Windows path separators (Claude Code does the same)
232
+ sanitized = str(project).replace("/", "-").replace("\\", "-")
233
+ return Path.home() / ".claude" / "projects" / sanitized / "memory"
headroom/memory/sync_adapters/codex_agent.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Codex CLI memory sync adapter.
2
+
3
+ Syncs memories to/from a headroom-managed section in AGENTS.md.
4
+ Codex reads AGENTS.md automatically before every task.
5
+
6
+ Note: Codex primarily uses the MCP server for memory (memory_search/save).
7
+ This adapter provides supplementary context injection via AGENTS.md so
8
+ Codex has key memories even without explicit tool calls.
9
+
10
+ Format in AGENTS.md:
11
+ <!-- headroom:memory:start -->
12
+ ## Headroom Shared Memory
13
+ - fact 1
14
+ - fact 2
15
+ <!-- headroom:memory:end -->
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import hashlib
21
+ import re
22
+ from pathlib import Path
23
+ from typing import Any
24
+
25
+ from headroom.memory.sync import AgentMemory, AgentMemoryAdapter
26
+
27
+ _MARKER_START = "<!-- headroom:memory:start -->"
28
+ _MARKER_END = "<!-- headroom:memory:end -->"
29
+ _MARKER_PATTERN = re.compile(
30
+ re.escape(_MARKER_START) + r"(.*?)" + re.escape(_MARKER_END),
31
+ re.DOTALL,
32
+ )
33
+
34
+
35
+ class CodexAdapter(AgentMemoryAdapter):
36
+ """Sync adapter for Codex's AGENTS.md."""
37
+
38
+ agent_name = "codex"
39
+
40
+ def __init__(self, agents_md_path: Path | str | None = None) -> None:
41
+ self._path = Path(agents_md_path) if agents_md_path else Path.cwd() / "AGENTS.md"
42
+
43
+ async def read_memories(self) -> list[AgentMemory]:
44
+ """Read memories from the headroom section of AGENTS.md."""
45
+ if not self._path.exists():
46
+ return []
47
+
48
+ content = self._path.read_text(encoding="utf-8")
49
+ match = _MARKER_PATTERN.search(content)
50
+ if not match:
51
+ return []
52
+
53
+ section = match.group(1).strip()
54
+ memories: list[AgentMemory] = []
55
+
56
+ for line in section.split("\n"):
57
+ line = line.strip()
58
+ if line.startswith("- "):
59
+ fact = line[2:].strip()
60
+ if fact:
61
+ memories.append(
62
+ AgentMemory(
63
+ content=fact,
64
+ source_file=self._path.name,
65
+ )
66
+ )
67
+
68
+ return memories
69
+
70
+ async def write_memories(self, memories: list[dict[str, Any]]) -> int:
71
+ """Write memories into the headroom section of AGENTS.md."""
72
+ if not memories:
73
+ return 0
74
+
75
+ # Build section content
76
+ lines = ["## Headroom Shared Memory", ""]
77
+ for mem in memories:
78
+ content = mem["content"].split("\n")[0].strip() # First line only
79
+ lines.append(f"- {content}")
80
+ lines.append("")
81
+
82
+ section = f"{_MARKER_START}\n" + "\n".join(lines) + f"{_MARKER_END}"
83
+
84
+ # Merge into AGENTS.md
85
+ if self._path.exists():
86
+ content = self._path.read_text(encoding="utf-8")
87
+ if _MARKER_START in content:
88
+ content = _MARKER_PATTERN.sub(section, content)
89
+ else:
90
+ content = content.rstrip() + "\n\n" + section + "\n"
91
+ else:
92
+ self._path.parent.mkdir(parents=True, exist_ok=True)
93
+ content = section + "\n"
94
+
95
+ self._path.write_text(content, encoding="utf-8")
96
+ return len(memories)
97
+
98
+ def fingerprint(self) -> str:
99
+ """Hash of AGENTS.md mtime."""
100
+ if not self._path.exists():
101
+ return "empty"
102
+ try:
103
+ stat = self._path.stat()
104
+ return hashlib.sha256(f"{self._path.name}:{stat.st_mtime_ns}".encode()).hexdigest()[:16]
105
+ except OSError:
106
+ return "error"
headroom/memory/writers/claude_writer.py CHANGED
@@ -64,7 +64,8 @@ class ClaudeCodeMemoryWriter(AgentWriter):
64
  project_path = self._project_path
65
  # Claude Code stores per-project memory at:
66
  # ~/.claude/projects/-<sanitized-path>/memory/MEMORY.md
67
- sanitized = str(project_path).replace("/", "-")
 
68
  claude_memory_dir = Path.home() / ".claude" / "projects" / sanitized / "memory"
69
  return claude_memory_dir / "MEMORY.md"
70
 
 
64
  project_path = self._project_path
65
  # Claude Code stores per-project memory at:
66
  # ~/.claude/projects/-<sanitized-path>/memory/MEMORY.md
67
+ # Replace both Unix and Windows path separators
68
+ sanitized = str(project_path).replace("/", "-").replace("\\", "-")
69
  claude_memory_dir = Path.home() / ".claude" / "projects" / sanitized / "memory"
70
  return claude_memory_dir / "MEMORY.md"
71
 
headroom/proxy/handlers/openai.py CHANGED
@@ -672,7 +672,9 @@ class OpenAIHandlerMixin:
672
  uncached_tokens=uncached_input_tokens,
673
  )
674
 
675
- # Memory: handle memory tool calls in OpenAI response
 
 
676
  if (
677
  self.memory_handler
678
  and memory_user_id
@@ -684,10 +686,29 @@ class OpenAIHandlerMixin:
684
  tool_results = await self.memory_handler.handle_memory_tool_calls(
685
  resp_json, memory_user_id, "openai"
686
  )
687
- logger.info(
688
- f"[{request_id}] Memory: Handled {len(tool_results)} "
689
- f"tool call(s) for user {memory_user_id}"
690
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
691
  except Exception as e:
692
  logger.warning(f"[{request_id}] Memory tool handling failed: {e}")
693
 
@@ -963,14 +984,29 @@ class OpenAIHandlerMixin:
963
  f"of context into instructions for user {memory_user_id}"
964
  )
965
 
966
- # Inject memory tools
967
  if self.memory_handler.config.inject_tools:
968
  resp_tools = body.get("tools") or []
969
  resp_tools, mem_tools_injected = self.memory_handler.inject_tools(
970
  resp_tools, "openai"
971
  )
972
  if mem_tools_injected:
973
- body["tools"] = resp_tools
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
974
  logger.info(
975
  f"[{request_id}] Memory: Injected memory tools (openai/responses)"
976
  )
@@ -1037,13 +1073,67 @@ class OpenAIHandlerMixin:
1037
  and self.memory_handler.has_memory_tool_calls(resp_json, "openai")
1038
  ):
1039
  try:
1040
- tool_results = await self.memory_handler.handle_memory_tool_calls(
1041
- resp_json, memory_user_id, "openai"
1042
- )
1043
- logger.info(
1044
- f"[{request_id}] Memory: Handled {len(tool_results)} "
1045
- f"tool call(s) for user {memory_user_id} (responses)"
1046
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1047
  except Exception as e:
1048
  logger.warning(
1049
  f"[{request_id}] Memory tool handling failed (responses): {e}"
@@ -1279,6 +1369,116 @@ class OpenAIHandlerMixin:
1279
  # Not JSON — pass through as-is
1280
  tokens_saved = 0
1281
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1282
  # --- Connect to upstream OpenAI WebSocket ---
1283
  logger.info(f"[{request_id}] WS /v1/responses connecting to {upstream_url}")
1284
 
@@ -1303,7 +1503,7 @@ class OpenAIHandlerMixin:
1303
  # Send (potentially compressed) first message
1304
  await upstream.send(first_msg_raw)
1305
 
1306
- # Bidirectional relay
1307
  async def _client_to_upstream() -> None:
1308
  try:
1309
  while True:
@@ -1318,14 +1518,164 @@ class OpenAIHandlerMixin:
1318
  await upstream.close()
1319
 
1320
  async def _upstream_to_client() -> None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1321
  try:
1322
  async for msg in upstream:
1323
- if isinstance(msg, str):
1324
- await websocket.send_text(msg)
1325
- elif isinstance(msg, bytes):
1326
  await websocket.send_bytes(msg)
1327
- else:
1328
- await websocket.send_text(str(msg))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1329
  except Exception as relay_err:
1330
  if "WebSocketDisconnect" not in type(relay_err).__name__:
1331
  logger.debug(
 
672
  uncached_tokens=uncached_input_tokens,
673
  )
674
 
675
+ # Memory: handle memory tool calls in OpenAI Chat Completions response.
676
+ # After executing tools, send a continuation request so the model
677
+ # can produce a final user-facing response (not just tool_calls).
678
  if (
679
  self.memory_handler
680
  and memory_user_id
 
686
  tool_results = await self.memory_handler.handle_memory_tool_calls(
687
  resp_json, memory_user_id, "openai"
688
  )
689
+ if tool_results:
690
+ # Build continuation: original messages + assistant tool_calls + tool results
691
+ assistant_msg = resp_json.get("choices", [{}])[0].get("message", {})
692
+ continuation_messages = list(optimized_messages)
693
+ continuation_messages.append(assistant_msg)
694
+ continuation_messages.extend(tool_results)
695
+
696
+ continuation_body = {
697
+ **body,
698
+ "messages": continuation_messages,
699
+ }
700
+
701
+ cont_response = await self._retry_request(
702
+ "POST", url, headers, continuation_body
703
+ )
704
+ if cont_response.status_code == 200:
705
+ resp_json = cont_response.json()
706
+ response = cont_response
707
+
708
+ logger.info(
709
+ f"[{request_id}] Memory: Handled {len(tool_results)} "
710
+ f"tool call(s) with continuation for user {memory_user_id}"
711
+ )
712
  except Exception as e:
713
  logger.warning(f"[{request_id}] Memory tool handling failed: {e}")
714
 
 
984
  f"of context into instructions for user {memory_user_id}"
985
  )
986
 
987
+ # Inject memory tools (Responses API format)
988
  if self.memory_handler.config.inject_tools:
989
  resp_tools = body.get("tools") or []
990
  resp_tools, mem_tools_injected = self.memory_handler.inject_tools(
991
  resp_tools, "openai"
992
  )
993
  if mem_tools_injected:
994
+ # Convert Chat Completions format to Responses API format
995
+ converted_tools = []
996
+ for t in resp_tools:
997
+ if t.get("type") == "function" and "function" in t:
998
+ fn = t["function"]
999
+ converted_tools.append(
1000
+ {
1001
+ "type": "function",
1002
+ "name": fn.get("name"),
1003
+ "description": fn.get("description", ""),
1004
+ "parameters": fn.get("parameters", {}),
1005
+ }
1006
+ )
1007
+ else:
1008
+ converted_tools.append(t)
1009
+ body["tools"] = converted_tools
1010
  logger.info(
1011
  f"[{request_id}] Memory: Injected memory tools (openai/responses)"
1012
  )
 
1073
  and self.memory_handler.has_memory_tool_calls(resp_json, "openai")
1074
  ):
1075
  try:
1076
+ # Extract function_call items from output
1077
+ from headroom.proxy.memory_handler import MEMORY_TOOL_NAMES
1078
+
1079
+ output_items = resp_json.get("output", [])
1080
+ memory_fc_items = [
1081
+ item
1082
+ for item in output_items
1083
+ if isinstance(item, dict)
1084
+ and item.get("type") == "function_call"
1085
+ and item.get("name") in MEMORY_TOOL_NAMES
1086
+ ]
1087
+
1088
+ # Execute memory tool calls
1089
+ tool_outputs: list[dict[str, Any]] = []
1090
+ for fc in memory_fc_items:
1091
+ call_id = fc.get("call_id", fc.get("id", ""))
1092
+ name = fc.get("name", "")
1093
+ args_str = fc.get("arguments", "{}")
1094
+ try:
1095
+ args = json.loads(args_str)
1096
+ except json.JSONDecodeError:
1097
+ args = {}
1098
+
1099
+ await self.memory_handler._ensure_initialized()
1100
+ if self.memory_handler._backend:
1101
+ result = await self.memory_handler._execute_memory_tool(
1102
+ name, args, memory_user_id, "openai"
1103
+ )
1104
+ else:
1105
+ result = json.dumps({"error": "Memory backend not initialized"})
1106
+
1107
+ tool_outputs.append(
1108
+ {
1109
+ "type": "function_call_output",
1110
+ "call_id": call_id,
1111
+ "output": result,
1112
+ }
1113
+ )
1114
+
1115
+ if tool_outputs:
1116
+ # Make continuation request with tool results
1117
+ response_id = resp_json.get("id")
1118
+ continuation_body = {
1119
+ "model": model,
1120
+ "input": tool_outputs,
1121
+ }
1122
+ if response_id:
1123
+ continuation_body["previous_response_id"] = response_id
1124
+ existing_tools = body.get("tools")
1125
+ if existing_tools:
1126
+ continuation_body["tools"] = existing_tools
1127
+
1128
+ cont_response = await self._retry_request(
1129
+ "POST", url, headers, continuation_body
1130
+ )
1131
+ resp_json = cont_response.json()
1132
+ response = cont_response
1133
+ logger.info(
1134
+ f"[{request_id}] Memory: Handled {len(tool_outputs)} "
1135
+ f"tool call(s) with continuation for user {memory_user_id} (responses)"
1136
+ )
1137
  except Exception as e:
1138
  logger.warning(
1139
  f"[{request_id}] Memory tool handling failed (responses): {e}"
 
1369
  # Not JSON — pass through as-is
1370
  tokens_saved = 0
1371
 
1372
+ # --- Memory: inject context, tools, and instructions ---
1373
+ memory_user_id: str | None = None
1374
+ if self.memory_handler and body:
1375
+ memory_user_id = ws_headers.get(
1376
+ "x-headroom-user-id",
1377
+ os.environ.get("USER", os.environ.get("USERNAME", "default")),
1378
+ )
1379
+ try:
1380
+ # Unwrap response.create envelope to access the response body
1381
+ ws_response_body = body.get("response", body)
1382
+
1383
+ # Debug: log what Codex sends so we can see the full tool list
1384
+ existing_tool_names = [
1385
+ t.get("name") or t.get("function", {}).get("name", "?")
1386
+ for t in (ws_response_body.get("tools") or [])
1387
+ ]
1388
+ instr_preview = (ws_response_body.get("instructions") or "")[:200]
1389
+ logger.info(
1390
+ f"[{request_id}] WS Memory: Codex tools={existing_tool_names}, "
1391
+ f"instructions_len={len(ws_response_body.get('instructions') or '')}, "
1392
+ f"instructions_preview={instr_preview!r}"
1393
+ )
1394
+
1395
+ # Inject memory context into instructions
1396
+ if self.memory_handler.config.inject_context:
1397
+ ws_input = ws_response_body.get("input", "")
1398
+ ws_instructions = ws_response_body.get("instructions")
1399
+ ws_msgs: list[dict[str, Any]] = []
1400
+ if ws_instructions:
1401
+ ws_msgs.append({"role": "system", "content": ws_instructions})
1402
+ if isinstance(ws_input, str) and ws_input:
1403
+ ws_msgs.append({"role": "user", "content": ws_input})
1404
+ elif isinstance(ws_input, list):
1405
+ from headroom.proxy.responses_converter import (
1406
+ responses_items_to_messages,
1407
+ )
1408
+
1409
+ converted_msgs, _ = responses_items_to_messages(ws_input)
1410
+ ws_msgs.extend(converted_msgs)
1411
+
1412
+ memory_context = await self.memory_handler.search_and_format_context(
1413
+ memory_user_id, ws_msgs
1414
+ )
1415
+ if memory_context:
1416
+ existing = ws_response_body.get("instructions") or ""
1417
+ if existing:
1418
+ ws_response_body["instructions"] = f"{existing}\n\n{memory_context}"
1419
+ else:
1420
+ ws_response_body["instructions"] = memory_context
1421
+ logger.info(
1422
+ f"[{request_id}] WS Memory: Injected {len(memory_context)} chars "
1423
+ f"of context into instructions"
1424
+ )
1425
+
1426
+ # Inject memory tools (Responses API format)
1427
+ if self.memory_handler.config.inject_tools:
1428
+ ws_tools = ws_response_body.get("tools") or []
1429
+ ws_tools, mem_injected = self.memory_handler.inject_tools(
1430
+ ws_tools, "openai"
1431
+ )
1432
+ if mem_injected:
1433
+ converted_tools = []
1434
+ for t in ws_tools:
1435
+ if t.get("type") == "function" and "function" in t:
1436
+ fn = t["function"]
1437
+ converted_tools.append(
1438
+ {
1439
+ "type": "function",
1440
+ "name": fn.get("name"),
1441
+ "description": fn.get("description", ""),
1442
+ "parameters": fn.get("parameters", {}),
1443
+ }
1444
+ )
1445
+ else:
1446
+ converted_tools.append(t)
1447
+ ws_response_body["tools"] = converted_tools
1448
+
1449
+ # Add memory instruction so the model uses
1450
+ # memory tools as persistent cross-session knowledge.
1451
+ mem_instruction = (
1452
+ "\n\n## Memory\n"
1453
+ "You have persistent memory via memory_search and "
1454
+ "memory_save tools. Memory stores knowledge across "
1455
+ "sessions — user info, project details, org context, "
1456
+ "decisions, architecture, conventions, anything worth "
1457
+ "remembering.\n\n"
1458
+ "- ALWAYS call memory_search BEFORE searching files "
1459
+ "when the user asks a question that could be answered "
1460
+ "from prior knowledge.\n"
1461
+ "- Call memory_save to store important facts, decisions, "
1462
+ "or context that would be useful in future sessions.\n"
1463
+ "- Memory is your first source of truth for anything "
1464
+ "not visible in the current conversation."
1465
+ )
1466
+ existing_instr = ws_response_body.get("instructions") or ""
1467
+ ws_response_body["instructions"] = existing_instr + mem_instruction
1468
+ logger.info(
1469
+ f"[{request_id}] WS Memory: Injected memory tools + instruction"
1470
+ )
1471
+
1472
+ # Write back into envelope if it was wrapped
1473
+ if "response" in body and isinstance(body["response"], dict):
1474
+ body["response"] = ws_response_body
1475
+ else:
1476
+ body = ws_response_body
1477
+
1478
+ first_msg_raw = json.dumps(body)
1479
+ except Exception as e:
1480
+ logger.warning(f"[{request_id}] WS Memory injection failed: {e}")
1481
+
1482
  # --- Connect to upstream OpenAI WebSocket ---
1483
  logger.info(f"[{request_id}] WS /v1/responses connecting to {upstream_url}")
1484
 
 
1503
  # Send (potentially compressed) first message
1504
  await upstream.send(first_msg_raw)
1505
 
1506
+ # Bidirectional relay with memory tool interception
1507
  async def _client_to_upstream() -> None:
1508
  try:
1509
  while True:
 
1518
  await upstream.close()
1519
 
1520
  async def _upstream_to_client() -> None:
1521
+ """Relay upstream→client with transparent memory tool handling.
1522
+
1523
+ Uses a buffer-then-decide approach:
1524
+ 1. Buffer events until first output item arrives
1525
+ 2. If first output is a memory tool → suppress entire response,
1526
+ execute tools silently, send continuation upstream
1527
+ 3. If first output is non-memory → flush buffer, stream normally
1528
+ 4. Continuation response events are relayed to Codex seamlessly
1529
+
1530
+ This prevents orphaned response.created events from confusing Codex.
1531
+ """
1532
+ from headroom.proxy.memory_handler import MEMORY_TOOL_NAMES
1533
+
1534
+ memory_enabled = bool(self.memory_handler and memory_user_id)
1535
+
1536
+ # Per-response state (reset after each response.completed)
1537
+ event_buffer: list[str] = []
1538
+ decided = False
1539
+ suppress_response = False
1540
+ pending_fcs: list[dict[str, Any]] = []
1541
+ resp_id: str | None = None
1542
+
1543
+ def _reset() -> None:
1544
+ nonlocal decided, suppress_response, resp_id
1545
+ event_buffer.clear()
1546
+ decided = False
1547
+ suppress_response = False
1548
+ pending_fcs.clear()
1549
+ resp_id = None
1550
+
1551
  try:
1552
  async for msg in upstream:
1553
+ if isinstance(msg, bytes):
 
 
1554
  await websocket.send_bytes(msg)
1555
+ continue
1556
+ msg_str = msg if isinstance(msg, str) else str(msg)
1557
+
1558
+ if not memory_enabled:
1559
+ await websocket.send_text(msg_str)
1560
+ continue
1561
+
1562
+ # Parse event
1563
+ try:
1564
+ event = json.loads(msg_str)
1565
+ except (json.JSONDecodeError, TypeError):
1566
+ await websocket.send_text(msg_str)
1567
+ continue
1568
+
1569
+ event_type = event.get("type", "")
1570
+
1571
+ # --- Phase 1: Buffer until first output item ---
1572
+ if not decided:
1573
+ event_buffer.append(msg_str)
1574
+
1575
+ if event_type == "response.output_item.added":
1576
+ item = event.get("item", {})
1577
+ if (
1578
+ item.get("type") == "function_call"
1579
+ and item.get("name") in MEMORY_TOOL_NAMES
1580
+ ):
1581
+ # Memory tool first → suppress entire response
1582
+ suppress_response = True
1583
+ decided = True
1584
+ event_buffer.clear()
1585
+ logger.info(
1586
+ f"[{request_id}] WS Memory: Detected "
1587
+ f"{item.get('name')} — suppressing response"
1588
+ )
1589
+ else:
1590
+ # Non-memory first → flush buffer, pass through
1591
+ decided = True
1592
+ for buf in event_buffer:
1593
+ await websocket.send_text(buf)
1594
+ event_buffer.clear()
1595
+
1596
+ elif event_type == "response.completed":
1597
+ # No output items at all — flush
1598
+ decided = True
1599
+ for buf in event_buffer:
1600
+ await websocket.send_text(buf)
1601
+ event_buffer.clear()
1602
+ _reset()
1603
+
1604
+ continue
1605
+
1606
+ # --- Phase 2a: Suppress mode (memory response) ---
1607
+ if suppress_response:
1608
+ if event_type == "response.output_item.done":
1609
+ item = event.get("item", {})
1610
+ if (
1611
+ item.get("type") == "function_call"
1612
+ and item.get("name") in MEMORY_TOOL_NAMES
1613
+ ):
1614
+ pending_fcs.append(item)
1615
+
1616
+ elif event_type == "response.completed":
1617
+ resp = event.get("response", {})
1618
+ resp_id = resp.get("id")
1619
+
1620
+ if pending_fcs:
1621
+ logger.info(
1622
+ f"[{request_id}] WS Memory: Executing "
1623
+ f"{len(pending_fcs)} tool(s) transparently"
1624
+ )
1625
+
1626
+ # Execute memory tool calls
1627
+ tool_outputs: list[dict[str, Any]] = []
1628
+ for fc in pending_fcs:
1629
+ call_id = fc.get("call_id", fc.get("id", ""))
1630
+ fc_name = fc.get("name", "")
1631
+ args_str = fc.get("arguments", "{}")
1632
+ try:
1633
+ fc_args = json.loads(args_str)
1634
+ except json.JSONDecodeError:
1635
+ fc_args = {}
1636
+
1637
+ await self.memory_handler._ensure_initialized()
1638
+ if self.memory_handler._backend:
1639
+ result = await self.memory_handler._execute_memory_tool(
1640
+ fc_name, fc_args, memory_user_id, "openai"
1641
+ )
1642
+ else:
1643
+ result = json.dumps(
1644
+ {"error": "backend not ready"}
1645
+ )
1646
+
1647
+ tool_outputs.append(
1648
+ {
1649
+ "type": "function_call_output",
1650
+ "call_id": call_id,
1651
+ "output": result,
1652
+ }
1653
+ )
1654
+ logger.info(
1655
+ f"[{request_id}] WS Memory: Executed "
1656
+ f"{fc_name} for user {memory_user_id}"
1657
+ )
1658
+
1659
+ # Send continuation upstream
1660
+ cont: dict[str, Any] = {
1661
+ "type": "response.create",
1662
+ "response": {"input": tool_outputs},
1663
+ }
1664
+ if resp_id:
1665
+ cont["response"]["previous_response_id"] = resp_id
1666
+ await upstream.send(json.dumps(cont))
1667
+ logger.info(
1668
+ f"[{request_id}] WS Memory: Sent continuation "
1669
+ f"with {len(tool_outputs)} result(s)"
1670
+ )
1671
+
1672
+ _reset()
1673
+ # All events suppressed in this mode
1674
+ continue
1675
+
1676
+ # --- Phase 2b: Pass-through mode ---
1677
+ await websocket.send_text(msg_str)
1678
+
1679
  except Exception as relay_err:
1680
  if "WebSocketDisconnect" not in type(relay_err).__name__:
1681
  logger.debug(
headroom/proxy/handlers/streaming.py CHANGED
@@ -614,15 +614,18 @@ class StreamingMixin:
614
  f"[{request_id}] Memory: Detected tool calls in streaming response"
615
  )
616
 
617
- # Execute memory tool calls silently — response already
618
- # streamed so we cannot make a continuation request.
 
 
619
  tool_results = await self.memory_handler.handle_memory_tool_calls(
620
  parsed_response, memory_user_id, provider
621
  )
622
  if tool_results:
623
  logger.info(
624
- f"[{request_id}] Memory: Tool calls executed silently "
625
- "(streaming mode no continuation)"
 
626
  )
627
 
628
  # CCR Feedback: Record headroom_retrieve tool calls for TOIN learning.
 
614
  f"[{request_id}] Memory: Detected tool calls in streaming response"
615
  )
616
 
617
+ # Execute memory tool calls — response already streamed
618
+ # so results are saved but continuation is not possible
619
+ # in SSE streaming mode. The WS and non-streaming paths
620
+ # handle continuation properly.
621
  tool_results = await self.memory_handler.handle_memory_tool_calls(
622
  parsed_response, memory_user_id, provider
623
  )
624
  if tool_results:
625
  logger.info(
626
+ f"[{request_id}] Memory: Tool calls executed "
627
+ f"({len(tool_results)} results saved, SSE streaming — "
628
+ "continuation handled by client)"
629
  )
630
 
631
  # CCR Feedback: Record headroom_retrieve tool calls for TOIN learning.
headroom/proxy/memory_handler.py CHANGED
@@ -380,7 +380,14 @@ class MemoryHandler:
380
  entities_str = ", ".join(result.related_entities[:3])
381
  memory_lines.append(f" (Related: {entities_str})")
382
 
383
- context = f"""## Relevant Memories for This User
 
 
 
 
 
 
 
384
 
385
  The following information was previously saved about this user:
386
 
@@ -388,15 +395,11 @@ The following information was previously saved about this user:
388
 
389
  Use this context to provide personalized and contextually relevant responses."""
390
 
391
- logger.info(
392
- f"Memory: Injecting {len(filtered_results)} memories "
393
- f"({len(context)} chars) for user {user_id}"
394
- )
395
- return context
396
-
397
- except Exception as e:
398
- logger.warning(f"Memory: Search failed for user {user_id}: {e}")
399
- return None
400
 
401
  def _extract_user_query(self, messages: list[dict[str, Any]]) -> str:
402
  """Extract the user query from the last user message."""
@@ -445,10 +448,25 @@ Use this context to provide personalized and contextually relevant responses."""
445
  return []
446
 
447
  elif provider == "openai":
 
448
  choices = response.get("choices", [])
449
  if choices:
450
  message = choices[0].get("message", {})
451
- return list(message.get("tool_calls", []) or [])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
452
  return []
453
 
454
  return []
@@ -474,13 +492,15 @@ Use this context to provide personalized and contextually relevant responses."""
474
 
475
  for tc in tool_calls:
476
  tool_name = tc.get("name") or tc.get("function", {}).get("name")
477
- tool_id = tc.get("id", "")
478
 
479
  # Parse input data
480
  if provider == "anthropic":
481
  input_data = tc.get("input", {})
482
  else:
483
- args_str = tc.get("function", {}).get("arguments", "{}")
 
 
484
  try:
485
  input_data = json.loads(args_str)
486
  except json.JSONDecodeError:
 
380
  entities_str = ", ".join(result.related_entities[:3])
381
  memory_lines.append(f" (Related: {entities_str})")
382
 
383
+ except Exception as e:
384
+ logger.warning(f"Memory: Search failed for user {user_id}: {e}")
385
+ return None
386
+
387
+ if not memory_lines:
388
+ return None
389
+
390
+ context = f"""## Relevant Memories for This User
391
 
392
  The following information was previously saved about this user:
393
 
 
395
 
396
  Use this context to provide personalized and contextually relevant responses."""
397
 
398
+ logger.info(
399
+ f"Memory: Injecting {len(memory_lines)} memories "
400
+ f"({len(context)} chars) for user {user_id}"
401
+ )
402
+ return context
 
 
 
 
403
 
404
  def _extract_user_query(self, messages: list[dict[str, Any]]) -> str:
405
  """Extract the user query from the last user message."""
 
448
  return []
449
 
450
  elif provider == "openai":
451
+ # Chat Completions format: choices[0].message.tool_calls
452
  choices = response.get("choices", [])
453
  if choices:
454
  message = choices[0].get("message", {})
455
+ tc_list = list(message.get("tool_calls", []) or [])
456
+ if tc_list:
457
+ return tc_list
458
+
459
+ # Responses API format: output[] with type=function_call
460
+ output = response.get("output", [])
461
+ if isinstance(output, list):
462
+ fc_items = [
463
+ item
464
+ for item in output
465
+ if isinstance(item, dict) and item.get("type") == "function_call"
466
+ ]
467
+ if fc_items:
468
+ return fc_items
469
+
470
  return []
471
 
472
  return []
 
492
 
493
  for tc in tool_calls:
494
  tool_name = tc.get("name") or tc.get("function", {}).get("name")
495
+ tool_id = tc.get("id") or tc.get("call_id", "")
496
 
497
  # Parse input data
498
  if provider == "anthropic":
499
  input_data = tc.get("input", {})
500
  else:
501
+ # Chat Completions format: function.arguments
502
+ # Responses API format: arguments (top-level string)
503
+ args_str = tc.get("arguments") or tc.get("function", {}).get("arguments") or "{}"
504
  try:
505
  input_data = json.loads(args_str)
506
  except json.JSONDecodeError:
headroom/proxy/models.py CHANGED
@@ -207,3 +207,6 @@ class ProxyConfig:
207
  subscription_tracking_enabled: bool = True
208
  subscription_poll_interval_s: int = 10
209
  subscription_active_window_s: int = 60
 
 
 
 
207
  subscription_tracking_enabled: bool = True
208
  subscription_poll_interval_s: int = 10
209
  subscription_active_window_s: int = 60
210
+
211
+ # Stateless mode — disable all filesystem writes for read-only / container deployments
212
+ stateless: bool = False
headroom/proxy/request_logger.py CHANGED
@@ -8,6 +8,7 @@ Extracted from server.py for maintainability.
8
  from __future__ import annotations
9
 
10
  import json
 
11
  import sys
12
  from collections import deque
13
  from dataclasses import asdict
@@ -19,11 +20,15 @@ if TYPE_CHECKING:
19
 
20
  from headroom.proxy.models import RequestLog
21
 
 
 
22
 
23
  class RequestLogger:
24
  """Log requests to JSONL file.
25
 
26
  Uses a deque with max 10,000 entries to prevent unbounded memory growth.
 
 
27
  """
28
 
29
  MAX_LOG_ENTRIES = 10_000
@@ -35,19 +40,30 @@ class RequestLogger:
35
  self._logs: deque[RequestLog] = deque(maxlen=self.MAX_LOG_ENTRIES)
36
 
37
  if self.log_file:
38
- self.log_file.parent.mkdir(parents=True, exist_ok=True)
 
 
 
 
 
 
 
 
39
 
40
  def log(self, entry: RequestLog):
41
  """Log a request. Oldest entries are automatically removed when limit reached."""
42
  self._logs.append(entry)
43
 
44
  if self.log_file:
45
- with open(self.log_file, "a") as f:
46
- log_dict = asdict(entry)
47
- if not self.log_full_messages:
48
- log_dict.pop("request_messages", None)
49
- log_dict.pop("response_content", None)
50
- f.write(json.dumps(log_dict) + "\n")
 
 
 
51
 
52
  def get_recent(self, n: int = 100) -> list[dict]:
53
  """Get recent log entries."""
 
8
  from __future__ import annotations
9
 
10
  import json
11
+ import logging
12
  import sys
13
  from collections import deque
14
  from dataclasses import asdict
 
20
 
21
  from headroom.proxy.models import RequestLog
22
 
23
+ logger = logging.getLogger(__name__)
24
+
25
 
26
  class RequestLogger:
27
  """Log requests to JSONL file.
28
 
29
  Uses a deque with max 10,000 entries to prevent unbounded memory growth.
30
+ Gracefully degrades to in-memory-only if the log file cannot be written
31
+ (read-only filesystem, permissions error, etc.).
32
  """
33
 
34
  MAX_LOG_ENTRIES = 10_000
 
40
  self._logs: deque[RequestLog] = deque(maxlen=self.MAX_LOG_ENTRIES)
41
 
42
  if self.log_file:
43
+ try:
44
+ self.log_file.parent.mkdir(parents=True, exist_ok=True)
45
+ except OSError as e:
46
+ logger.warning(
47
+ "Cannot create log directory %s: %s — logging to memory only",
48
+ self.log_file.parent,
49
+ e,
50
+ )
51
+ self.log_file = None
52
 
53
  def log(self, entry: RequestLog):
54
  """Log a request. Oldest entries are automatically removed when limit reached."""
55
  self._logs.append(entry)
56
 
57
  if self.log_file:
58
+ try:
59
+ with open(self.log_file, "a") as f:
60
+ log_dict = asdict(entry)
61
+ if not self.log_full_messages:
62
+ log_dict.pop("request_messages", None)
63
+ log_dict.pop("response_content", None)
64
+ f.write(json.dumps(log_dict) + "\n")
65
+ except OSError:
66
+ pass # Graceful degradation: memory-only logging continues
67
 
68
  def get_recent(self, n: int = 100) -> list[dict]:
69
  """Get recent log entries."""
headroom/telemetry/toin.py CHANGED
@@ -1174,11 +1174,11 @@ class ToolIntelligenceNetwork:
1174
  canonical = "null"
1175
  elif isinstance(value, bool):
1176
  canonical = "true" if value else "false"
1177
- elif isinstance(value, (int, float)):
1178
  canonical = str(value)
1179
  elif isinstance(value, str):
1180
  canonical = value
1181
- elif isinstance(value, (list, dict)):
1182
  # For complex types, use JSON serialization
1183
  try:
1184
  canonical = json.dumps(value, sort_keys=True, default=str)
@@ -1573,6 +1573,8 @@ def _create_default_toin_backend() -> Any:
1573
  backend_type = (os.environ.get(TOIN_BACKEND_ENV_VAR) or "").strip().lower()
1574
  if not backend_type or backend_type == "filesystem":
1575
  return None
 
 
1576
  try:
1577
  from importlib.metadata import entry_points
1578
 
 
1174
  canonical = "null"
1175
  elif isinstance(value, bool):
1176
  canonical = "true" if value else "false"
1177
+ elif isinstance(value, int | float):
1178
  canonical = str(value)
1179
  elif isinstance(value, str):
1180
  canonical = value
1181
+ elif isinstance(value, list | dict):
1182
  # For complex types, use JSON serialization
1183
  try:
1184
  canonical = json.dumps(value, sort_keys=True, default=str)
 
1573
  backend_type = (os.environ.get(TOIN_BACKEND_ENV_VAR) or "").strip().lower()
1574
  if not backend_type or backend_type == "filesystem":
1575
  return None
1576
+ if backend_type == "none":
1577
+ return None # Explicit in-memory-only (e.g. --stateless mode)
1578
  try:
1579
  from importlib.metadata import entry_points
1580
 
headroom/transforms/kompress_compressor.py CHANGED
@@ -25,12 +25,12 @@ from .base import Transform
25
 
26
  logger = logging.getLogger(__name__)
27
 
28
- # HuggingFace model ID
29
  HF_MODEL_ID = "chopratejas/kompress-base"
30
 
31
- # Lazy singleton
32
- _kompress_model = None
33
- _kompress_tokenizer = None
34
  _kompress_lock = threading.Lock()
35
 
36
 
@@ -132,9 +132,6 @@ def _get_model_class() -> type:
132
 
133
  # ── Model Loading ─────────────────────────────────────────────────────
134
 
135
- # Backend tag: "onnx" or "pytorch"
136
- _kompress_backend: str | None = None
137
-
138
 
139
  class _OnnxModel:
140
  """Thin wrapper so ONNX session has the same interface as PyTorch model."""
@@ -163,48 +160,42 @@ class _OnnxModel:
163
  return (np.array(scores) > 0.5).tolist()
164
 
165
 
166
- def _load_kompress_onnx() -> tuple[Any, Any]:
167
  """Download ONNX INT8 model from HuggingFace and load with onnxruntime."""
168
  import onnxruntime as ort
169
  from transformers import AutoTokenizer
170
 
171
- global _kompress_model, _kompress_tokenizer, _kompress_backend
172
-
173
  with _kompress_lock:
174
- if _kompress_model is not None:
175
- return _kompress_model, _kompress_tokenizer
176
 
177
  from huggingface_hub import hf_hub_download
178
 
179
- logger.info("Downloading Kompress ONNX model from %s ...", HF_MODEL_ID)
180
- onnx_path = hf_hub_download(HF_MODEL_ID, "onnx/kompress-int8.onnx")
181
 
182
  session = ort.InferenceSession(onnx_path)
183
  model = _OnnxModel(session)
184
  tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base")
185
 
186
- _kompress_model = model
187
- _kompress_tokenizer = tokenizer
188
- _kompress_backend = "onnx"
189
- logger.info("Kompress ONNX INT8 loaded (no torch dependency)")
190
- return model, tokenizer
191
 
192
 
193
- def _load_kompress_pytorch(device: str = "auto") -> tuple[Any, Any]:
194
  """Download PyTorch model from HuggingFace and load with torch."""
195
  import torch
196
  from transformers import AutoTokenizer
197
 
198
- global _kompress_model, _kompress_tokenizer, _kompress_backend
199
-
200
  with _kompress_lock:
201
- if _kompress_model is not None:
202
- return _kompress_model, _kompress_tokenizer
203
 
204
  from huggingface_hub import hf_hub_download
205
 
206
- logger.info("Downloading Kompress PyTorch model from %s ...", HF_MODEL_ID)
207
- weights_path = hf_hub_download(HF_MODEL_ID, "model.safetensors")
208
 
209
  HeadroomCompressorModel = _get_model_class()
210
  model = HeadroomCompressorModel()
@@ -227,50 +218,60 @@ def _load_kompress_pytorch(device: str = "auto") -> tuple[Any, Any]:
227
 
228
  tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base")
229
 
230
- _kompress_model = model
231
- _kompress_tokenizer = tokenizer
232
- _kompress_backend = "pytorch"
233
- logger.info("Kompress PyTorch loaded on %s (%s)", device, HF_MODEL_ID)
234
- return model, tokenizer
235
 
236
 
237
- def _load_kompress(device: str = "auto") -> tuple[Any, Any]:
238
- """Load Kompress model: try ONNX first (lightweight), fall back to PyTorch."""
239
- global _kompress_model
240
- if _kompress_model is not None:
241
- return _kompress_model, _kompress_tokenizer
 
 
 
242
 
243
  # Prefer ONNX (50MB onnxruntime vs 800MB torch)
244
  if _is_onnx_available():
245
  try:
246
- return _load_kompress_onnx()
247
  except Exception as e:
248
- logger.warning("ONNX load failed, trying PyTorch: %s", e)
249
 
250
  if _is_pytorch_available():
251
- return _load_kompress_pytorch(device)
252
 
253
  raise ImportError(
254
  "Kompress requires onnxruntime or torch. Install with: pip install headroom-ai[proxy]"
255
  )
256
 
257
 
258
- def unload_kompress_model() -> bool:
259
- """Unload the Kompress model to free memory."""
260
- global _kompress_model, _kompress_tokenizer
 
 
 
261
  with _kompress_lock:
262
- if _kompress_model is not None:
263
- _kompress_model = None
264
- _kompress_tokenizer = None
265
- try:
266
- import torch
 
 
 
 
267
 
268
- if torch.cuda.is_available():
269
- torch.cuda.empty_cache()
270
- except ImportError:
271
- pass
272
- return True
273
- return False
 
 
274
 
275
 
276
  # ── Compressor ────────────────────────────────────────────────────────
@@ -278,10 +279,26 @@ def unload_kompress_model() -> bool:
278
 
279
  @dataclass
280
  class KompressConfig:
281
- """Minimal config. The model decides what's important — not us."""
 
 
 
 
 
 
 
 
 
 
 
 
 
282
 
283
  device: str = "auto"
284
  enable_ccr: bool = True
 
 
 
285
 
286
 
287
  @dataclass
@@ -308,9 +325,10 @@ class KompressResult:
308
 
309
 
310
  class KompressCompressor(Transform):
311
- """Kompress: ModernBERT token compressor for structured tool outputs.
312
 
313
- Auto-downloads chopratejas/kompress-base from HuggingFace on first use.
 
314
  """
315
 
316
  name: str = "kompress_compressor"
@@ -347,11 +365,10 @@ class KompressCompressor(Transform):
347
  return self._passthrough(content, n_words)
348
 
349
  try:
350
- model, tokenizer = _load_kompress(self.config.device)
351
- is_onnx = _kompress_backend == "onnx"
352
 
353
- # Chunk at 512 tokens ≈ 350 words (matches training max_length)
354
- max_chunk_words = 350
355
  kept_ids: set[int] = set()
356
 
357
  for chunk_start in range(0, n_words, max_chunk_words):
@@ -423,6 +440,7 @@ class KompressCompressor(Transform):
423
  original_tokens=n_words,
424
  compressed_tokens=compressed_count,
425
  compression_ratio=ratio,
 
426
  )
427
 
428
  # CCR marker
@@ -542,7 +560,7 @@ class KompressCompressor(Transform):
542
  word_lists: list[list[str]] = [c.split() for c in contents]
543
 
544
  # Short texts short-circuit to passthrough — no model call needed.
545
- max_chunk_words = 350
546
  chunk_queue: list[tuple[int, int, list[str], float | None]] = []
547
  for i, (words, ratio) in enumerate(zip(word_lists, ratios, strict=True)):
548
  if len(words) < 10:
@@ -558,7 +576,7 @@ class KompressCompressor(Transform):
558
 
559
  # Load model once for the whole batch.
560
  try:
561
- model, tokenizer = _load_kompress(self.config.device)
562
  except Exception as e:
563
  logger.warning("Kompress load failed for batch: %s — passthrough all", e)
564
  for i in range(n):
@@ -566,7 +584,7 @@ class KompressCompressor(Transform):
566
  results[i] = self._passthrough(contents[i], len(word_lists[i]))
567
  return [r for r in results if r is not None]
568
 
569
- is_onnx = _kompress_backend == "onnx"
570
  kept_ids_per_text: dict[int, set[int]] = {i: set() for i in range(n) if results[i] is None}
571
 
572
  for batch_start in range(0, len(chunk_queue), batch_size):
@@ -620,9 +638,9 @@ class KompressCompressor(Transform):
620
  for wid in sorted_wids[:num_keep]:
621
  kept_ids_per_text[text_idx].add(wid + chunk_start)
622
  else:
623
- # Threshold at 0.5 (matches ONNX get_keep_mask behavior).
624
  for wid, score in word_scores.items():
625
- if score > 0.5:
626
  kept_ids_per_text[text_idx].add(wid + chunk_start)
627
 
628
  except Exception as e:
@@ -659,6 +677,7 @@ class KompressCompressor(Transform):
659
  original_tokens=n_words,
660
  compressed_tokens=compressed_count,
661
  compression_ratio=comp_ratio,
 
662
  )
663
 
664
  if self.config.enable_ccr and comp_ratio < 0.8:
@@ -692,28 +711,28 @@ class KompressCompressor(Transform):
692
  If the model isn't loaded yet, we trigger loading so the backend
693
  is known. This is a no-op if the model is already in cache.
694
  """
695
- global _kompress_model, _kompress_backend
696
- if _kompress_model is None:
697
  try:
698
- _load_kompress(self.config.device)
699
  except Exception:
700
- # If load fails, caller will see the error downstream.
701
  return True
702
 
703
- if _kompress_backend == "onnx":
 
 
 
 
 
704
  return True # ONNX CPU provider doesn't parallelize batch dim
705
- if _kompress_backend == "pytorch":
706
  try:
707
  import torch
708
 
709
- # Check the model's actual device
710
- if _kompress_model is not None and hasattr(_kompress_model, "parameters"):
711
- device = next(_kompress_model.parameters()).device
712
- if device.type == "cuda":
713
- return False # GPU benefits from batching
714
- if device.type == "mps":
715
- return False # MPS (Apple Silicon) also benefits
716
- # Fall through for CPU
717
  _ = torch
718
  except ImportError:
719
  return True
 
25
 
26
  logger = logging.getLogger(__name__)
27
 
28
+ # Default HuggingFace model ID
29
  HF_MODEL_ID = "chopratejas/kompress-base"
30
 
31
+ # Model cache: model_id -> (model, tokenizer, backend)
32
+ # Supports multiple models loaded simultaneously.
33
+ _kompress_cache: dict[str, tuple[Any, Any, str]] = {}
34
  _kompress_lock = threading.Lock()
35
 
36
 
 
132
 
133
  # ── Model Loading ─────────────────────────────────────────────────────
134
 
 
 
 
135
 
136
  class _OnnxModel:
137
  """Thin wrapper so ONNX session has the same interface as PyTorch model."""
 
160
  return (np.array(scores) > 0.5).tolist()
161
 
162
 
163
+ def _load_kompress_onnx(model_id: str) -> tuple[Any, Any, str]:
164
  """Download ONNX INT8 model from HuggingFace and load with onnxruntime."""
165
  import onnxruntime as ort
166
  from transformers import AutoTokenizer
167
 
 
 
168
  with _kompress_lock:
169
+ if model_id in _kompress_cache:
170
+ return _kompress_cache[model_id]
171
 
172
  from huggingface_hub import hf_hub_download
173
 
174
+ logger.info("Downloading Kompress ONNX model from %s ...", model_id)
175
+ onnx_path = hf_hub_download(model_id, "onnx/kompress-int8.onnx")
176
 
177
  session = ort.InferenceSession(onnx_path)
178
  model = _OnnxModel(session)
179
  tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base")
180
 
181
+ _kompress_cache[model_id] = (model, tokenizer, "onnx")
182
+ logger.info("Kompress ONNX INT8 loaded: %s", model_id)
183
+ return model, tokenizer, "onnx"
 
 
184
 
185
 
186
+ def _load_kompress_pytorch(model_id: str, device: str = "auto") -> tuple[Any, Any, str]:
187
  """Download PyTorch model from HuggingFace and load with torch."""
188
  import torch
189
  from transformers import AutoTokenizer
190
 
 
 
191
  with _kompress_lock:
192
+ if model_id in _kompress_cache:
193
+ return _kompress_cache[model_id]
194
 
195
  from huggingface_hub import hf_hub_download
196
 
197
+ logger.info("Downloading Kompress PyTorch model from %s ...", model_id)
198
+ weights_path = hf_hub_download(model_id, "model.safetensors")
199
 
200
  HeadroomCompressorModel = _get_model_class()
201
  model = HeadroomCompressorModel()
 
218
 
219
  tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base")
220
 
221
+ _kompress_cache[model_id] = (model, tokenizer, "pytorch")
222
+ logger.info("Kompress PyTorch loaded on %s (%s)", device, model_id)
223
+ return model, tokenizer, "pytorch"
 
 
224
 
225
 
226
+ def _load_kompress(model_id: str = HF_MODEL_ID, device: str = "auto") -> tuple[Any, Any, str]:
227
+ """Load Kompress model, returns (model, tokenizer, backend).
228
+
229
+ Try ONNX first (lightweight), fall back to PyTorch.
230
+ Models are cached by model_id — multiple models can coexist.
231
+ """
232
+ if model_id in _kompress_cache:
233
+ return _kompress_cache[model_id]
234
 
235
  # Prefer ONNX (50MB onnxruntime vs 800MB torch)
236
  if _is_onnx_available():
237
  try:
238
+ return _load_kompress_onnx(model_id)
239
  except Exception as e:
240
+ logger.warning("ONNX load failed for %s, trying PyTorch: %s", model_id, e)
241
 
242
  if _is_pytorch_available():
243
+ return _load_kompress_pytorch(model_id, device)
244
 
245
  raise ImportError(
246
  "Kompress requires onnxruntime or torch. Install with: pip install headroom-ai[proxy]"
247
  )
248
 
249
 
250
+ def unload_kompress_model(model_id: str | None = None) -> bool:
251
+ """Unload Kompress model(s) to free memory.
252
+
253
+ Args:
254
+ model_id: Specific model to unload. If None, unloads all cached models.
255
+ """
256
  with _kompress_lock:
257
+ if model_id is not None:
258
+ if model_id in _kompress_cache:
259
+ del _kompress_cache[model_id]
260
+ else:
261
+ return False
262
+ elif _kompress_cache:
263
+ _kompress_cache.clear()
264
+ else:
265
+ return False
266
 
267
+ try:
268
+ import torch
269
+
270
+ if torch.cuda.is_available():
271
+ torch.cuda.empty_cache()
272
+ except ImportError:
273
+ pass
274
+ return True
275
 
276
 
277
  # ── Compressor ────────────────────────────────────────────────────────
 
279
 
280
  @dataclass
281
  class KompressConfig:
282
+ """Configuration for Kompress compression.
283
+
284
+ The model_id, chunk_words, and score_threshold are coupled: a model
285
+ trained on 50-word chunks needs chunk_words=50 at inference. The
286
+ defaults match kompress-base. For domain-specific models, set all three.
287
+
288
+ Example — financial documents::
289
+
290
+ KompressConfig(
291
+ model_id="chopratejas/kompress-finance",
292
+ chunk_words=50,
293
+ score_threshold=0.5,
294
+ )
295
+ """
296
 
297
  device: str = "auto"
298
  enable_ccr: bool = True
299
+ model_id: str = HF_MODEL_ID
300
+ chunk_words: int = 350
301
+ score_threshold: float = 0.5
302
 
303
 
304
  @dataclass
 
325
 
326
 
327
  class KompressCompressor(Transform):
328
+ """Kompress: ModernBERT token compressor.
329
 
330
+ Auto-downloads the model from HuggingFace on first use.
331
+ Configure via KompressConfig to select model, chunk size, and threshold.
332
  """
333
 
334
  name: str = "kompress_compressor"
 
365
  return self._passthrough(content, n_words)
366
 
367
  try:
368
+ model, tokenizer, backend = _load_kompress(self.config.model_id, self.config.device)
369
+ is_onnx = backend == "onnx"
370
 
371
+ max_chunk_words = self.config.chunk_words
 
372
  kept_ids: set[int] = set()
373
 
374
  for chunk_start in range(0, n_words, max_chunk_words):
 
440
  original_tokens=n_words,
441
  compressed_tokens=compressed_count,
442
  compression_ratio=ratio,
443
+ model_used=self.config.model_id,
444
  )
445
 
446
  # CCR marker
 
560
  word_lists: list[list[str]] = [c.split() for c in contents]
561
 
562
  # Short texts short-circuit to passthrough — no model call needed.
563
+ max_chunk_words = self.config.chunk_words
564
  chunk_queue: list[tuple[int, int, list[str], float | None]] = []
565
  for i, (words, ratio) in enumerate(zip(word_lists, ratios, strict=True)):
566
  if len(words) < 10:
 
576
 
577
  # Load model once for the whole batch.
578
  try:
579
+ model, tokenizer, backend = _load_kompress(self.config.model_id, self.config.device)
580
  except Exception as e:
581
  logger.warning("Kompress load failed for batch: %s — passthrough all", e)
582
  for i in range(n):
 
584
  results[i] = self._passthrough(contents[i], len(word_lists[i]))
585
  return [r for r in results if r is not None]
586
 
587
+ is_onnx = backend == "onnx"
588
  kept_ids_per_text: dict[int, set[int]] = {i: set() for i in range(n) if results[i] is None}
589
 
590
  for batch_start in range(0, len(chunk_queue), batch_size):
 
638
  for wid in sorted_wids[:num_keep]:
639
  kept_ids_per_text[text_idx].add(wid + chunk_start)
640
  else:
641
+ # Threshold from config (default 0.5, matches ONNX get_keep_mask).
642
  for wid, score in word_scores.items():
643
+ if score > self.config.score_threshold:
644
  kept_ids_per_text[text_idx].add(wid + chunk_start)
645
 
646
  except Exception as e:
 
677
  original_tokens=n_words,
678
  compressed_tokens=compressed_count,
679
  compression_ratio=comp_ratio,
680
+ model_used=self.config.model_id,
681
  )
682
 
683
  if self.config.enable_ccr and comp_ratio < 0.8:
 
711
  If the model isn't loaded yet, we trigger loading so the backend
712
  is known. This is a no-op if the model is already in cache.
713
  """
714
+ model_id = self.config.model_id
715
+ if model_id not in _kompress_cache:
716
  try:
717
+ _load_kompress(model_id, self.config.device)
718
  except Exception:
 
719
  return True
720
 
721
+ if model_id not in _kompress_cache:
722
+ return True
723
+
724
+ model, _tokenizer, backend = _kompress_cache[model_id]
725
+
726
+ if backend == "onnx":
727
  return True # ONNX CPU provider doesn't parallelize batch dim
728
+ if backend == "pytorch":
729
  try:
730
  import torch
731
 
732
+ if hasattr(model, "parameters"):
733
+ device = next(model.parameters()).device
734
+ if device.type in ("cuda", "mps"):
735
+ return False # GPU/MPS benefits from batching
 
 
 
 
736
  _ = torch
737
  except ImportError:
738
  return True
headroom/transforms/smart_crusher.py CHANGED
@@ -180,27 +180,31 @@ def _hash_field_name(field_name: str) -> str:
180
  # Minimum chars for a text field to be worth compressing within an item
181
  _MIN_FIELD_CHARS_FOR_WITHIN = 200
182
 
183
- # Lazy-loaded compressor for within-item text compression
184
  _within_compressor: Any = None
185
  _within_compressor_checked = False
 
186
 
187
 
188
  def _get_within_compressor() -> Any:
189
  """Get a text compressor for within-item field compression.
190
 
191
  Returns Kompress if available (requires [ml] extra), else None.
 
192
  """
193
  global _within_compressor, _within_compressor_checked
194
  if not _within_compressor_checked:
195
- _within_compressor_checked = True
196
- try:
197
- from .kompress_compressor import KompressCompressor, is_kompress_available
198
-
199
- if is_kompress_available():
200
- _within_compressor = KompressCompressor()
201
- logger.debug("Within-item compression: using Kompress")
202
- except ImportError:
203
- pass
 
 
204
  return _within_compressor
205
 
206
 
@@ -435,7 +439,7 @@ def _detect_sequential_pattern(values: list[Any], check_order: bool = True) -> b
435
  # Get numeric values
436
  nums = []
437
  for v in values:
438
- if isinstance(v, (int, float)) and not isinstance(v, bool):
439
  nums.append(v)
440
  elif isinstance(v, str):
441
  try:
@@ -546,7 +550,6 @@ def _detect_score_field_statistically(stats: FieldStats, items: list[dict]) -> t
546
  confidence = 0.0
547
 
548
  # Check for bounded range typical of scores
549
- stats.max_val - stats.min_val
550
  min_val, max_val = stats.min_val, stats.max_val
551
 
552
  # Common score ranges: [0,1], [0,10], [0,100], [-1,1], [0,5]
@@ -578,7 +581,7 @@ def _detect_score_field_statistically(stats: FieldStats, items: list[dict]) -> t
578
  for item in items:
579
  if stats.name in item:
580
  val = item.get(stats.name)
581
- if isinstance(val, (int, float)) and math.isfinite(val):
582
  values_in_order.append(float(val))
583
  if len(values_in_order) >= 5:
584
  # Check for descending sort
@@ -804,11 +807,11 @@ def _detect_items_by_learned_semantics(
804
  value_canonical = "null"
805
  elif isinstance(value, bool):
806
  value_canonical = "true" if value else "false"
807
- elif isinstance(value, (int, float)):
808
  value_canonical = str(value)
809
  elif isinstance(value, str):
810
  value_canonical = value
811
- elif isinstance(value, (list, dict)):
812
  try:
813
  value_canonical = json.dumps(value, sort_keys=True, default=str)
814
  except (TypeError, ValueError):
@@ -1030,7 +1033,7 @@ class SmartAnalyzer:
1030
  first_val = non_null_values[0]
1031
  if isinstance(first_val, bool):
1032
  field_type = "boolean"
1033
- elif isinstance(first_val, (int, float)):
1034
  field_type = "numeric"
1035
  elif isinstance(first_val, str):
1036
  field_type = "string"
@@ -1064,7 +1067,7 @@ class SmartAnalyzer:
1064
  # Numeric-specific analysis
1065
  if field_type == "numeric":
1066
  # Filter out NaN and Infinity which break statistics functions
1067
- nums = [v for v in non_null_values if isinstance(v, (int, float)) and math.isfinite(v)]
1068
  if nums:
1069
  try:
1070
  stats.min_val = min(nums)
@@ -1283,7 +1286,7 @@ class SmartAnalyzer:
1283
  threshold = self.config.variance_threshold * std
1284
  for i, item in enumerate(items):
1285
  val = item.get(stats.name)
1286
- if isinstance(val, (int, float)):
1287
  if abs(val - stats.mean_val) > threshold:
1288
  anomaly_indices.add(i)
1289
 
@@ -1953,9 +1956,9 @@ class SmartCrusher(Transform):
1953
  if len(keep_indices) <= effective_max:
1954
  return keep_indices
1955
 
1956
- # Use provided field_semantics or fall back to instance variable (set by crush())
1957
  effective_field_semantics = field_semantics or getattr(
1958
- self, "_current_field_semantics", None
1959
  )
1960
 
1961
  # Identify error items using KEYWORD detection (preservation guarantee)
@@ -1976,7 +1979,7 @@ class SmartCrusher(Transform):
1976
  threshold = self.config.variance_threshold * std
1977
  for i, item in enumerate(items):
1978
  val = item.get(field_name)
1979
- if isinstance(val, (int, float)):
1980
  if abs(val - stats.mean_val) > threshold:
1981
  anomaly_indices.add(i)
1982
 
@@ -2297,6 +2300,10 @@ class SmartCrusher(Transform):
2297
 
2298
  return result, was_modified, info
2299
 
 
 
 
 
2300
  def _process_value(
2301
  self,
2302
  value: Any,
@@ -2311,6 +2318,10 @@ class SmartCrusher(Transform):
2311
  Tuple of (processed_value, info_string, ccr_markers).
2312
  ccr_markers is a list of (hash, original_count, compressed_count, summary) tuples.
2313
  """
 
 
 
 
2314
  info_parts = []
2315
  ccr_markers: list[tuple] = []
2316
 
@@ -2495,9 +2506,12 @@ class SmartCrusher(Transform):
2495
  )
2496
 
2497
  # === TOIN Evolution: Extract field semantics for signal detection ===
2498
- # Store temporarily on instance for use in _prioritize_indices
2499
  # This enables learned signal detection without changing all method signatures
2500
- self._current_field_semantics = (
 
 
 
2501
  toin_hint.field_semantics if toin_hint.field_semantics else None
2502
  )
2503
 
@@ -2661,12 +2675,14 @@ class SmartCrusher(Transform):
2661
  )
2662
 
2663
  # Clean up temporary instance variable
2664
- self._current_field_semantics = None
 
2665
  return result, strategy_info, ccr_hash, dropped_summary
2666
 
2667
  except Exception:
2668
  # Clean up temporary instance variable
2669
- self._current_field_semantics = None
 
2670
  # Re-raise any exceptions (removed finally block since we no longer mutate config)
2671
  raise
2672
 
@@ -2814,7 +2830,7 @@ class SmartCrusher(Transform):
2814
  return items, "number:passthrough"
2815
 
2816
  # Filter out non-finite values for statistics
2817
- finite = [x for x in items if isinstance(x, (int, float)) and math.isfinite(x)]
2818
  if not finite:
2819
  return items, "number:no_finite"
2820
 
@@ -2832,7 +2848,7 @@ class SmartCrusher(Transform):
2832
  outlier_indices: set[int] = set()
2833
  if std_val > 0:
2834
  for i, val in enumerate(items):
2835
- if isinstance(val, (int, float)) and math.isfinite(val):
2836
  if abs(val - mean_val) > self.config.variance_threshold * std_val:
2837
  outlier_indices.add(i)
2838
 
@@ -2844,12 +2860,12 @@ class SmartCrusher(Transform):
2844
  left = [
2845
  items[j]
2846
  for j in range(i - window, i)
2847
- if isinstance(items[j], (int, float)) and math.isfinite(items[j])
2848
  ]
2849
  right = [
2850
  items[j]
2851
  for j in range(i, i + window)
2852
- if isinstance(items[j], (int, float)) and math.isfinite(items[j])
2853
  ]
2854
  if left and right:
2855
  left_mean = statistics.mean(left)
@@ -2877,27 +2893,23 @@ class SmartCrusher(Transform):
2877
  if i not in keep_indices:
2878
  keep_indices.add(i)
2879
 
2880
- # Build output: summary string + kept values in original order
2881
- stats_summary = (
2882
- f"[{n} numbers: min={min(finite)}, max={max(finite)}, "
2883
- f"mean={mean_val:.4g}, median={median_val:.4g}, "
2884
- f"stddev={std_val:.4g}, p25={p25:.4g}, p75={p75:.4g}"
2885
- )
2886
- if outlier_indices:
2887
- stats_summary += f", outliers={len(outlier_indices)}"
2888
- if change_indices:
2889
- stats_summary += f", change_points={len(change_indices)}"
2890
- stats_summary += "]"
2891
-
2892
  kept_values = [items[i] for i in sorted(keep_indices)]
2893
- result: list = [stats_summary] + kept_values
2894
 
2895
- strategy = f"number:adaptive({n}->{len(kept_values)}"
 
 
 
 
 
 
2896
  if outlier_indices:
2897
  strategy += f",outliers={len(outlier_indices)}"
 
 
2898
  strategy += ")"
2899
 
2900
- return result, strategy
2901
 
2902
  def _crush_mixed_array(
2903
  self,
@@ -2930,7 +2942,7 @@ class SmartCrusher(Transform):
2930
  key = "str"
2931
  elif isinstance(item, bool):
2932
  key = "bool"
2933
- elif isinstance(item, (int, float)):
2934
  key = "number"
2935
  elif isinstance(item, list):
2936
  key = "list"
@@ -2979,13 +2991,13 @@ class SmartCrusher(Transform):
2979
  last_idx = set(indices[-k_last:])
2980
  keep_indices.update(first_idx | last_idx)
2981
  # Outliers
2982
- finite = [v for v in values if isinstance(v, (int, float)) and math.isfinite(v)]
2983
  if len(finite) > 1:
2984
  mean_v = statistics.mean(finite)
2985
  std_v = statistics.stdev(finite)
2986
  if std_v > 0:
2987
  for idx, val in group_items:
2988
- if isinstance(val, (int, float)) and math.isfinite(val):
2989
  if abs(val - mean_v) > self.config.variance_threshold * std_v:
2990
  keep_indices.add(idx)
2991
  strategy_parts.append(f"num:{len(values)}")
@@ -3553,7 +3565,7 @@ class SmartCrusher(Transform):
3553
  threshold = self.config.variance_threshold * std
3554
  for i, item in enumerate(items):
3555
  val = item.get(name)
3556
- if isinstance(val, (int, float)):
3557
  if abs(val - stats.mean_val) > threshold:
3558
  keep_indices.add(i)
3559
 
 
180
  # Minimum chars for a text field to be worth compressing within an item
181
  _MIN_FIELD_CHARS_FOR_WITHIN = 200
182
 
183
+ # Lazy-loaded compressor for within-item text compression (thread-safe)
184
  _within_compressor: Any = None
185
  _within_compressor_checked = False
186
+ _within_compressor_lock = threading.Lock()
187
 
188
 
189
  def _get_within_compressor() -> Any:
190
  """Get a text compressor for within-item field compression.
191
 
192
  Returns Kompress if available (requires [ml] extra), else None.
193
+ Thread-safe via double-checked locking.
194
  """
195
  global _within_compressor, _within_compressor_checked
196
  if not _within_compressor_checked:
197
+ with _within_compressor_lock:
198
+ if not _within_compressor_checked:
199
+ try:
200
+ from .kompress_compressor import KompressCompressor, is_kompress_available
201
+
202
+ if is_kompress_available():
203
+ _within_compressor = KompressCompressor()
204
+ logger.debug("Within-item compression: using Kompress")
205
+ except ImportError:
206
+ pass
207
+ _within_compressor_checked = True
208
  return _within_compressor
209
 
210
 
 
439
  # Get numeric values
440
  nums = []
441
  for v in values:
442
+ if isinstance(v, int | float) and not isinstance(v, bool):
443
  nums.append(v)
444
  elif isinstance(v, str):
445
  try:
 
550
  confidence = 0.0
551
 
552
  # Check for bounded range typical of scores
 
553
  min_val, max_val = stats.min_val, stats.max_val
554
 
555
  # Common score ranges: [0,1], [0,10], [0,100], [-1,1], [0,5]
 
581
  for item in items:
582
  if stats.name in item:
583
  val = item.get(stats.name)
584
+ if isinstance(val, int | float) and math.isfinite(val):
585
  values_in_order.append(float(val))
586
  if len(values_in_order) >= 5:
587
  # Check for descending sort
 
807
  value_canonical = "null"
808
  elif isinstance(value, bool):
809
  value_canonical = "true" if value else "false"
810
+ elif isinstance(value, int | float):
811
  value_canonical = str(value)
812
  elif isinstance(value, str):
813
  value_canonical = value
814
+ elif isinstance(value, list | dict):
815
  try:
816
  value_canonical = json.dumps(value, sort_keys=True, default=str)
817
  except (TypeError, ValueError):
 
1033
  first_val = non_null_values[0]
1034
  if isinstance(first_val, bool):
1035
  field_type = "boolean"
1036
+ elif isinstance(first_val, int | float):
1037
  field_type = "numeric"
1038
  elif isinstance(first_val, str):
1039
  field_type = "string"
 
1067
  # Numeric-specific analysis
1068
  if field_type == "numeric":
1069
  # Filter out NaN and Infinity which break statistics functions
1070
+ nums = [v for v in non_null_values if isinstance(v, int | float) and math.isfinite(v)]
1071
  if nums:
1072
  try:
1073
  stats.min_val = min(nums)
 
1286
  threshold = self.config.variance_threshold * std
1287
  for i, item in enumerate(items):
1288
  val = item.get(stats.name)
1289
+ if isinstance(val, int | float):
1290
  if abs(val - stats.mean_val) > threshold:
1291
  anomaly_indices.add(i)
1292
 
 
1956
  if len(keep_indices) <= effective_max:
1957
  return keep_indices
1958
 
1959
+ # Use provided field_semantics or fall back to thread-local (set by _crush_array)
1960
  effective_field_semantics = field_semantics or getattr(
1961
+ getattr(self, "_thread_local", None), "field_semantics", None
1962
  )
1963
 
1964
  # Identify error items using KEYWORD detection (preservation guarantee)
 
1979
  threshold = self.config.variance_threshold * std
1980
  for i, item in enumerate(items):
1981
  val = item.get(field_name)
1982
+ if isinstance(val, int | float):
1983
  if abs(val - stats.mean_val) > threshold:
1984
  anomaly_indices.add(i)
1985
 
 
2300
 
2301
  return result, was_modified, info
2302
 
2303
+ # Maximum recursion depth for nested JSON processing.
2304
+ # Prevents RecursionError on adversarial/deeply-nested input.
2305
+ _MAX_PROCESS_DEPTH = 50
2306
+
2307
  def _process_value(
2308
  self,
2309
  value: Any,
 
2318
  Tuple of (processed_value, info_string, ccr_markers).
2319
  ccr_markers is a list of (hash, original_count, compressed_count, summary) tuples.
2320
  """
2321
+ # Guard against deeply nested JSON causing RecursionError
2322
+ if depth >= self._MAX_PROCESS_DEPTH:
2323
+ return value, "", []
2324
+
2325
  info_parts = []
2326
  ccr_markers: list[tuple] = []
2327
 
 
2506
  )
2507
 
2508
  # === TOIN Evolution: Extract field semantics for signal detection ===
2509
+ # Store in thread-local storage for use in _prioritize_indices.
2510
  # This enables learned signal detection without changing all method signatures
2511
+ # while remaining thread-safe (no cross-thread contamination).
2512
+ if not hasattr(self, "_thread_local"):
2513
+ self._thread_local = threading.local()
2514
+ self._thread_local.field_semantics = (
2515
  toin_hint.field_semantics if toin_hint.field_semantics else None
2516
  )
2517
 
 
2675
  )
2676
 
2677
  # Clean up temporary instance variable
2678
+ if hasattr(self, "_thread_local"):
2679
+ self._thread_local.field_semantics = None
2680
  return result, strategy_info, ccr_hash, dropped_summary
2681
 
2682
  except Exception:
2683
  # Clean up temporary instance variable
2684
+ if hasattr(self, "_thread_local"):
2685
+ self._thread_local.field_semantics = None
2686
  # Re-raise any exceptions (removed finally block since we no longer mutate config)
2687
  raise
2688
 
 
2830
  return items, "number:passthrough"
2831
 
2832
  # Filter out non-finite values for statistics
2833
+ finite = [x for x in items if isinstance(x, int | float) and math.isfinite(x)]
2834
  if not finite:
2835
  return items, "number:no_finite"
2836
 
 
2848
  outlier_indices: set[int] = set()
2849
  if std_val > 0:
2850
  for i, val in enumerate(items):
2851
+ if isinstance(val, int | float) and math.isfinite(val):
2852
  if abs(val - mean_val) > self.config.variance_threshold * std_val:
2853
  outlier_indices.add(i)
2854
 
 
2860
  left = [
2861
  items[j]
2862
  for j in range(i - window, i)
2863
+ if isinstance(items[j], int | float) and math.isfinite(items[j])
2864
  ]
2865
  right = [
2866
  items[j]
2867
  for j in range(i, i + window)
2868
+ if isinstance(items[j], int | float) and math.isfinite(items[j])
2869
  ]
2870
  if left and right:
2871
  left_mean = statistics.mean(left)
 
2893
  if i not in keep_indices:
2894
  keep_indices.add(i)
2895
 
2896
+ # Build output: kept values only (schema-preserving no generated text)
 
 
 
 
 
 
 
 
 
 
 
2897
  kept_values = [items[i] for i in sorted(keep_indices)]
 
2898
 
2899
+ # Encode statistics into the strategy string (not the array itself)
2900
+ strategy = (
2901
+ f"number:adaptive({n}->{len(kept_values)}"
2902
+ f",min={min(finite)},max={max(finite)}"
2903
+ f",mean={mean_val:.4g},median={median_val:.4g}"
2904
+ f",stddev={std_val:.4g},p25={p25:.4g},p75={p75:.4g}"
2905
+ )
2906
  if outlier_indices:
2907
  strategy += f",outliers={len(outlier_indices)}"
2908
+ if change_indices:
2909
+ strategy += f",change_points={len(change_indices)}"
2910
  strategy += ")"
2911
 
2912
+ return kept_values, strategy
2913
 
2914
  def _crush_mixed_array(
2915
  self,
 
2942
  key = "str"
2943
  elif isinstance(item, bool):
2944
  key = "bool"
2945
+ elif isinstance(item, int | float):
2946
  key = "number"
2947
  elif isinstance(item, list):
2948
  key = "list"
 
2991
  last_idx = set(indices[-k_last:])
2992
  keep_indices.update(first_idx | last_idx)
2993
  # Outliers
2994
+ finite = [v for v in values if isinstance(v, int | float) and math.isfinite(v)]
2995
  if len(finite) > 1:
2996
  mean_v = statistics.mean(finite)
2997
  std_v = statistics.stdev(finite)
2998
  if std_v > 0:
2999
  for idx, val in group_items:
3000
+ if isinstance(val, int | float) and math.isfinite(val):
3001
  if abs(val - mean_v) > self.config.variance_threshold * std_v:
3002
  keep_indices.add(idx)
3003
  strategy_parts.append(f"num:{len(values)}")
 
3565
  threshold = self.config.variance_threshold * std
3566
  for i, item in enumerate(items):
3567
  val = item.get(name)
3568
+ if isinstance(val, int | float):
3569
  if abs(val - stats.mean_val) > threshold:
3570
  keep_indices.add(i)
3571
 
plugins/openclaw/package.json CHANGED
@@ -1,54 +1,54 @@
1
- {
2
- "name": "headroom-openclaw",
3
- "version": "0.1.0",
4
- "description": "Headroom context compression plugin for OpenClaw — 70-90% token savings with zero LLM calls",
5
- "type": "module",
6
- "main": "./dist/index.js",
7
- "types": "./dist/index.d.ts",
8
- "files": [
9
- "dist",
10
- "hook-shim",
11
- "openclaw.plugin.json",
12
- "README.md"
13
- ],
14
- "scripts": {
15
- "build": "tsup && node prepare-dist.mjs",
16
- "test": "vitest run",
17
- "test:watch": "vitest",
18
- "typecheck": "tsc --noEmit"
19
- },
20
- "dependencies": {
21
- "headroom-ai": "^0.1.0"
22
- },
23
- "peerDependencies": {
24
- "openclaw": "*"
25
- },
26
- "peerDependenciesMeta": {
27
- "openclaw": {
28
- "optional": true
29
- }
30
- },
31
- "devDependencies": {
32
- "@types/node": "^22.10.0",
33
- "tsup": "^8.0.0",
34
- "typescript": "^5.5.0",
35
- "vitest": "^2.0.0"
36
- },
37
- "openclaw": {
38
- "hooks": [
39
- "./hook-shim"
40
- ],
41
- "extensions": [
42
- "./dist/index.js"
43
- ],
44
- "capabilities": {
45
- "network": {
46
- "allow": [
47
- "http://*:*",
48
- "https://*:*"
49
- ]
50
- }
51
- }
52
- },
53
- "license": "Apache-2.0"
54
- }
 
1
+ {
2
+ "name": "headroom-openclaw",
3
+ "version": "0.1.1",
4
+ "description": "Headroom context compression plugin for OpenClaw — 70-90% token savings with zero LLM calls",
5
+ "type": "module",
6
+ "main": "./dist/index.js",
7
+ "types": "./dist/index.d.ts",
8
+ "files": [
9
+ "dist",
10
+ "hook-shim",
11
+ "openclaw.plugin.json",
12
+ "README.md"
13
+ ],
14
+ "scripts": {
15
+ "build": "tsup && node prepare-dist.mjs",
16
+ "test": "vitest run",
17
+ "test:watch": "vitest",
18
+ "typecheck": "tsc --noEmit"
19
+ },
20
+ "dependencies": {
21
+ "headroom-ai": "^0.1.0"
22
+ },
23
+ "peerDependencies": {
24
+ "openclaw": "*"
25
+ },
26
+ "peerDependenciesMeta": {
27
+ "openclaw": {
28
+ "optional": true
29
+ }
30
+ },
31
+ "devDependencies": {
32
+ "@types/node": "^22.10.0",
33
+ "tsup": "^8.0.0",
34
+ "typescript": "^5.5.0",
35
+ "vitest": "^2.0.0"
36
+ },
37
+ "openclaw": {
38
+ "hooks": [
39
+ "./hook-shim"
40
+ ],
41
+ "extensions": [
42
+ "./dist/index.js"
43
+ ],
44
+ "capabilities": {
45
+ "network": {
46
+ "allow": [
47
+ "http://*:*",
48
+ "https://*:*"
49
+ ]
50
+ }
51
+ }
52
+ },
53
+ "license": "Apache-2.0"
54
+ }
tests/test_cli/test_wrap_copilot.py CHANGED
@@ -21,6 +21,7 @@ def test_wrap_copilot_auto_anthropic_injects_instructions(
21
  runner: CliRunner, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
22
  ) -> None:
23
  monkeypatch.chdir(tmp_path)
 
24
  captured: dict[str, object] = {}
25
 
26
  def fake_launch_tool(**kwargs): # noqa: ANN003
@@ -51,7 +52,10 @@ def test_wrap_copilot_auto_anthropic_injects_instructions(
51
  assert captured["args"] == ("--model", "claude-sonnet-4-20250514")
52
 
53
 
54
- def test_wrap_copilot_openai_backend_sets_completions_env(runner: CliRunner) -> None:
 
 
 
55
  captured: dict[str, object] = {}
56
 
57
  def fake_launch_tool(**kwargs): # noqa: ANN003
@@ -90,7 +94,10 @@ def test_wrap_copilot_openai_backend_sets_completions_env(runner: CliRunner) ->
90
  assert captured["args"] == ("--model", "gpt-4o")
91
 
92
 
93
- def test_wrap_copilot_auto_detects_running_proxy_backend(runner: CliRunner) -> None:
 
 
 
94
  captured: dict[str, object] = {}
95
 
96
  def fake_launch_tool(**kwargs): # noqa: ANN003
@@ -153,7 +160,10 @@ def test_wrap_copilot_rejects_responses_for_translated_backends(runner: CliRunne
153
  assert "not supported with translated backends" in result.output
154
 
155
 
156
- def test_wrap_copilot_clears_stale_wire_api_in_anthropic_mode(runner: CliRunner) -> None:
 
 
 
157
  captured: dict[str, object] = {}
158
 
159
  def fake_launch_tool(**kwargs): # noqa: ANN003
@@ -164,7 +174,10 @@ def test_wrap_copilot_clears_stale_wire_api_in_anthropic_mode(runner: CliRunner)
164
  result = runner.invoke(
165
  main,
166
  ["wrap", "copilot", "--no-rtk", "--", "--model", "claude-sonnet-4-20250514"],
167
- env={"COPILOT_PROVIDER_WIRE_API": "responses"},
 
 
 
168
  )
169
 
170
  assert result.exit_code == 0, result.output
 
21
  runner: CliRunner, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
22
  ) -> None:
23
  monkeypatch.chdir(tmp_path)
24
+ monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-test-dummy")
25
  captured: dict[str, object] = {}
26
 
27
  def fake_launch_tool(**kwargs): # noqa: ANN003
 
52
  assert captured["args"] == ("--model", "claude-sonnet-4-20250514")
53
 
54
 
55
+ def test_wrap_copilot_openai_backend_sets_completions_env(
56
+ runner: CliRunner, monkeypatch: pytest.MonkeyPatch
57
+ ) -> None:
58
+ monkeypatch.setenv("OPENAI_API_KEY", "sk-test-dummy")
59
  captured: dict[str, object] = {}
60
 
61
  def fake_launch_tool(**kwargs): # noqa: ANN003
 
94
  assert captured["args"] == ("--model", "gpt-4o")
95
 
96
 
97
+ def test_wrap_copilot_auto_detects_running_proxy_backend(
98
+ runner: CliRunner, monkeypatch: pytest.MonkeyPatch
99
+ ) -> None:
100
+ monkeypatch.setenv("OPENAI_API_KEY", "sk-test-dummy")
101
  captured: dict[str, object] = {}
102
 
103
  def fake_launch_tool(**kwargs): # noqa: ANN003
 
160
  assert "not supported with translated backends" in result.output
161
 
162
 
163
+ def test_wrap_copilot_clears_stale_wire_api_in_anthropic_mode(
164
+ runner: CliRunner, monkeypatch: pytest.MonkeyPatch
165
+ ) -> None:
166
+ monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-test-dummy")
167
  captured: dict[str, object] = {}
168
 
169
  def fake_launch_tool(**kwargs): # noqa: ANN003
 
174
  result = runner.invoke(
175
  main,
176
  ["wrap", "copilot", "--no-rtk", "--", "--model", "claude-sonnet-4-20250514"],
177
+ env={
178
+ "COPILOT_PROVIDER_WIRE_API": "responses",
179
+ "ANTHROPIC_API_KEY": "sk-test-dummy",
180
+ },
181
  )
182
 
183
  assert result.exit_code == 0, result.output
tests/test_learn/test_scanner.py CHANGED
@@ -103,6 +103,38 @@ class TestGreedyPathDecode:
103
  result = _greedy_path_decode(tmp_path, ["my", "cool", "project", "nosync", "headroom"])
104
  assert result == tmp_path / "my-cool-project.nosync" / "headroom"
105
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  def test_nonexistent_path_returns_none(self, tmp_path: Path) -> None:
107
  result = _greedy_path_decode(tmp_path, ["does", "not", "exist"])
108
  assert result is None
@@ -236,6 +268,38 @@ class TestDecodeProjectPath:
236
  else:
237
  assert result is None or result == project
238
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
  def test_windows_drive_letter_pattern(self) -> None:
240
  """Encoded name -C-MQ2-macros should detect Windows drive letter."""
241
  import sys
 
103
  result = _greedy_path_decode(tmp_path, ["my", "cool", "project", "nosync", "headroom"])
104
  assert result == tmp_path / "my-cool-project.nosync" / "headroom"
105
 
106
+ # ---- Underscore tests (issue #159) ----
107
+
108
+ def test_single_underscore_in_dirname(self, tmp_path: Path) -> None:
109
+ """Directory name contains one literal underscore (e.g. my_project)."""
110
+ _make_dirs(tmp_path, "my_project")
111
+ result = _greedy_path_decode(tmp_path, ["my", "project"])
112
+ assert result == tmp_path / "my_project"
113
+
114
+ def test_multiple_underscores_in_dirname(self, tmp_path: Path) -> None:
115
+ """Directory name contains multiple underscores (e.g. my_cool_project)."""
116
+ _make_dirs(tmp_path, "my_cool_project")
117
+ result = _greedy_path_decode(tmp_path, ["my", "cool", "project"])
118
+ assert result == tmp_path / "my_cool_project"
119
+
120
+ def test_underscore_nested_path(self, tmp_path: Path) -> None:
121
+ """Nested path like org/my_project should decode correctly."""
122
+ _make_dirs(tmp_path, "org/my_project")
123
+ result = _greedy_path_decode(tmp_path, ["org", "my", "project"])
124
+ assert result == tmp_path / "org" / "my_project"
125
+
126
+ def test_mixed_underscore_and_hyphen_in_dirname(self, tmp_path: Path) -> None:
127
+ """Directory with both hyphens and underscores (e.g. my-cool_project)."""
128
+ _make_dirs(tmp_path, "my-cool_project")
129
+ result = _greedy_path_decode(tmp_path, ["my", "cool", "project"])
130
+ assert result == tmp_path / "my-cool_project"
131
+
132
+ def test_underscore_dir_containing_hyphen_subdir(self, tmp_path: Path) -> None:
133
+ """Path like my_app/sub-module — underscore parent + hyphen child."""
134
+ _make_dirs(tmp_path, "my_app/sub-module")
135
+ result = _greedy_path_decode(tmp_path, ["my", "app", "sub", "module"])
136
+ assert result == tmp_path / "my_app" / "sub-module"
137
+
138
  def test_nonexistent_path_returns_none(self, tmp_path: Path) -> None:
139
  result = _greedy_path_decode(tmp_path, ["does", "not", "exist"])
140
  assert result is None
 
268
  else:
269
  assert result is None or result == project
270
 
271
+ def test_underscore_dirname_via_greedy(self, users_tmp: Path) -> None:
272
+ """my_project — underscore in directory name (issue #159).
273
+
274
+ Claude Code encodes /Users/foo/org/my_project as
275
+ -Users-foo-org-my-project. Simple replace gives
276
+ …/org/my/project which does not exist, so the greedy decoder
277
+ must reconstruct my_project from tokens ['my', 'project'].
278
+ """
279
+ project = users_tmp / "org" / "my_project"
280
+ project.mkdir(parents=True)
281
+
282
+ encoded = "-" + str(project)[1:].replace("/", "-")
283
+ result = _decode_project_path(encoded)
284
+
285
+ if str(users_tmp).startswith("/Users/"):
286
+ assert result == project
287
+ else:
288
+ assert result is None or result == project
289
+
290
+ def test_multi_underscore_dirname_via_greedy(self, users_tmp: Path) -> None:
291
+ """my_cool_project — multiple underscores (issue #159)."""
292
+ project = users_tmp / "my_cool_project"
293
+ project.mkdir(parents=True)
294
+
295
+ encoded = "-" + str(project)[1:].replace("/", "-")
296
+ result = _decode_project_path(encoded)
297
+
298
+ if str(users_tmp).startswith("/Users/"):
299
+ assert result == project
300
+ else:
301
+ assert result is None or result == project
302
+
303
  def test_windows_drive_letter_pattern(self) -> None:
304
  """Encoded name -C-MQ2-macros should detect Windows drive letter."""
305
  import sys
tests/test_memory_sync.py ADDED
@@ -0,0 +1,647 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Comprehensive tests for the universal memory sync engine.
2
+
3
+ Tests cover:
4
+ - Core sync: import, export, bidirectional
5
+ - Idempotency and deduplication
6
+ - Fast no-op detection
7
+ - Lineage and governance metadata
8
+ - Claude Code adapter: read/write frontmatter files
9
+ - Codex adapter: read/write AGENTS.md sections
10
+ - Cross-agent interop: save in one agent, find in another
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import hashlib
16
+ import json
17
+ import time
18
+ from dataclasses import dataclass, field
19
+ from datetime import datetime, timezone
20
+ from pathlib import Path
21
+ from typing import Any
22
+
23
+ import pytest
24
+
25
+ from headroom.memory.sync import (
26
+ sync,
27
+ sync_export,
28
+ sync_import,
29
+ )
30
+ from headroom.memory.sync_adapters.claude_code import (
31
+ ClaudeCodeAdapter,
32
+ _parse_frontmatter,
33
+ )
34
+ from headroom.memory.sync_adapters.codex_agent import CodexAdapter
35
+
36
+ # ---------------------------------------------------------------------------
37
+ # Fake backend for testing (no real DB/embeddings needed)
38
+ # ---------------------------------------------------------------------------
39
+
40
+
41
+ @dataclass
42
+ class FakeMemory:
43
+ id: str = ""
44
+ content: str = ""
45
+ user_id: str = ""
46
+ category: str = ""
47
+ importance: float = 0.5
48
+ created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
49
+ metadata: dict[str, Any] = field(default_factory=dict)
50
+
51
+
52
+ class FakeBackend:
53
+ """In-memory backend for testing sync without real DB."""
54
+
55
+ def __init__(self) -> None:
56
+ self._memories: list[FakeMemory] = []
57
+ self._next_id = 1
58
+
59
+ async def get_user_memories(self, user_id: str, limit: int = 500) -> list[FakeMemory]:
60
+ return [m for m in self._memories if m.user_id == user_id][:limit]
61
+
62
+ async def save_memory(
63
+ self,
64
+ content: str,
65
+ user_id: str,
66
+ importance: float = 0.5,
67
+ metadata: dict[str, Any] | None = None,
68
+ **kwargs: Any,
69
+ ) -> FakeMemory:
70
+ mem = FakeMemory(
71
+ id=f"mem_{self._next_id:04d}",
72
+ content=content,
73
+ user_id=user_id,
74
+ importance=importance,
75
+ metadata=metadata or {},
76
+ )
77
+ self._next_id += 1
78
+ self._memories.append(mem)
79
+ return mem
80
+
81
+ def add_memory(self, content: str, user_id: str = "tcms", **kwargs: Any) -> FakeMemory:
82
+ """Sync helper to pre-populate memories."""
83
+ mem = FakeMemory(
84
+ id=f"mem_{self._next_id:04d}",
85
+ content=content,
86
+ user_id=user_id,
87
+ metadata=kwargs.get("metadata", {}),
88
+ importance=kwargs.get("importance", 0.5),
89
+ )
90
+ self._next_id += 1
91
+ self._memories.append(mem)
92
+ return mem
93
+
94
+
95
+ # ---------------------------------------------------------------------------
96
+ # Core sync tests
97
+ # ---------------------------------------------------------------------------
98
+
99
+
100
+ class TestSyncImport:
101
+ """Test importing from agent files into DB."""
102
+
103
+ @pytest.fixture
104
+ def backend(self):
105
+ return FakeBackend()
106
+
107
+ @pytest.fixture
108
+ def claude_dir(self, tmp_path):
109
+ d = tmp_path / "memory"
110
+ d.mkdir()
111
+ return d
112
+
113
+ def _write_claude_memory(
114
+ self, memory_dir: Path, name: str, content: str, **fm_fields: str
115
+ ) -> None:
116
+ slug = name.lower().replace(" ", "_")
117
+ fields = {"name": name, "description": content[:80], "type": "project", **fm_fields}
118
+ fm_lines = ["---"]
119
+ for k, v in fields.items():
120
+ fm_lines.append(f"{k}: {v}")
121
+ fm_lines.append("---")
122
+ (memory_dir / f"{slug}.md").write_text("\n".join(fm_lines) + f"\n\n{content}\n")
123
+
124
+ @pytest.mark.asyncio
125
+ async def test_import_claude_files_to_db(self, backend, claude_dir):
126
+ self._write_claude_memory(claude_dir, "Project codename", "The secret name is TC")
127
+ self._write_claude_memory(claude_dir, "Dark mode", "User prefers dark mode")
128
+
129
+ adapter = ClaudeCodeAdapter(claude_dir)
130
+ imported = await sync_import(backend, adapter, "tcms")
131
+
132
+ assert imported == 2
133
+ mems = await backend.get_user_memories("tcms")
134
+ contents = {m.content for m in mems}
135
+ assert "The secret name is TC" in contents
136
+ assert "User prefers dark mode" in contents
137
+
138
+ @pytest.mark.asyncio
139
+ async def test_import_skips_existing(self, backend, claude_dir):
140
+ """Memories already in DB are not re-imported."""
141
+ backend.add_memory(
142
+ "The secret name is TC",
143
+ metadata={"content_hash": hashlib.sha256(b"The secret name is TC").hexdigest()[:16]},
144
+ )
145
+
146
+ self._write_claude_memory(claude_dir, "Project codename", "The secret name is TC")
147
+ self._write_claude_memory(claude_dir, "New fact", "Something new")
148
+
149
+ adapter = ClaudeCodeAdapter(claude_dir)
150
+ imported = await sync_import(backend, adapter, "tcms")
151
+
152
+ assert imported == 1 # Only "Something new"
153
+
154
+ @pytest.mark.asyncio
155
+ async def test_import_preserves_lineage(self, backend, claude_dir):
156
+ self._write_claude_memory(claude_dir, "Fact", "Important fact")
157
+
158
+ adapter = ClaudeCodeAdapter(claude_dir)
159
+ await sync_import(backend, adapter, "tcms")
160
+
161
+ mems = await backend.get_user_memories("tcms")
162
+ assert len(mems) == 1
163
+ assert mems[0].metadata["source_agent"] == "claude"
164
+ assert mems[0].metadata["source_file"] == "fact.md"
165
+ assert "content_hash" in mems[0].metadata
166
+ assert mems[0].metadata["sync_direction"] == "import"
167
+
168
+
169
+ class TestSyncExport:
170
+ """Test exporting from DB to agent files."""
171
+
172
+ @pytest.fixture
173
+ def backend(self):
174
+ return FakeBackend()
175
+
176
+ @pytest.fixture
177
+ def claude_dir(self, tmp_path):
178
+ d = tmp_path / "memory"
179
+ d.mkdir()
180
+ return d
181
+
182
+ @pytest.mark.asyncio
183
+ async def test_export_new_memory_to_claude_files(self, backend, claude_dir):
184
+ backend.add_memory(
185
+ "Project uses Python 3.12",
186
+ metadata={
187
+ "source_agent": "codex",
188
+ "sync_direction": "export", # Not from claude import
189
+ },
190
+ )
191
+
192
+ adapter = ClaudeCodeAdapter(claude_dir)
193
+ exported = await sync_export(backend, adapter, "tcms")
194
+
195
+ assert exported == 1
196
+ # Check file was created
197
+ md_files = list(claude_dir.glob("headroom_*.md"))
198
+ assert len(md_files) == 1
199
+
200
+ content = md_files[0].read_text()
201
+ assert "Python 3.12" in content
202
+ assert "headroom_id: mem_0001" in content
203
+ assert "source_agent: codex" in content
204
+
205
+ @pytest.mark.asyncio
206
+ async def test_export_skips_claude_originated(self, backend, claude_dir):
207
+ """Don't re-export memories that were imported FROM claude (anti-echo)."""
208
+ backend.add_memory(
209
+ "From claude",
210
+ metadata={
211
+ "source_agent": "claude",
212
+ "sync_direction": "import",
213
+ },
214
+ )
215
+ backend.add_memory(
216
+ "From codex",
217
+ metadata={
218
+ "source_agent": "codex",
219
+ },
220
+ )
221
+
222
+ adapter = ClaudeCodeAdapter(claude_dir)
223
+ exported = await sync_export(backend, adapter, "tcms")
224
+
225
+ assert exported == 1 # Only "From codex"
226
+
227
+ @pytest.mark.asyncio
228
+ async def test_export_updates_memory_md_index(self, backend, claude_dir):
229
+ # Create an existing MEMORY.md
230
+ (claude_dir / "MEMORY.md").write_text("# Memory\n\n## User\n- Some existing entry\n")
231
+
232
+ backend.add_memory("New fact from codex", metadata={"source_agent": "codex"})
233
+
234
+ adapter = ClaudeCodeAdapter(claude_dir)
235
+ await sync_export(backend, adapter, "tcms")
236
+
237
+ memory_md = (claude_dir / "MEMORY.md").read_text()
238
+ assert "Headroom Shared Memory" in memory_md
239
+ assert "New fact from codex" in memory_md
240
+ assert "Some existing entry" in memory_md # Preserved
241
+
242
+
243
+ class TestBidirectionalSync:
244
+ """Test full bidirectional sync."""
245
+
246
+ @pytest.fixture
247
+ def backend(self):
248
+ return FakeBackend()
249
+
250
+ @pytest.fixture
251
+ def claude_dir(self, tmp_path):
252
+ d = tmp_path / "memory"
253
+ d.mkdir()
254
+ return d
255
+
256
+ @pytest.fixture
257
+ def state_path(self, tmp_path):
258
+ return tmp_path / "sync_state.json"
259
+
260
+ def _write_claude_memory(self, memory_dir: Path, name: str, content: str) -> None:
261
+ slug = name.lower().replace(" ", "_")
262
+ fm = f"---\nname: {name}\ndescription: {content[:80]}\ntype: project\n---"
263
+ (memory_dir / f"{slug}.md").write_text(f"{fm}\n\n{content}\n")
264
+
265
+ @pytest.mark.asyncio
266
+ async def test_bidirectional_sync(self, backend, claude_dir, state_path):
267
+ # Claude has a memory file
268
+ self._write_claude_memory(claude_dir, "Convention", "Always use ruff for linting")
269
+
270
+ # DB has a memory from Codex
271
+ backend.add_memory("Secret name is TC", metadata={"source_agent": "codex"})
272
+
273
+ adapter = ClaudeCodeAdapter(claude_dir)
274
+ result = await sync(backend, adapter, "tcms", state_path=state_path, force=True)
275
+
276
+ assert result.imported == 1 # Claude file → DB
277
+ assert result.exported == 1 # Codex memory → Claude file
278
+
279
+ # Verify DB has both
280
+ mems = await backend.get_user_memories("tcms")
281
+ contents = {m.content for m in mems}
282
+ assert "Always use ruff for linting" in contents
283
+ assert "Secret name is TC" in contents
284
+
285
+ # Verify Claude dir has the exported file
286
+ all_files = list(claude_dir.glob("headroom_*.md"))
287
+ assert len(all_files) >= 1
288
+ exported_content = " ".join(f.read_text() for f in all_files)
289
+ assert "TC" in exported_content
290
+
291
+ @pytest.mark.asyncio
292
+ async def test_sync_idempotent(self, backend, claude_dir, state_path):
293
+ """Running sync twice produces no duplicates."""
294
+ self._write_claude_memory(claude_dir, "Fact", "Python 3.12 is required")
295
+ backend.add_memory("Port 8787 is default", metadata={"source_agent": "codex"})
296
+
297
+ adapter = ClaudeCodeAdapter(claude_dir)
298
+
299
+ r1 = await sync(backend, adapter, "tcms", state_path=state_path, force=True)
300
+ assert r1.imported == 1
301
+ assert r1.exported == 1
302
+
303
+ r2 = await sync(backend, adapter, "tcms", state_path=state_path, force=True)
304
+ assert r2.imported == 0 # Already imported
305
+ assert r2.exported == 0 # Already exported
306
+
307
+ # No duplicates in DB
308
+ mems = await backend.get_user_memories("tcms")
309
+ assert len(mems) == 2
310
+
311
+ @pytest.mark.asyncio
312
+ async def test_fast_noop_when_unchanged(self, backend, claude_dir, state_path):
313
+ """Second sync with no changes completes in < 10ms."""
314
+ self._write_claude_memory(claude_dir, "Fact", "Some fact")
315
+
316
+ adapter = ClaudeCodeAdapter(claude_dir)
317
+
318
+ # First sync (populates state)
319
+ await sync(backend, adapter, "tcms", state_path=state_path, force=True)
320
+
321
+ # Second sync (should be fast no-op)
322
+ start = time.monotonic()
323
+ r = await sync(backend, adapter, "tcms", state_path=state_path)
324
+ elapsed = (time.monotonic() - start) * 1000
325
+
326
+ assert r.imported == 0
327
+ assert r.exported == 0
328
+ assert elapsed < 50 # Generous threshold for CI
329
+
330
+
331
+ class TestLineageAndGovernance:
332
+ """Test metadata tracking for audit and lineage."""
333
+
334
+ @pytest.fixture
335
+ def backend(self):
336
+ return FakeBackend()
337
+
338
+ @pytest.fixture
339
+ def claude_dir(self, tmp_path):
340
+ d = tmp_path / "memory"
341
+ d.mkdir()
342
+ return d
343
+
344
+ @pytest.mark.asyncio
345
+ async def test_lineage_tracks_source_agent(self, backend, claude_dir):
346
+ fm = "---\nname: test\ndescription: test\ntype: project\n---"
347
+ (claude_dir / "test.md").write_text(f"{fm}\n\nClaude discovered this\n")
348
+
349
+ adapter = ClaudeCodeAdapter(claude_dir)
350
+ await sync_import(backend, adapter, "tcms")
351
+
352
+ mems = await backend.get_user_memories("tcms")
353
+ assert mems[0].metadata["source_agent"] == "claude"
354
+
355
+ @pytest.mark.asyncio
356
+ async def test_exported_files_have_headroom_id(self, backend, claude_dir):
357
+ backend.add_memory("From codex", metadata={"source_agent": "codex"})
358
+
359
+ adapter = ClaudeCodeAdapter(claude_dir)
360
+ await sync_export(backend, adapter, "tcms")
361
+
362
+ md_files = list(claude_dir.glob("headroom_*.md"))
363
+ assert len(md_files) == 1
364
+ content = md_files[0].read_text()
365
+ assert "headroom_id:" in content
366
+
367
+ @pytest.mark.asyncio
368
+ async def test_sync_state_records_timestamps(self, backend, claude_dir, tmp_path):
369
+ state_path = tmp_path / "state.json"
370
+ fm = "---\nname: t\ndescription: t\ntype: project\n---"
371
+ (claude_dir / "t.md").write_text(f"{fm}\n\nFact\n")
372
+
373
+ adapter = ClaudeCodeAdapter(claude_dir)
374
+ await sync(backend, adapter, "tcms", state_path=state_path, force=True)
375
+
376
+ state = json.loads(state_path.read_text())
377
+ key = "claude:tcms"
378
+ assert key in state
379
+ assert "last_sync" in state[key]
380
+ assert "agent_fingerprint" in state[key]
381
+ assert "db_fingerprint" in state[key]
382
+
383
+
384
+ # ---------------------------------------------------------------------------
385
+ # Claude Code adapter tests
386
+ # ---------------------------------------------------------------------------
387
+
388
+
389
+ class TestClaudeCodeAdapter:
390
+ """Test Claude Code adapter read/write."""
391
+
392
+ @pytest.fixture
393
+ def memory_dir(self, tmp_path):
394
+ d = tmp_path / "memory"
395
+ d.mkdir()
396
+ return d
397
+
398
+ def test_parse_frontmatter(self):
399
+ content = "---\nname: Test\ntype: project\n---\n\nBody content here."
400
+ fm, body = _parse_frontmatter(content)
401
+ assert fm["name"] == "Test"
402
+ assert fm["type"] == "project"
403
+ assert body == "Body content here."
404
+
405
+ def test_parse_frontmatter_no_frontmatter(self):
406
+ content = "Just plain content."
407
+ fm, body = _parse_frontmatter(content)
408
+ assert fm == {}
409
+ assert body == "Just plain content."
410
+
411
+ @pytest.mark.asyncio
412
+ async def test_read_memories_skips_memory_md(self, memory_dir):
413
+ (memory_dir / "MEMORY.md").write_text("# Index\n- entry")
414
+ (memory_dir / "fact.md").write_text(
415
+ "---\nname: Fact\ntype: project\n---\n\nImportant fact."
416
+ )
417
+
418
+ adapter = ClaudeCodeAdapter(memory_dir)
419
+ mems = await adapter.read_memories()
420
+
421
+ assert len(mems) == 1
422
+ assert mems[0].content == "Important fact."
423
+ assert mems[0].source_file == "fact.md"
424
+
425
+ @pytest.mark.asyncio
426
+ async def test_write_creates_valid_md(self, memory_dir):
427
+ adapter = ClaudeCodeAdapter(memory_dir)
428
+ written = await adapter.write_memories(
429
+ [
430
+ {
431
+ "content": "Project uses FastAPI",
432
+ "category": "architecture",
433
+ "headroom_id": "mem_001",
434
+ "source_agent": "codex",
435
+ "content_hash": "abc123",
436
+ }
437
+ ]
438
+ )
439
+
440
+ assert written == 1
441
+ files = list(memory_dir.glob("headroom_*.md"))
442
+ assert len(files) == 1
443
+
444
+ content = files[0].read_text()
445
+ fm, body = _parse_frontmatter(content)
446
+ assert fm["type"] == "architecture"
447
+ assert fm["headroom_id"] == "mem_001"
448
+ assert fm["source_agent"] == "codex"
449
+ assert "FastAPI" in body
450
+
451
+ def test_fingerprint_changes_on_modification(self, memory_dir):
452
+ (memory_dir / "test.md").write_text("content 1")
453
+
454
+ adapter = ClaudeCodeAdapter(memory_dir)
455
+ fp1 = adapter.fingerprint()
456
+
457
+ (memory_dir / "test.md").write_text("content 2")
458
+ fp2 = adapter.fingerprint()
459
+
460
+ assert fp1 != fp2
461
+
462
+ def test_fingerprint_stable_when_unchanged(self, memory_dir):
463
+ (memory_dir / "test.md").write_text("stable content")
464
+
465
+ adapter = ClaudeCodeAdapter(memory_dir)
466
+ assert adapter.fingerprint() == adapter.fingerprint()
467
+
468
+ def test_fingerprint_empty_dir(self, tmp_path):
469
+ empty = tmp_path / "empty"
470
+ empty.mkdir()
471
+ adapter = ClaudeCodeAdapter(empty)
472
+ assert adapter.fingerprint() == "empty"
473
+
474
+
475
+ # ---------------------------------------------------------------------------
476
+ # Codex adapter tests
477
+ # ---------------------------------------------------------------------------
478
+
479
+
480
+ class TestCodexAdapter:
481
+ """Test Codex AGENTS.md adapter."""
482
+
483
+ @pytest.fixture
484
+ def agents_md(self, tmp_path):
485
+ return tmp_path / "AGENTS.md"
486
+
487
+ @pytest.mark.asyncio
488
+ async def test_read_from_agents_md(self, agents_md):
489
+ agents_md.write_text(
490
+ "# Instructions\n\n"
491
+ "<!-- headroom:memory:start -->\n"
492
+ "## Headroom Shared Memory\n\n"
493
+ "- Secret name is TC\n"
494
+ "- Uses Python 3.12\n"
495
+ "<!-- headroom:memory:end -->\n"
496
+ )
497
+
498
+ adapter = CodexAdapter(agents_md)
499
+ mems = await adapter.read_memories()
500
+
501
+ assert len(mems) == 2
502
+ assert mems[0].content == "Secret name is TC"
503
+ assert mems[1].content == "Uses Python 3.12"
504
+
505
+ @pytest.mark.asyncio
506
+ async def test_write_to_agents_md(self, agents_md):
507
+ agents_md.write_text("# Existing instructions\n")
508
+
509
+ adapter = CodexAdapter(agents_md)
510
+ written = await adapter.write_memories(
511
+ [
512
+ {"content": "Port 8787 is default"},
513
+ {"content": "Uses ruff for linting"},
514
+ ]
515
+ )
516
+
517
+ assert written == 2
518
+ content = agents_md.read_text()
519
+ assert "headroom:memory:start" in content
520
+ assert "Port 8787 is default" in content
521
+ assert "Uses ruff for linting" in content
522
+ assert "Existing instructions" in content # Preserved
523
+
524
+ @pytest.mark.asyncio
525
+ async def test_write_replaces_existing_section(self, agents_md):
526
+ agents_md.write_text(
527
+ "# Instructions\n\n"
528
+ "<!-- headroom:memory:start -->\n"
529
+ "## Old\n- old fact\n"
530
+ "<!-- headroom:memory:end -->\n"
531
+ )
532
+
533
+ adapter = CodexAdapter(agents_md)
534
+ await adapter.write_memories([{"content": "new fact"}])
535
+
536
+ content = agents_md.read_text()
537
+ assert "new fact" in content
538
+ assert "old fact" not in content
539
+
540
+ @pytest.mark.asyncio
541
+ async def test_read_empty_agents_md(self, agents_md):
542
+ agents_md.write_text("# No memory section\n")
543
+ adapter = CodexAdapter(agents_md)
544
+ mems = await adapter.read_memories()
545
+ assert mems == []
546
+
547
+ @pytest.mark.asyncio
548
+ async def test_read_nonexistent_file(self, tmp_path):
549
+ adapter = CodexAdapter(tmp_path / "nonexistent.md")
550
+ mems = await adapter.read_memories()
551
+ assert mems == []
552
+
553
+
554
+ # ---------------------------------------------------------------------------
555
+ # Cross-agent integration tests
556
+ # ---------------------------------------------------------------------------
557
+
558
+
559
+ class TestCrossAgentInterop:
560
+ """Test that memories flow between agents via sync."""
561
+
562
+ @pytest.fixture
563
+ def backend(self):
564
+ return FakeBackend()
565
+
566
+ @pytest.fixture
567
+ def claude_dir(self, tmp_path):
568
+ d = tmp_path / "claude_memory"
569
+ d.mkdir()
570
+ return d
571
+
572
+ @pytest.fixture
573
+ def agents_md(self, tmp_path):
574
+ return tmp_path / "AGENTS.md"
575
+
576
+ @pytest.fixture
577
+ def state_path(self, tmp_path):
578
+ return tmp_path / "state.json"
579
+
580
+ @pytest.mark.asyncio
581
+ async def test_codex_saves_claude_finds(self, backend, claude_dir, state_path):
582
+ """Memory saved via Codex MCP appears in Claude's files after sync."""
583
+ # Simulate Codex saving via MCP (directly to backend)
584
+ backend.add_memory(
585
+ "Secret name is TC",
586
+ metadata={"source_agent": "codex", "content_hash": "x"},
587
+ )
588
+
589
+ # Sync to Claude
590
+ adapter = ClaudeCodeAdapter(claude_dir)
591
+ result = await sync(backend, adapter, "tcms", state_path=state_path, force=True)
592
+
593
+ assert result.exported == 1
594
+
595
+ # Claude's memory dir should have the file
596
+ files = list(claude_dir.glob("headroom_*.md"))
597
+ assert len(files) == 1
598
+ assert "TC" in files[0].read_text()
599
+
600
+ @pytest.mark.asyncio
601
+ async def test_claude_saves_codex_finds(self, backend, claude_dir, agents_md, state_path):
602
+ """Memory saved in Claude's files appears in Codex AGENTS.md after sync."""
603
+ # Claude has a memory
604
+ fm = "---\nname: Linting\ndescription: use ruff\ntype: project\n---"
605
+ (claude_dir / "linting.md").write_text(f"{fm}\n\nAlways use ruff for linting\n")
606
+
607
+ # Sync Claude → DB
608
+ claude_adapter = ClaudeCodeAdapter(claude_dir)
609
+ await sync(backend, claude_adapter, "tcms", state_path=state_path, force=True)
610
+
611
+ # Sync DB → Codex AGENTS.md
612
+ codex_adapter = CodexAdapter(agents_md)
613
+ result = await sync(backend, codex_adapter, "tcms", state_path=state_path, force=True)
614
+
615
+ assert result.exported >= 1
616
+ assert "ruff" in agents_md.read_text()
617
+
618
+ @pytest.mark.asyncio
619
+ async def test_full_round_trip(self, backend, claude_dir, agents_md, state_path):
620
+ """Full round trip: Claude → DB → Codex, Codex → DB → Claude."""
621
+ # Claude has a memory
622
+ fm = "---\nname: Framework\ntype: project\n---"
623
+ (claude_dir / "framework.md").write_text(f"{fm}\n\nUses FastAPI\n")
624
+
625
+ # Codex has a memory (in DB via MCP)
626
+ backend.add_memory("Port is 8787", metadata={"source_agent": "codex"})
627
+
628
+ # Sync both adapters
629
+ claude_adapter = ClaudeCodeAdapter(claude_dir)
630
+ codex_adapter = CodexAdapter(agents_md)
631
+
632
+ await sync(backend, claude_adapter, "tcms", state_path=state_path, force=True)
633
+ await sync(backend, codex_adapter, "tcms", state_path=state_path, force=True)
634
+
635
+ # DB has both memories
636
+ mems = await backend.get_user_memories("tcms")
637
+ contents = {m.content for m in mems}
638
+ assert "Uses FastAPI" in contents
639
+ assert "Port is 8787" in contents
640
+
641
+ # Claude files have Codex's memory
642
+ all_claude = " ".join(f.read_text() for f in claude_dir.glob("headroom_*.md"))
643
+ assert "8787" in all_claude
644
+
645
+ # AGENTS.md has both (from DB)
646
+ agents_content = agents_md.read_text()
647
+ assert "FastAPI" in agents_content or "8787" in agents_content
tests/test_package_init_lazy.py CHANGED
@@ -33,7 +33,8 @@ def test_headroom_import_stays_lazy() -> None:
33
  )
34
 
35
  data = json.loads(result.stdout.strip())
36
- assert data["version"] == "0.5.21"
 
37
  assert data["cache_loaded"] is False
38
  assert data["models_registry_loaded"] is False
39
  assert data["memory_loaded"] is False
 
33
  )
34
 
35
  data = json.loads(result.stdout.strip())
36
+ # Version is a non-empty string; don't hardcode a specific value.
37
+ assert isinstance(data["version"], str) and data["version"]
38
  assert data["cache_loaded"] is False
39
  assert data["models_registry_loaded"] is False
40
  assert data["memory_loaded"] is False
tests/test_transforms/test_kompress_compressor.py CHANGED
@@ -352,9 +352,8 @@ class TestUnloadKompressModel:
352
  import headroom.transforms.kompress_compressor as kmod
353
  from headroom.transforms.kompress_compressor import unload_kompress_model
354
 
355
- # Ensure no model is loaded (previous tests may have set the global)
356
- kmod._kompress_model = None
357
- kmod._kompress_tokenizer = None
358
 
359
  # Should return False when no model is loaded
360
  assert unload_kompress_model() is False
 
352
  import headroom.transforms.kompress_compressor as kmod
353
  from headroom.transforms.kompress_compressor import unload_kompress_model
354
 
355
+ # Ensure no model is loaded (previous tests may have set the cache)
356
+ kmod._kompress_cache.clear()
 
357
 
358
  # Should return False when no model is loaded
359
  assert unload_kompress_model() is False
tests/test_transforms/test_smart_crusher_bugs.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Regression tests for SmartCrusher bugs.
2
+
3
+ Bug 1: _crush_number_array mixes types (string summary + numbers),
4
+ violating the schema-preserving guarantee.
5
+ Bug 2: _current_field_semantics is shared instance state, creating
6
+ a race condition when crushing concurrently.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import json
12
+ from concurrent.futures import ThreadPoolExecutor, as_completed
13
+
14
+ from headroom import SmartCrusherConfig
15
+ from headroom.transforms.smart_crusher import SmartCrusher
16
+
17
+ # ---------------------------------------------------------------------------
18
+ # Fixtures
19
+ # ---------------------------------------------------------------------------
20
+
21
+
22
+ def _make_crusher(max_items: int = 10, min_items: int = 3) -> SmartCrusher:
23
+ config = SmartCrusherConfig(
24
+ enabled=True,
25
+ min_items_to_analyze=min_items,
26
+ min_tokens_to_crush=0,
27
+ max_items_after_crush=max_items,
28
+ variance_threshold=2.0,
29
+ )
30
+ return SmartCrusher(config=config)
31
+
32
+
33
+ # ---------------------------------------------------------------------------
34
+ # Bug 1: Number array type mixing
35
+ # ---------------------------------------------------------------------------
36
+
37
+
38
+ class TestNumberArraySchemaPreservation:
39
+ """_crush_number_array must return only original numeric values.
40
+
41
+ Previously it prepended a stats summary string, producing
42
+ [string, int, int, ...] which violates the schema-preserving
43
+ guarantee and breaks type-aware JSON consumers.
44
+ """
45
+
46
+ def test_crushed_number_array_contains_only_numbers(self) -> None:
47
+ """Every element of the crushed array must be int or float."""
48
+ crusher = _make_crusher(max_items=10)
49
+ numbers = list(range(50)) # 0..49, well above the n<=8 passthrough
50
+ crushed, strategy = crusher._crush_number_array(numbers)
51
+
52
+ for i, item in enumerate(crushed):
53
+ assert isinstance(item, int | float), (
54
+ f"Item {i} is {type(item).__name__} = {item!r}, expected int/float. "
55
+ f"Schema-preserving guarantee violated."
56
+ )
57
+
58
+ def test_crushed_number_array_subset_of_original(self) -> None:
59
+ """Every value in the crushed array must exist in the original."""
60
+ crusher = _make_crusher(max_items=10)
61
+ numbers = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120]
62
+ crushed, _ = crusher._crush_number_array(numbers)
63
+
64
+ original_set = set(numbers)
65
+ for item in crushed:
66
+ assert item in original_set, (
67
+ f"Value {item!r} not in original array — generated content detected"
68
+ )
69
+
70
+ def test_stats_summary_in_strategy_not_in_array(self) -> None:
71
+ """Statistics should be communicated via strategy string, not array content."""
72
+ crusher = _make_crusher(max_items=5)
73
+ numbers = list(range(100))
74
+ crushed, strategy = crusher._crush_number_array(numbers)
75
+
76
+ # Strategy should contain stats info
77
+ assert "number:" in strategy
78
+
79
+ # Array should not contain any strings
80
+ strings_in_result = [x for x in crushed if isinstance(x, str)]
81
+ assert strings_in_result == [], f"Found string(s) in numeric array: {strings_in_result}"
82
+
83
+ def test_number_array_passthrough_for_small(self) -> None:
84
+ """Arrays with n <= 8 should pass through unchanged."""
85
+ crusher = _make_crusher()
86
+ small = [1, 2, 3, 4, 5]
87
+ crushed, strategy = crusher._crush_number_array(small)
88
+ assert crushed == small
89
+ assert strategy == "number:passthrough"
90
+
91
+ def test_number_array_preserves_outliers(self) -> None:
92
+ """Outlier values should be preserved in the crushed output."""
93
+ crusher = _make_crusher(max_items=10)
94
+ # Normal range + extreme outlier
95
+ numbers = [10] * 20 + [10000]
96
+ crushed, strategy = crusher._crush_number_array(numbers)
97
+ assert 10000 in crushed, "Outlier value 10000 was dropped"
98
+
99
+ def test_number_array_preserves_boundaries(self) -> None:
100
+ """First and last values should always be kept."""
101
+ crusher = _make_crusher(max_items=5)
102
+ numbers = list(range(100))
103
+ crushed, strategy = crusher._crush_number_array(numbers)
104
+ assert crushed[0] == 0, "First value not preserved"
105
+ assert numbers[-1] in crushed, "Last value not preserved"
106
+
107
+ def test_non_finite_passthrough(self) -> None:
108
+ """All-NaN/Inf arrays should return unchanged."""
109
+ crusher = _make_crusher()
110
+ nans = [float("nan")] * 10
111
+ crushed, strategy = crusher._crush_number_array(nans)
112
+ assert strategy == "number:no_finite"
113
+ assert len(crushed) == 10
114
+
115
+ def test_full_crush_pipeline_number_array_types(self) -> None:
116
+ """End-to-end: crushing a JSON number array via the public API."""
117
+ crusher = _make_crusher(max_items=10)
118
+ content = json.dumps(list(range(50)))
119
+ result, was_modified, info = crusher._smart_crush_content(content)
120
+
121
+ if was_modified:
122
+ parsed = json.loads(result)
123
+ assert isinstance(parsed, list)
124
+ for item in parsed:
125
+ assert isinstance(item, int | float), (
126
+ f"Public API returned non-numeric item {item!r} in number array"
127
+ )
128
+
129
+
130
+ # ---------------------------------------------------------------------------
131
+ # Bug 2: Race condition on _current_field_semantics
132
+ # ---------------------------------------------------------------------------
133
+
134
+
135
+ class TestFieldSemanticsThreadSafety:
136
+ """_current_field_semantics must not leak between concurrent crushes.
137
+
138
+ Previously it was stored as instance state (self._current_field_semantics)
139
+ which created a race condition when the same SmartCrusher instance
140
+ was used from multiple threads.
141
+ """
142
+
143
+ def test_concurrent_crushes_no_cross_contamination(self) -> None:
144
+ """Two concurrent crushes must not share field_semantics state."""
145
+ crusher = _make_crusher(max_items=5)
146
+
147
+ # Two different array payloads
148
+ payload_a = json.dumps([{"name": f"item_{i}", "value": i} for i in range(20)])
149
+ payload_b = json.dumps([{"key": f"k_{i}", "score": i * 0.1} for i in range(20)])
150
+
151
+ results: dict[str, str] = {}
152
+ errors: list[Exception] = []
153
+
154
+ def crush_task(label: str, content: str) -> None:
155
+ try:
156
+ result, modified, info = crusher._smart_crush_content(content)
157
+ results[label] = result
158
+ except Exception as e:
159
+ errors.append(e)
160
+
161
+ with ThreadPoolExecutor(max_workers=4) as executor:
162
+ futures = []
163
+ # Run many concurrent crushes to increase race probability
164
+ for i in range(20):
165
+ futures.append(executor.submit(crush_task, f"a_{i}", payload_a))
166
+ futures.append(executor.submit(crush_task, f"b_{i}", payload_b))
167
+ for f in as_completed(futures):
168
+ f.result() # Re-raise exceptions
169
+
170
+ assert not errors, f"Concurrent crushes raised errors: {errors}"
171
+
172
+ # After all crushes, thread-local state must be clean
173
+ tl = getattr(crusher, "_thread_local", None)
174
+ if tl is not None:
175
+ semantics = getattr(tl, "field_semantics", None)
176
+ assert semantics is None, f"field_semantics leaked in thread-local: {semantics}"
177
+
178
+
179
+ # ---------------------------------------------------------------------------
180
+ # Issue 7: Recursion depth limit
181
+ # ---------------------------------------------------------------------------
182
+
183
+
184
+ class TestRecursionDepthLimit:
185
+ """_process_value must not crash on deeply nested JSON."""
186
+
187
+ def test_deeply_nested_json_does_not_crash(self) -> None:
188
+ """Nesting deeper than _MAX_PROCESS_DEPTH should return value unchanged."""
189
+ crusher = _make_crusher()
190
+ # Build a 100-level nested structure
191
+ nested: dict = {"leaf": "value"}
192
+ for _i in range(100):
193
+ nested = {"level": nested}
194
+
195
+ content = json.dumps(nested)
196
+ result, was_modified, info = crusher._smart_crush_content(content)
197
+ # Should not raise RecursionError
198
+ parsed = json.loads(result)
199
+ # The deep structure should be preserved (returned as-is past depth limit)
200
+ assert isinstance(parsed, dict)
201
+
202
+ def test_deeply_nested_list_does_not_crash(self) -> None:
203
+ """Deeply nested lists should also be handled safely."""
204
+ crusher = _make_crusher()
205
+ nested: list = ["leaf"]
206
+ for _i in range(100):
207
+ nested = [nested]
208
+
209
+ content = json.dumps(nested)
210
+ result, was_modified, info = crusher._smart_crush_content(content)
211
+ parsed = json.loads(result)
212
+ assert isinstance(parsed, list)
tests/test_transforms/test_universal_json_crush.py CHANGED
@@ -172,14 +172,15 @@ class TestCrushNumberArray:
172
  assert len(crushed) < len(numbers)
173
  assert "number:adaptive" in strategy
174
 
175
- def test_summary_prepended(self, crusher):
176
  numbers = list(range(100))
177
  crushed, strategy = crusher._crush_number_array(numbers)
178
- # First element should be the stats summary string
179
- assert isinstance(crushed[0], str)
180
- assert "numbers:" in crushed[0]
181
- assert "min=" in crushed[0]
182
- assert "max=" in crushed[0]
 
183
 
184
  def test_outliers_preserved(self, crusher):
185
  # Normal values around 50 with one extreme outlier
@@ -193,13 +194,13 @@ class TestCrushNumberArray:
193
  crushed, strategy = crusher._crush_number_array(numbers)
194
  # With all identical, should compress heavily
195
  # Summary + a few representatives
196
- numeric_values = [v for v in crushed if isinstance(v, (int, float))]
197
  assert all(v == 42.0 for v in numeric_values)
198
 
199
  def test_first_last_kept(self, crusher):
200
  numbers = list(range(50))
201
  crushed, strategy = crusher._crush_number_array(numbers)
202
- numeric_values = [v for v in crushed if isinstance(v, (int, float))]
203
  assert 0 in numeric_values # First
204
  assert 49 in numeric_values # Last
205
 
@@ -207,7 +208,7 @@ class TestCrushNumberArray:
207
  # Stable at 10, then jumps to 100
208
  numbers = [10.0] * 50 + [100.0] * 50
209
  crushed, strategy = crusher_large_k._crush_number_array(numbers)
210
- numeric_values = [v for v in crushed if isinstance(v, (int, float))]
211
  # Both 10.0 and 100.0 should be present
212
  assert 10.0 in numeric_values
213
  assert 100.0 in numeric_values
@@ -215,23 +216,24 @@ class TestCrushNumberArray:
215
  def test_nan_inf_filtered(self, crusher):
216
  numbers = [1.0, 2.0, float("nan"), float("inf"), 3.0] * 10
217
  crushed, strategy = crusher._crush_number_array(numbers)
218
- # Should not crash; stats should be based on finite values
219
- assert isinstance(crushed[0], str)
 
220
 
221
  def test_integers_preserved_as_int(self, crusher):
222
  numbers = list(range(50))
223
  crushed, strategy = crusher._crush_number_array(numbers)
224
- numeric_values = [v for v in crushed if isinstance(v, (int, float))]
225
  # Integers should remain integers (not converted to float)
226
  assert any(isinstance(v, int) for v in numeric_values)
227
 
228
  def test_statistics_accuracy(self, crusher):
229
  numbers = list(range(1, 101)) # 1 to 100
230
  crushed, strategy = crusher._crush_number_array(numbers)
231
- summary = crushed[0]
232
- assert "min=1" in summary
233
- assert "max=100" in summary
234
- assert "mean=50.5" in summary
235
 
236
 
237
  # =====================================================================
@@ -382,7 +384,7 @@ class TestSafetyGuarantees:
382
  assert items[-1] in crushed
383
  else:
384
  crushed, _ = crusher._crush_number_array(items)
385
- numeric = [v for v in crushed if isinstance(v, (int, float))]
386
  assert items[0] in numeric
387
  assert items[-1] in numeric
388
 
@@ -399,7 +401,7 @@ class TestSafetyGuarantees:
399
  """Arrays below min_items_to_analyze pass through unchanged."""
400
  if all(isinstance(i, str) for i in items):
401
  crushed, strategy = crusher._crush_string_array(items)
402
- elif all(isinstance(i, (int, float)) for i in items):
403
  crushed, strategy = crusher._crush_number_array(items)
404
  else:
405
  crushed, strategy = crusher._crush_mixed_array(items)
 
172
  assert len(crushed) < len(numbers)
173
  assert "number:adaptive" in strategy
174
 
175
+ def test_stats_in_strategy_not_array(self, crusher):
176
  numbers = list(range(100))
177
  crushed, strategy = crusher._crush_number_array(numbers)
178
+ # Stats should be in the strategy string, not in the array
179
+ assert "min=" in strategy
180
+ assert "max=" in strategy
181
+ # Array should contain only numbers (schema-preserving)
182
+ for item in crushed:
183
+ assert isinstance(item, int | float)
184
 
185
  def test_outliers_preserved(self, crusher):
186
  # Normal values around 50 with one extreme outlier
 
194
  crushed, strategy = crusher._crush_number_array(numbers)
195
  # With all identical, should compress heavily
196
  # Summary + a few representatives
197
+ numeric_values = [v for v in crushed if isinstance(v, int | float)]
198
  assert all(v == 42.0 for v in numeric_values)
199
 
200
  def test_first_last_kept(self, crusher):
201
  numbers = list(range(50))
202
  crushed, strategy = crusher._crush_number_array(numbers)
203
+ numeric_values = [v for v in crushed if isinstance(v, int | float)]
204
  assert 0 in numeric_values # First
205
  assert 49 in numeric_values # Last
206
 
 
208
  # Stable at 10, then jumps to 100
209
  numbers = [10.0] * 50 + [100.0] * 50
210
  crushed, strategy = crusher_large_k._crush_number_array(numbers)
211
+ numeric_values = [v for v in crushed if isinstance(v, int | float)]
212
  # Both 10.0 and 100.0 should be present
213
  assert 10.0 in numeric_values
214
  assert 100.0 in numeric_values
 
216
  def test_nan_inf_filtered(self, crusher):
217
  numbers = [1.0, 2.0, float("nan"), float("inf"), 3.0] * 10
218
  crushed, strategy = crusher._crush_number_array(numbers)
219
+ # Should not crash; stats in strategy based on finite values
220
+ assert "min=" in strategy
221
+ assert "max=" in strategy
222
 
223
  def test_integers_preserved_as_int(self, crusher):
224
  numbers = list(range(50))
225
  crushed, strategy = crusher._crush_number_array(numbers)
226
+ numeric_values = [v for v in crushed if isinstance(v, int | float)]
227
  # Integers should remain integers (not converted to float)
228
  assert any(isinstance(v, int) for v in numeric_values)
229
 
230
  def test_statistics_accuracy(self, crusher):
231
  numbers = list(range(1, 101)) # 1 to 100
232
  crushed, strategy = crusher._crush_number_array(numbers)
233
+ # Stats are in the strategy string
234
+ assert "min=1" in strategy
235
+ assert "max=100" in strategy
236
+ assert "mean=50.5" in strategy
237
 
238
 
239
  # =====================================================================
 
384
  assert items[-1] in crushed
385
  else:
386
  crushed, _ = crusher._crush_number_array(items)
387
+ numeric = [v for v in crushed if isinstance(v, int | float)]
388
  assert items[0] in numeric
389
  assert items[-1] in numeric
390
 
 
401
  """Arrays below min_items_to_analyze pass through unchanged."""
402
  if all(isinstance(i, str) for i in items):
403
  crushed, strategy = crusher._crush_string_array(items)
404
+ elif all(isinstance(i, int | float) for i in items):
405
  crushed, strategy = crusher._crush_number_array(items)
406
  else:
407
  crushed, strategy = crusher._crush_mixed_array(items)
tests/test_ws_memory_relay.py ADDED
@@ -0,0 +1,523 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for WebSocket memory tool interception in the Codex Responses API relay.
2
+
3
+ Verifies that:
4
+ 1. Memory tool events are suppressed from reaching Codex
5
+ 2. response.created is buffered and only flushed for non-memory responses
6
+ 3. Tool execution happens and continuation is sent upstream
7
+ 4. Non-memory responses pass through with normal streaming latency
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import json
13
+ from dataclasses import dataclass, field
14
+ from typing import Any
15
+
16
+ from headroom.proxy.memory_handler import MEMORY_TOOL_NAMES
17
+
18
+ # ---------------------------------------------------------------------------
19
+ # Minimal WS relay state machine (mirrors the logic in openai.py)
20
+ # ---------------------------------------------------------------------------
21
+
22
+
23
+ @dataclass
24
+ class WSMemoryRelayState:
25
+ """State machine for WS event processing with memory tool interception.
26
+
27
+ This mirrors the logic in ``_upstream_to_client`` but is decoupled from
28
+ actual WebSocket I/O so it can be unit-tested.
29
+ """
30
+
31
+ memory_tool_names: set[str] = field(default_factory=lambda: set(MEMORY_TOOL_NAMES))
32
+
33
+ # Per-response state (reset after each response.completed)
34
+ event_buffer: list[str] = field(default_factory=list)
35
+ decided: bool = False
36
+ suppress_response: bool = False
37
+ pending_function_calls: list[dict[str, Any]] = field(default_factory=list)
38
+ last_response_id: str | None = None
39
+
40
+ def process_event(self, msg_str: str) -> dict[str, Any]:
41
+ """Process a single upstream WS event.
42
+
43
+ Returns a dict with possible keys:
44
+ relay: list[str] — events to send to Codex
45
+ execute_tools: list — function_call items to execute
46
+ send_continuation: dict — continuation payload to send upstream
47
+ """
48
+ result: dict[str, Any] = {"relay": [], "execute_tools": [], "send_continuation": None}
49
+
50
+ try:
51
+ event = json.loads(msg_str)
52
+ except (json.JSONDecodeError, TypeError):
53
+ # Not JSON — always relay
54
+ result["relay"].append(msg_str)
55
+ return result
56
+
57
+ event_type = event.get("type", "")
58
+
59
+ # ---- Phase 1: Buffering (before first output item) ----
60
+ if not self.decided:
61
+ self.event_buffer.append(msg_str)
62
+
63
+ if event_type == "response.output_item.added":
64
+ item = event.get("item", {})
65
+ if (
66
+ item.get("type") == "function_call"
67
+ and item.get("name") in self.memory_tool_names
68
+ ):
69
+ # Memory tool is first output → suppress entire response
70
+ self.suppress_response = True
71
+ self.decided = True
72
+ self.event_buffer.clear()
73
+ else:
74
+ # Non-memory item → flush buffer and pass through
75
+ self.decided = True
76
+ result["relay"].extend(self.event_buffer)
77
+ self.event_buffer.clear()
78
+
79
+ elif event_type == "response.completed":
80
+ # Response completed with no output items — flush all
81
+ self.decided = True
82
+ result["relay"].extend(self.event_buffer)
83
+ self.event_buffer.clear()
84
+
85
+ return result
86
+
87
+ # ---- Phase 2a: Suppress mode (memory tool response) ----
88
+ if self.suppress_response:
89
+ # Capture completed function_call items
90
+ if event_type == "response.output_item.done":
91
+ item = event.get("item", {})
92
+ if (
93
+ item.get("type") == "function_call"
94
+ and item.get("name") in self.memory_tool_names
95
+ ):
96
+ self.pending_function_calls.append(item)
97
+
98
+ if event_type == "response.completed":
99
+ resp = event.get("response", {})
100
+ self.last_response_id = resp.get("id")
101
+
102
+ if self.pending_function_calls:
103
+ result["execute_tools"] = list(self.pending_function_calls)
104
+ # Build continuation payload
105
+ # (actual tool execution + output building done by caller)
106
+ result["send_continuation"] = {
107
+ "response_id": self.last_response_id,
108
+ "function_calls": list(self.pending_function_calls),
109
+ }
110
+
111
+ # Reset for next response (continuation)
112
+ self._reset_response_state()
113
+
114
+ return result # Nothing relayed in suppress mode
115
+
116
+ # ---- Phase 2b: Pass-through mode (normal response) ----
117
+ result["relay"].append(msg_str)
118
+ return result
119
+
120
+ def _reset_response_state(self) -> None:
121
+ """Reset per-response state for the next response."""
122
+ self.event_buffer.clear()
123
+ self.decided = False
124
+ self.suppress_response = False
125
+ self.pending_function_calls.clear()
126
+ self.last_response_id = None
127
+
128
+
129
+ # ---------------------------------------------------------------------------
130
+ # Test helpers
131
+ # ---------------------------------------------------------------------------
132
+
133
+
134
+ def _make_event(event_type: str, **kwargs: Any) -> str:
135
+ data: dict[str, Any] = {"type": event_type}
136
+ data.update(kwargs)
137
+ return json.dumps(data)
138
+
139
+
140
+ def _response_created(response_id: str = "resp_A") -> str:
141
+ return _make_event("response.created", response={"id": response_id})
142
+
143
+
144
+ def _output_item_added_text(index: int = 0) -> str:
145
+ return _make_event(
146
+ "response.output_item.added",
147
+ output_index=index,
148
+ item={"type": "message", "role": "assistant"},
149
+ )
150
+
151
+
152
+ def _output_item_added_function_call(name: str, index: int = 0, call_id: str = "call_1") -> str:
153
+ return _make_event(
154
+ "response.output_item.added",
155
+ output_index=index,
156
+ item={"type": "function_call", "name": name, "call_id": call_id},
157
+ )
158
+
159
+
160
+ def _function_call_args_delta(index: int = 0, delta: str = '{"qu') -> str:
161
+ return _make_event(
162
+ "response.function_call_arguments.delta",
163
+ output_index=index,
164
+ delta=delta,
165
+ )
166
+
167
+
168
+ def _function_call_args_done(index: int = 0, arguments: str = '{"query": "codename"}') -> str:
169
+ return _make_event(
170
+ "response.function_call_arguments.done",
171
+ output_index=index,
172
+ arguments=arguments,
173
+ )
174
+
175
+
176
+ def _output_item_done_function_call(
177
+ name: str,
178
+ index: int = 0,
179
+ call_id: str = "call_1",
180
+ arguments: str = '{"query": "codename"}',
181
+ ) -> str:
182
+ return _make_event(
183
+ "response.output_item.done",
184
+ output_index=index,
185
+ item={
186
+ "type": "function_call",
187
+ "name": name,
188
+ "call_id": call_id,
189
+ "arguments": arguments,
190
+ },
191
+ )
192
+
193
+
194
+ def _output_text_delta(index: int = 0, text: str = "Hello") -> str:
195
+ return _make_event(
196
+ "response.output_text.delta",
197
+ output_index=index,
198
+ delta=text,
199
+ )
200
+
201
+
202
+ def _output_item_done_text(index: int = 0) -> str:
203
+ return _make_event(
204
+ "response.output_item.done",
205
+ output_index=index,
206
+ item={"type": "message", "role": "assistant"},
207
+ )
208
+
209
+
210
+ def _response_completed(response_id: str = "resp_A") -> str:
211
+ return _make_event(
212
+ "response.completed",
213
+ response={"id": response_id, "status": "completed"},
214
+ )
215
+
216
+
217
+ def _output_item_added_shell(index: int = 0) -> str:
218
+ """Simulate a Codex built-in tool (shell) that should pass through."""
219
+ return _make_event(
220
+ "response.output_item.added",
221
+ output_index=index,
222
+ item={"type": "function_call", "name": "shell", "call_id": "call_shell"},
223
+ )
224
+
225
+
226
+ # ---------------------------------------------------------------------------
227
+ # Tests
228
+ # ---------------------------------------------------------------------------
229
+
230
+
231
+ class TestWSMemoryRelayNonMemory:
232
+ """Responses with no memory tools pass through normally."""
233
+
234
+ def test_text_response_relayed_immediately(self):
235
+ """Text-only response: all events relayed, no buffering after first item."""
236
+ relay = WSMemoryRelayState()
237
+
238
+ events = [
239
+ _response_created(),
240
+ _output_item_added_text(),
241
+ _output_text_delta(text="The answer is 42"),
242
+ _output_item_done_text(),
243
+ _response_completed(),
244
+ ]
245
+
246
+ all_relayed: list[str] = []
247
+ for ev in events:
248
+ result = relay.process_event(ev)
249
+ all_relayed.extend(result["relay"])
250
+ assert result["execute_tools"] == []
251
+ assert result["send_continuation"] is None
252
+
253
+ # All 5 events should be relayed
254
+ assert len(all_relayed) == 5
255
+
256
+ # First event (response.created) should be buffered then flushed
257
+ # with the second event (output_item_added_text)
258
+ types = [json.loads(e)["type"] for e in all_relayed]
259
+ assert types == [
260
+ "response.created",
261
+ "response.output_item.added",
262
+ "response.output_text.delta",
263
+ "response.output_item.done",
264
+ "response.completed",
265
+ ]
266
+
267
+ def test_shell_tool_relayed(self):
268
+ """Codex built-in tools (shell) pass through without interception."""
269
+ relay = WSMemoryRelayState()
270
+
271
+ events = [
272
+ _response_created(),
273
+ _output_item_added_shell(),
274
+ _response_completed(),
275
+ ]
276
+
277
+ all_relayed: list[str] = []
278
+ for ev in events:
279
+ result = relay.process_event(ev)
280
+ all_relayed.extend(result["relay"])
281
+ assert result["execute_tools"] == []
282
+ assert result["send_continuation"] is None
283
+
284
+ assert len(all_relayed) == 3
285
+
286
+ def test_empty_response_relayed(self):
287
+ """Response with no output items still relays created + completed."""
288
+ relay = WSMemoryRelayState()
289
+
290
+ events = [
291
+ _response_created(),
292
+ _response_completed(),
293
+ ]
294
+
295
+ all_relayed: list[str] = []
296
+ for ev in events:
297
+ result = relay.process_event(ev)
298
+ all_relayed.extend(result["relay"])
299
+
300
+ assert len(all_relayed) == 2
301
+
302
+
303
+ class TestWSMemoryRelayMemoryTool:
304
+ """Responses with memory tools are intercepted transparently."""
305
+
306
+ def test_memory_search_fully_suppressed(self):
307
+ """memory_search call: ALL events suppressed from Codex."""
308
+ relay = WSMemoryRelayState()
309
+
310
+ events = [
311
+ _response_created("resp_A"),
312
+ _output_item_added_function_call("memory_search", index=0),
313
+ _function_call_args_delta(index=0),
314
+ _function_call_args_done(index=0),
315
+ _output_item_done_function_call("memory_search", index=0),
316
+ _response_completed("resp_A"),
317
+ ]
318
+
319
+ all_relayed: list[str] = []
320
+ tool_executions: list[Any] = []
321
+ continuations: list[Any] = []
322
+
323
+ for ev in events:
324
+ result = relay.process_event(ev)
325
+ all_relayed.extend(result["relay"])
326
+ tool_executions.extend(result["execute_tools"])
327
+ if result["send_continuation"]:
328
+ continuations.append(result["send_continuation"])
329
+
330
+ # ZERO events relayed to Codex
331
+ assert len(all_relayed) == 0, (
332
+ f"Expected 0 relayed events, got {len(all_relayed)}: "
333
+ f"{[json.loads(e)['type'] for e in all_relayed]}"
334
+ )
335
+
336
+ # Tool execution triggered
337
+ assert len(tool_executions) == 1
338
+ assert tool_executions[0]["name"] == "memory_search"
339
+
340
+ # Continuation requested
341
+ assert len(continuations) == 1
342
+ assert continuations[0]["response_id"] == "resp_A"
343
+
344
+ def test_memory_save_also_suppressed(self):
345
+ """memory_save call is also intercepted."""
346
+ relay = WSMemoryRelayState()
347
+
348
+ events = [
349
+ _response_created("resp_B"),
350
+ _output_item_added_function_call("memory_save", index=0, call_id="call_save"),
351
+ _function_call_args_done(index=0, arguments='{"content": "user likes dark mode"}'),
352
+ _output_item_done_function_call(
353
+ "memory_save",
354
+ index=0,
355
+ call_id="call_save",
356
+ arguments='{"content": "user likes dark mode"}',
357
+ ),
358
+ _response_completed("resp_B"),
359
+ ]
360
+
361
+ all_relayed: list[str] = []
362
+ tool_executions: list[Any] = []
363
+
364
+ for ev in events:
365
+ result = relay.process_event(ev)
366
+ all_relayed.extend(result["relay"])
367
+ tool_executions.extend(result["execute_tools"])
368
+
369
+ assert len(all_relayed) == 0
370
+ assert len(tool_executions) == 1
371
+ assert tool_executions[0]["name"] == "memory_save"
372
+
373
+ def test_continuation_response_relayed_normally(self):
374
+ """After memory tool handling, the continuation response passes through."""
375
+ relay = WSMemoryRelayState()
376
+
377
+ # --- First response: memory_search (suppressed) ---
378
+ first_response_events = [
379
+ _response_created("resp_A"),
380
+ _output_item_added_function_call("memory_search", index=0),
381
+ _function_call_args_done(index=0),
382
+ _output_item_done_function_call("memory_search", index=0),
383
+ _response_completed("resp_A"),
384
+ ]
385
+
386
+ for ev in first_response_events:
387
+ relay.process_event(ev)
388
+
389
+ # --- Second response: continuation text (relayed) ---
390
+ continuation_events = [
391
+ _response_created("resp_B"),
392
+ _output_item_added_text(index=0),
393
+ _output_text_delta(index=0, text="The codename is Pegasus-2"),
394
+ _output_item_done_text(index=0),
395
+ _response_completed("resp_B"),
396
+ ]
397
+
398
+ all_relayed: list[str] = []
399
+ for ev in continuation_events:
400
+ result = relay.process_event(ev)
401
+ all_relayed.extend(result["relay"])
402
+ assert result["execute_tools"] == []
403
+ assert result["send_continuation"] is None
404
+
405
+ # All continuation events relayed
406
+ assert len(all_relayed) == 5
407
+ types = [json.loads(e)["type"] for e in all_relayed]
408
+ assert types[0] == "response.created"
409
+ assert types[-1] == "response.completed"
410
+
411
+ # Verify the text content
412
+ text_events = [
413
+ json.loads(e)
414
+ for e in all_relayed
415
+ if json.loads(e)["type"] == "response.output_text.delta"
416
+ ]
417
+ assert len(text_events) == 1
418
+ assert text_events[0]["delta"] == "The codename is Pegasus-2"
419
+
420
+ def test_non_json_message_always_relayed(self):
421
+ """Binary or non-JSON messages pass through regardless."""
422
+ relay = WSMemoryRelayState()
423
+
424
+ result = relay.process_event("not valid json {{{")
425
+ assert len(result["relay"]) == 1
426
+ assert result["relay"][0] == "not valid json {{{"
427
+
428
+ def test_multiple_memory_tools_in_one_response(self):
429
+ """Multiple memory tools in one response — all suppressed."""
430
+ relay = WSMemoryRelayState()
431
+
432
+ events = [
433
+ _response_created("resp_multi"),
434
+ _output_item_added_function_call("memory_search", index=0, call_id="call_1"),
435
+ _output_item_done_function_call("memory_search", index=0, call_id="call_1"),
436
+ # The model decides to save something too
437
+ _output_item_added_function_call("memory_save", index=1, call_id="call_2"),
438
+ _output_item_done_function_call(
439
+ "memory_save",
440
+ index=1,
441
+ call_id="call_2",
442
+ arguments='{"content": "test"}',
443
+ ),
444
+ _response_completed("resp_multi"),
445
+ ]
446
+
447
+ all_relayed: list[str] = []
448
+ tool_executions: list[Any] = []
449
+ continuations: list[Any] = []
450
+
451
+ for ev in events:
452
+ result = relay.process_event(ev)
453
+ all_relayed.extend(result["relay"])
454
+ tool_executions.extend(result["execute_tools"])
455
+ if result["send_continuation"]:
456
+ continuations.append(result["send_continuation"])
457
+
458
+ assert len(all_relayed) == 0
459
+ assert len(tool_executions) == 2
460
+ assert {t["name"] for t in tool_executions} == {"memory_search", "memory_save"}
461
+ assert len(continuations) == 1
462
+
463
+
464
+ class TestWSMemoryRelayStateReset:
465
+ """State resets properly between responses."""
466
+
467
+ def test_state_resets_after_memory_response(self):
468
+ """After a memory response, the relay is ready for a fresh response."""
469
+ relay = WSMemoryRelayState()
470
+
471
+ # Memory response
472
+ for ev in [
473
+ _response_created("resp_A"),
474
+ _output_item_added_function_call("memory_search"),
475
+ _output_item_done_function_call("memory_search"),
476
+ _response_completed("resp_A"),
477
+ ]:
478
+ relay.process_event(ev)
479
+
480
+ # State should be reset
481
+ assert relay.decided is False
482
+ assert relay.suppress_response is False
483
+ assert len(relay.pending_function_calls) == 0
484
+ assert len(relay.event_buffer) == 0
485
+
486
+ def test_alternating_memory_and_normal(self):
487
+ """Memory response → normal response → both work correctly."""
488
+ relay = WSMemoryRelayState()
489
+
490
+ # 1. Memory response (suppressed)
491
+ for ev in [
492
+ _response_created("resp_A"),
493
+ _output_item_added_function_call("memory_search"),
494
+ _output_item_done_function_call("memory_search"),
495
+ _response_completed("resp_A"),
496
+ ]:
497
+ relay.process_event(ev)
498
+
499
+ # 2. Continuation text response (relayed)
500
+ relayed: list[str] = []
501
+ for ev in [
502
+ _response_created("resp_B"),
503
+ _output_item_added_text(),
504
+ _output_text_delta(text="Pegasus-2"),
505
+ _output_item_done_text(),
506
+ _response_completed("resp_B"),
507
+ ]:
508
+ result = relay.process_event(ev)
509
+ relayed.extend(result["relay"])
510
+
511
+ assert len(relayed) == 5
512
+
513
+ # 3. Another normal response should also work
514
+ relayed2: list[str] = []
515
+ for ev in [
516
+ _response_created("resp_C"),
517
+ _output_item_added_shell(),
518
+ _response_completed("resp_C"),
519
+ ]:
520
+ result = relay.process_event(ev)
521
+ relayed2.extend(result["relay"])
522
+
523
+ assert len(relayed2) == 3