Spaces:
Running
Running
Merge branch 'chopratejas:main' into ci/release-automation
Browse files- .github/dependabot.yml +27 -0
- .github/workflows/docker.yml +17 -0
- .github/workflows/publish.yml +15 -2
- .gitignore +3 -0
- Dockerfile +7 -3
- README.md +32 -35
- headroom/cli/learn.py +14 -1
- headroom/cli/mcp.py +23 -3
- headroom/cli/proxy.py +35 -4
- headroom/cli/wrap.py +260 -12
- headroom/install/state.py +16 -7
- headroom/learn/base.py +9 -3
- headroom/learn/plugins/claude.py +18 -10
- headroom/learn/plugins/codex.py +17 -6
- headroom/learn/plugins/gemini.py +14 -6
- headroom/memory/factory.py +10 -0
- headroom/memory/mcp_server.py +375 -0
- headroom/memory/sync.py +395 -0
- headroom/memory/sync_adapters/__init__.py +1 -0
- headroom/memory/sync_adapters/claude_code.py +233 -0
- headroom/memory/sync_adapters/codex_agent.py +106 -0
- headroom/memory/writers/claude_writer.py +2 -1
- headroom/proxy/handlers/openai.py +370 -20
- headroom/proxy/handlers/streaming.py +7 -4
- headroom/proxy/memory_handler.py +33 -13
- headroom/proxy/models.py +3 -0
- headroom/proxy/request_logger.py +23 -7
- headroom/telemetry/toin.py +4 -2
- headroom/transforms/kompress_compressor.py +98 -79
- headroom/transforms/smart_crusher.py +60 -48
- plugins/openclaw/package.json +54 -54
- tests/test_cli/test_wrap_copilot.py +17 -4
- tests/test_learn/test_scanner.py +64 -0
- tests/test_memory_sync.py +647 -0
- tests/test_package_init_lazy.py +2 -1
- tests/test_transforms/test_kompress_compressor.py +2 -3
- tests/test_transforms/test_smart_crusher_bugs.py +212 -0
- tests/test_transforms/test_universal_json_crush.py +20 -18
- tests/test_ws_memory_relay.py +523 -0
.github/dependabot.yml
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version: 2
|
| 2 |
+
updates:
|
| 3 |
+
# Docker base image digest updates
|
| 4 |
+
- package-ecosystem: docker
|
| 5 |
+
directory: /
|
| 6 |
+
schedule:
|
| 7 |
+
interval: weekly
|
| 8 |
+
commit-message:
|
| 9 |
+
prefix: "docker"
|
| 10 |
+
|
| 11 |
+
# GitHub Actions version updates
|
| 12 |
+
- package-ecosystem: github-actions
|
| 13 |
+
directory: /
|
| 14 |
+
schedule:
|
| 15 |
+
interval: weekly
|
| 16 |
+
commit-message:
|
| 17 |
+
prefix: "ci"
|
| 18 |
+
|
| 19 |
+
# Python dependency updates (pip)
|
| 20 |
+
- package-ecosystem: pip
|
| 21 |
+
directory: /
|
| 22 |
+
schedule:
|
| 23 |
+
interval: weekly
|
| 24 |
+
commit-message:
|
| 25 |
+
prefix: "deps"
|
| 26 |
+
# Only open PRs for security updates to avoid noise
|
| 27 |
+
open-pull-requests-limit: 5
|
.github/workflows/docker.yml
CHANGED
|
@@ -12,6 +12,7 @@ env:
|
|
| 12 |
permissions:
|
| 13 |
contents: read
|
| 14 |
packages: write
|
|
|
|
| 15 |
|
| 16 |
jobs:
|
| 17 |
docker-variant-tags:
|
|
@@ -81,3 +82,19 @@ jobs:
|
|
| 81 |
set: |
|
| 82 |
*.cache-from=type=gha
|
| 83 |
*.cache-to=type=gha,mode=max
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
permissions:
|
| 13 |
contents: read
|
| 14 |
packages: write
|
| 15 |
+
id-token: write # For cosign keyless signing via Sigstore OIDC
|
| 16 |
|
| 17 |
jobs:
|
| 18 |
docker-variant-tags:
|
|
|
|
| 82 |
set: |
|
| 83 |
*.cache-from=type=gha
|
| 84 |
*.cache-to=type=gha,mode=max
|
| 85 |
+
|
| 86 |
+
- name: Install cosign
|
| 87 |
+
uses: sigstore/cosign-installer@v3
|
| 88 |
+
|
| 89 |
+
- name: Sign images with cosign (keyless via Sigstore OIDC)
|
| 90 |
+
env:
|
| 91 |
+
BAKE_META: ${{ steps.bake.outputs.metadata }}
|
| 92 |
+
run: |
|
| 93 |
+
# Extract all pushed image digests from bake metadata and sign each
|
| 94 |
+
echo "$BAKE_META" | jq -r '
|
| 95 |
+
to_entries[].value."containerimage.digest" // empty
|
| 96 |
+
' | while read -r digest; do
|
| 97 |
+
image="${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}@${digest}"
|
| 98 |
+
echo "Signing ${image}"
|
| 99 |
+
cosign sign --yes "${image}"
|
| 100 |
+
done
|
.github/workflows/publish.yml
CHANGED
|
@@ -11,7 +11,8 @@ jobs:
|
|
| 11 |
runs-on: ubuntu-latest
|
| 12 |
environment: pypi
|
| 13 |
permissions:
|
| 14 |
-
id-token: write
|
|
|
|
| 15 |
|
| 16 |
steps:
|
| 17 |
- uses: actions/checkout@v4
|
|
@@ -23,11 +24,23 @@ jobs:
|
|
| 23 |
|
| 24 |
- name: Install build tools
|
| 25 |
run: |
|
| 26 |
-
python -m pip install --upgrade pip build
|
| 27 |
|
| 28 |
- name: Build package
|
| 29 |
run: |
|
| 30 |
python -m build
|
| 31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
- name: Publish to PyPI
|
| 33 |
uses: pypa/gh-action-pypi-publish@release/v1
|
|
|
|
| 11 |
runs-on: ubuntu-latest
|
| 12 |
environment: pypi
|
| 13 |
permissions:
|
| 14 |
+
id-token: write # For trusted publishing
|
| 15 |
+
contents: write # For uploading release assets
|
| 16 |
|
| 17 |
steps:
|
| 18 |
- uses: actions/checkout@v4
|
|
|
|
| 24 |
|
| 25 |
- name: Install build tools
|
| 26 |
run: |
|
| 27 |
+
python -m pip install --upgrade pip build cyclonedx-bom
|
| 28 |
|
| 29 |
- name: Build package
|
| 30 |
run: |
|
| 31 |
python -m build
|
| 32 |
|
| 33 |
+
- name: Generate SBOM (CycloneDX)
|
| 34 |
+
run: |
|
| 35 |
+
pip install -e ".[proxy]"
|
| 36 |
+
cyclonedx-py environment \
|
| 37 |
+
--output-format json \
|
| 38 |
+
--outfile dist/headroom-sbom.cdx.json
|
| 39 |
+
|
| 40 |
+
- name: Upload SBOM to release
|
| 41 |
+
uses: softprops/action-gh-release@v2
|
| 42 |
+
with:
|
| 43 |
+
files: dist/headroom-sbom.cdx.json
|
| 44 |
+
|
| 45 |
- name: Publish to PyPI
|
| 46 |
uses: pypa/gh-action-pypi-publish@release/v1
|
.gitignore
CHANGED
|
@@ -12,6 +12,9 @@ scripts/*
|
|
| 12 |
# Swift SDK (separate repo)
|
| 13 |
swift/
|
| 14 |
|
|
|
|
|
|
|
|
|
|
| 15 |
# Audit/scan outputs (contain security findings — never commit)
|
| 16 |
bandit_result.txt
|
| 17 |
pip_audit_result.txt
|
|
|
|
| 12 |
# Swift SDK (separate repo)
|
| 13 |
swift/
|
| 14 |
|
| 15 |
+
# Local planning docs (never commit)
|
| 16 |
+
ENTERPRISE_HARDENING.md
|
| 17 |
+
|
| 18 |
# Audit/scan outputs (contain security findings — never commit)
|
| 19 |
bandit_result.txt
|
| 20 |
pip_audit_result.txt
|
Dockerfile
CHANGED
|
@@ -1,10 +1,14 @@
|
|
| 1 |
ARG PYTHON_VERSION=3.11
|
| 2 |
ARG UV_VERSION=0.6.17
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
ARG DISTROLESS_IMAGE=gcr.io/distroless/python3-debian13
|
| 4 |
ARG PYTHON_SITE_PACKAGES=/usr/local/lib/python${PYTHON_VERSION}/site-packages
|
| 5 |
|
| 6 |
# ---- Build stage: compile native extensions, build wheel ----
|
| 7 |
-
FROM python:${PYTHON_VERSION}-slim AS builder
|
| 8 |
|
| 9 |
ARG UV_VERSION
|
| 10 |
|
|
@@ -32,7 +36,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
|
| 32 |
uv pip install --system --no-deps --reinstall-package headroom-ai .
|
| 33 |
|
| 34 |
# ---- Runtime stage (python-slim): supports root/nonroot via build arg ----
|
| 35 |
-
FROM python:${PYTHON_VERSION}-slim AS runtime-slim-base
|
| 36 |
|
| 37 |
ARG RUNTIME_USER=nonroot
|
| 38 |
ARG PYTHON_SITE_PACKAGES
|
|
@@ -69,7 +73,7 @@ HEALTHCHECK --interval=30s --timeout=5s --start-period=20s --retries=3 \
|
|
| 69 |
ENTRYPOINT ["headroom", "proxy"]
|
| 70 |
CMD ["--host", "0.0.0.0", "--port", "8787"]
|
| 71 |
|
| 72 |
-
FROM ${DISTROLESS_IMAGE} AS runtime-slim
|
| 73 |
|
| 74 |
ARG RUNTIME_USER=nonroot
|
| 75 |
ARG PYTHON_SITE_PACKAGES
|
|
|
|
| 1 |
ARG PYTHON_VERSION=3.11
|
| 2 |
ARG UV_VERSION=0.6.17
|
| 3 |
+
# Pinned 2026-04-15. Update via Dependabot or: docker pull python:3.11-slim
|
| 4 |
+
ARG PYTHON_DIGEST=sha256:233de06753d30d120b1a3ce359d8d3be8bda78524cd8f520c99883bfe33964cf
|
| 5 |
+
# Pinned 2026-04-15. Update via Dependabot or: docker pull gcr.io/distroless/python3-debian13
|
| 6 |
+
ARG DISTROLESS_DIGEST=sha256:ed3a4beb46f8f8baac068743ba1b1f95ea3f793422129cf6dd23967f779b6018
|
| 7 |
ARG DISTROLESS_IMAGE=gcr.io/distroless/python3-debian13
|
| 8 |
ARG PYTHON_SITE_PACKAGES=/usr/local/lib/python${PYTHON_VERSION}/site-packages
|
| 9 |
|
| 10 |
# ---- Build stage: compile native extensions, build wheel ----
|
| 11 |
+
FROM python:${PYTHON_VERSION}-slim@${PYTHON_DIGEST} AS builder
|
| 12 |
|
| 13 |
ARG UV_VERSION
|
| 14 |
|
|
|
|
| 36 |
uv pip install --system --no-deps --reinstall-package headroom-ai .
|
| 37 |
|
| 38 |
# ---- Runtime stage (python-slim): supports root/nonroot via build arg ----
|
| 39 |
+
FROM python:${PYTHON_VERSION}-slim@${PYTHON_DIGEST} AS runtime-slim-base
|
| 40 |
|
| 41 |
ARG RUNTIME_USER=nonroot
|
| 42 |
ARG PYTHON_SITE_PACKAGES
|
|
|
|
| 73 |
ENTRYPOINT ["headroom", "proxy"]
|
| 74 |
CMD ["--host", "0.0.0.0", "--port", "8787"]
|
| 75 |
|
| 76 |
+
FROM ${DISTROLESS_IMAGE}@${DISTROLESS_DIGEST} AS runtime-slim
|
| 77 |
|
| 78 |
ARG RUNTIME_USER=nonroot
|
| 79 |
ARG PYTHON_SITE_PACKAGES
|
README.md
CHANGED
|
@@ -155,9 +155,9 @@ OPENAI_BASE_URL=http://localhost:8787/v1 your-app
|
|
| 155 |
Use `token` mode for short/medium sessions where raw compression savings matter most.
|
| 156 |
Use `cache` mode for long-running chats where preserving prior-turn bytes improves provider cache reuse.
|
| 157 |
|
| 158 |
-
Works with any language, any tool, any framework. **[Proxy docs](docs/proxy.
|
| 159 |
|
| 160 |
-
Prefer Docker as the runtime provider? See **[
|
| 161 |
|
| 162 |
### Coding agents — one command
|
| 163 |
|
|
@@ -189,7 +189,7 @@ summary = ctx.get("research") # Agent B reads (~80% smaller)
|
|
| 189 |
full = ctx.get("research", full=True) # Agent B gets original if needed
|
| 190 |
```
|
| 191 |
|
| 192 |
-
Compress what moves between agents — any framework. **[SharedContext Guide](docs/shared-context.
|
| 193 |
|
| 194 |
### MCP Tools (Claude Code, Cursor)
|
| 195 |
|
|
@@ -197,7 +197,7 @@ Compress what moves between agents — any framework. **[SharedContext Guide](do
|
|
| 197 |
headroom mcp install && claude
|
| 198 |
```
|
| 199 |
|
| 200 |
-
Gives your AI tool three MCP tools: `headroom_compress`, `headroom_retrieve`, `headroom_stats`. **[MCP Guide](docs/mcp.
|
| 201 |
|
| 202 |
### Drop into your existing stack
|
| 203 |
|
|
@@ -219,7 +219,7 @@ Gives your AI tool three MCP tools: `headroom_compress`, `headroom_retrieve`, `h
|
|
| 219 |
| **Codex / Aider** | Wrap | `headroom wrap codex` or `headroom wrap aider` |
|
| 220 |
| **Always-on local proxy** | Persistent install | `headroom install apply --preset persistent-service --providers auto` |
|
| 221 |
|
| 222 |
-
**[Full Integration Guide](docs/
|
| 223 |
|
| 224 |
---
|
| 225 |
|
|
@@ -292,7 +292,7 @@ python -m headroom.evals suite --tier 1 -o eval_results/
|
|
| 292 |
python -m headroom.evals suite --tier 1 --ci
|
| 293 |
```
|
| 294 |
|
| 295 |
-
Full methodology: [Benchmarks](docs/benchmarks.
|
| 296 |
|
| 297 |
---
|
| 298 |
|
|
@@ -317,7 +317,7 @@ headroom wrap claude --memory # Claude with persistent memory
|
|
| 317 |
headroom wrap codex --memory # Codex shares the SAME memory store
|
| 318 |
```
|
| 319 |
|
| 320 |
-
Claude saves a fact, Codex reads it back. All agents sharing one proxy share one memory — project-scoped, user-isolated, with agent provenance tracking and automatic deduplication. No SDK changes needed. **[Memory docs](docs/memory.
|
| 321 |
|
| 322 |
### Failure Learning
|
| 323 |
|
|
@@ -327,7 +327,7 @@ headroom learn --apply # Write learnings to agent-native files
|
|
| 327 |
headroom learn --agent codex --all # Analyze all Codex sessions
|
| 328 |
```
|
| 329 |
|
| 330 |
-
Plugin-based: reads conversation history from Claude Code, Codex, or Gemini CLI. Finds failure patterns, correlates with successes, writes corrections to CLAUDE.md / AGENTS.md / GEMINI.md. External plugins via entry points. **[Learn docs](docs/
|
| 331 |
|
| 332 |
<p align="center">
|
| 333 |
<img src="headroom_learn.gif" alt="headroom learn demo" width="800">
|
|
@@ -405,7 +405,7 @@ Context compression is a new space. Here's how the approaches differ:
|
|
| 405 |
Originals are in the Compressed Store — nothing is thrown away.
|
| 406 |
```
|
| 407 |
|
| 408 |
-
**Overhead**: 15-200ms compression latency (net positive for Sonnet/Opus). Full data: [
|
| 409 |
|
| 410 |
---
|
| 411 |
|
|
@@ -413,16 +413,16 @@ Context compression is a new space. Here's how the approaches differ:
|
|
| 413 |
|
| 414 |
| Integration | Status | Docs |
|
| 415 |
|-------------|--------|------|
|
| 416 |
-
| `headroom wrap claude/copilot/codex/aider/cursor` | **Stable** | [Proxy Docs](docs/proxy.
|
| 417 |
-
| `compress()` — one function | **Stable** | [Integration Guide](docs/
|
| 418 |
-
| `SharedContext` — multi-agent | **Stable** | [SharedContext Guide](docs/shared-context.
|
| 419 |
-
| LiteLLM callback | **Stable** | [
|
| 420 |
-
| ASGI middleware | **Stable** | [Integration Guide](docs/
|
| 421 |
-
| Proxy server | **Stable** | [Proxy Docs](docs/proxy.
|
| 422 |
-
| Agno | **Stable** | [Agno Guide](docs/agno.
|
| 423 |
-
| MCP (Claude Code, Cursor, etc.) | **Stable** | [MCP Guide](docs/mcp.
|
| 424 |
-
| Strands | **Stable** | [Strands Guide](docs/strands.
|
| 425 |
-
| LangChain | **Stable** | [LangChain Guide](docs/langchain.
|
| 426 |
| **OpenClaw** | **Stable** | [OpenClaw plugin](#openclaw-plugin) |
|
| 427 |
|
| 428 |
---
|
|
@@ -521,23 +521,20 @@ Python 3.10+
|
|
| 521 |
|
| 522 |
| | |
|
| 523 |
|---|---|
|
| 524 |
-
| [Integration Guide](docs/
|
| 525 |
-
| [Proxy Docs](docs/proxy.
|
| 526 |
-
| [Architecture](docs/
|
| 527 |
-
| [CCR Guide](docs/ccr.
|
| 528 |
-
| [Benchmarks](docs/benchmarks.
|
| 529 |
-
| [
|
| 530 |
-
| [Limitations](docs/LIMITATIONS.md) | When compression helps, when it doesn't |
|
| 531 |
| [Evals Framework](headroom/evals/README.md) | Prove compression preserves accuracy |
|
| 532 |
-
| [Memory](docs/memory.
|
| 533 |
-
| [Agno](docs/agno.
|
| 534 |
-
| [MCP](docs/mcp.
|
| 535 |
-
| [SharedContext](docs/shared-context.
|
| 536 |
-
| [Learn](docs/
|
| 537 |
-
| [
|
| 538 |
-
| [
|
| 539 |
-
| [Persistent Installs](docs/persistent-installs.md) | Service/task/docker deployment models and provider scopes |
|
| 540 |
-
| [Configuration](docs/configuration.md) | All options |
|
| 541 |
|
| 542 |
---
|
| 543 |
|
|
|
|
| 155 |
Use `token` mode for short/medium sessions where raw compression savings matter most.
|
| 156 |
Use `cache` mode for long-running chats where preserving prior-turn bytes improves provider cache reuse.
|
| 157 |
|
| 158 |
+
Works with any language, any tool, any framework. **[Proxy docs](docs/content/docs/proxy.mdx)**
|
| 159 |
|
| 160 |
+
Prefer Docker as the runtime provider? See **[Installation — Docker](docs/content/docs/installation.mdx)**.
|
| 161 |
|
| 162 |
### Coding agents — one command
|
| 163 |
|
|
|
|
| 189 |
full = ctx.get("research", full=True) # Agent B gets original if needed
|
| 190 |
```
|
| 191 |
|
| 192 |
+
Compress what moves between agents — any framework. **[SharedContext Guide](docs/content/docs/shared-context.mdx)**
|
| 193 |
|
| 194 |
### MCP Tools (Claude Code, Cursor)
|
| 195 |
|
|
|
|
| 197 |
headroom mcp install && claude
|
| 198 |
```
|
| 199 |
|
| 200 |
+
Gives your AI tool three MCP tools: `headroom_compress`, `headroom_retrieve`, `headroom_stats`. **[MCP Guide](docs/content/docs/mcp.mdx)**
|
| 201 |
|
| 202 |
### Drop into your existing stack
|
| 203 |
|
|
|
|
| 219 |
| **Codex / Aider** | Wrap | `headroom wrap codex` or `headroom wrap aider` |
|
| 220 |
| **Always-on local proxy** | Persistent install | `headroom install apply --preset persistent-service --providers auto` |
|
| 221 |
|
| 222 |
+
**[Full Integration Guide](docs/content/docs/index.mdx)**
|
| 223 |
|
| 224 |
---
|
| 225 |
|
|
|
|
| 292 |
python -m headroom.evals suite --tier 1 --ci
|
| 293 |
```
|
| 294 |
|
| 295 |
+
Full methodology: [Benchmarks](docs/content/docs/benchmarks.mdx) | [Evals Framework](headroom/evals/README.md)
|
| 296 |
|
| 297 |
---
|
| 298 |
|
|
|
|
| 317 |
headroom wrap codex --memory # Codex shares the SAME memory store
|
| 318 |
```
|
| 319 |
|
| 320 |
+
Claude saves a fact, Codex reads it back. All agents sharing one proxy share one memory — project-scoped, user-isolated, with agent provenance tracking and automatic deduplication. No SDK changes needed. **[Memory docs](docs/content/docs/memory.mdx)**
|
| 321 |
|
| 322 |
### Failure Learning
|
| 323 |
|
|
|
|
| 327 |
headroom learn --agent codex --all # Analyze all Codex sessions
|
| 328 |
```
|
| 329 |
|
| 330 |
+
Plugin-based: reads conversation history from Claude Code, Codex, or Gemini CLI. Finds failure patterns, correlates with successes, writes corrections to CLAUDE.md / AGENTS.md / GEMINI.md. External plugins via entry points. **[Learn docs](docs/content/docs/failure-learning.mdx)**
|
| 331 |
|
| 332 |
<p align="center">
|
| 333 |
<img src="headroom_learn.gif" alt="headroom learn demo" width="800">
|
|
|
|
| 405 |
Originals are in the Compressed Store — nothing is thrown away.
|
| 406 |
```
|
| 407 |
|
| 408 |
+
**Overhead**: 15-200ms compression latency (net positive for Sonnet/Opus). Full data: [Benchmarks](docs/content/docs/benchmarks.mdx)
|
| 409 |
|
| 410 |
---
|
| 411 |
|
|
|
|
| 413 |
|
| 414 |
| Integration | Status | Docs |
|
| 415 |
|-------------|--------|------|
|
| 416 |
+
| `headroom wrap claude/copilot/codex/aider/cursor` | **Stable** | [Proxy Docs](docs/content/docs/proxy.mdx) |
|
| 417 |
+
| `compress()` — one function | **Stable** | [Integration Guide](docs/content/docs/index.mdx) |
|
| 418 |
+
| `SharedContext` — multi-agent | **Stable** | [SharedContext Guide](docs/content/docs/shared-context.mdx) |
|
| 419 |
+
| LiteLLM callback | **Stable** | [LiteLLM Guide](docs/content/docs/litellm.mdx) |
|
| 420 |
+
| ASGI middleware | **Stable** | [Integration Guide](docs/content/docs/index.mdx) |
|
| 421 |
+
| Proxy server | **Stable** | [Proxy Docs](docs/content/docs/proxy.mdx) |
|
| 422 |
+
| Agno | **Stable** | [Agno Guide](docs/content/docs/agno.mdx) |
|
| 423 |
+
| MCP (Claude Code, Cursor, etc.) | **Stable** | [MCP Guide](docs/content/docs/mcp.mdx) |
|
| 424 |
+
| Strands | **Stable** | [Strands Guide](docs/content/docs/strands.mdx) |
|
| 425 |
+
| LangChain | **Stable** | [LangChain Guide](docs/content/docs/langchain.mdx) |
|
| 426 |
| **OpenClaw** | **Stable** | [OpenClaw plugin](#openclaw-plugin) |
|
| 427 |
|
| 428 |
---
|
|
|
|
| 521 |
|
| 522 |
| | |
|
| 523 |
|---|---|
|
| 524 |
+
| [Integration Guide](docs/content/docs/index.mdx) | LiteLLM, ASGI, compress(), proxy |
|
| 525 |
+
| [Proxy Docs](docs/content/docs/proxy.mdx) | Proxy server configuration |
|
| 526 |
+
| [Architecture](docs/content/docs/architecture.mdx) | How the pipeline works |
|
| 527 |
+
| [CCR Guide](docs/content/docs/ccr.mdx) | Reversible compression |
|
| 528 |
+
| [Benchmarks](docs/content/docs/benchmarks.mdx) | Accuracy validation |
|
| 529 |
+
| [Limitations](docs/content/docs/limitations.mdx) | When compression helps, when it doesn't |
|
|
|
|
| 530 |
| [Evals Framework](headroom/evals/README.md) | Prove compression preserves accuracy |
|
| 531 |
+
| [Memory](docs/content/docs/memory.mdx) | Cross-agent persistent memory with provenance + dedup |
|
| 532 |
+
| [Agno](docs/content/docs/agno.mdx) | Agno agent framework |
|
| 533 |
+
| [MCP](docs/content/docs/mcp.mdx) | Context engineering toolkit (compress, retrieve, stats) |
|
| 534 |
+
| [SharedContext](docs/content/docs/shared-context.mdx) | Compressed inter-agent context sharing |
|
| 535 |
+
| [Learn](docs/content/docs/failure-learning.mdx) | Plugin-based failure learning (Claude, Codex, Gemini, extensible) |
|
| 536 |
+
| [Installation](docs/content/docs/installation.mdx) | pip, npm, Docker install methods |
|
| 537 |
+
| [Configuration](docs/content/docs/configuration.mdx) | All options |
|
|
|
|
|
|
|
| 538 |
|
| 539 |
---
|
| 540 |
|
headroom/cli/learn.py
CHANGED
|
@@ -90,12 +90,21 @@ Use 'auto' (default) to scan all detected agents."""
|
|
| 90 |
help="LLM model for analysis (e.g., claude-sonnet-4-6, gpt-4o, gemini/gemini-2.0-flash). "
|
| 91 |
"Auto-detected from API keys if not specified.",
|
| 92 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
def learn(
|
| 94 |
project: Path | None,
|
| 95 |
analyze_all: bool,
|
| 96 |
apply: bool,
|
| 97 |
agent: str,
|
| 98 |
model: str | None,
|
|
|
|
| 99 |
) -> None:
|
| 100 |
"""Learn from past tool call failures to prevent future ones.
|
| 101 |
|
|
@@ -115,9 +124,13 @@ def learn(
|
|
| 115 |
headroom learn --all # Analyze all projects
|
| 116 |
headroom learn --agent codex --all # Analyze all Codex sessions
|
| 117 |
"""
|
|
|
|
|
|
|
| 118 |
from ..learn.analyzer import SessionAnalyzer, _detect_default_model
|
| 119 |
from ..learn.registry import auto_detect_plugins, get_plugin
|
| 120 |
|
|
|
|
|
|
|
| 121 |
# Resolve model early to fail fast with a clear message
|
| 122 |
try:
|
| 123 |
resolved_model = model or _detect_default_model()
|
|
@@ -185,7 +198,7 @@ def learn(
|
|
| 185 |
click.echo(f"Path: {proj.project_path}")
|
| 186 |
click.echo(f"{'=' * 60}")
|
| 187 |
|
| 188 |
-
sessions = plugin.scan_project(proj)
|
| 189 |
if not sessions:
|
| 190 |
click.echo(" No conversation data found.")
|
| 191 |
continue
|
|
|
|
| 90 |
help="LLM model for analysis (e.g., claude-sonnet-4-6, gpt-4o, gemini/gemini-2.0-flash). "
|
| 91 |
"Auto-detected from API keys if not specified.",
|
| 92 |
)
|
| 93 |
+
@click.option(
|
| 94 |
+
"--workers",
|
| 95 |
+
"-j",
|
| 96 |
+
type=int,
|
| 97 |
+
default=None,
|
| 98 |
+
help="Parallel workers for session scanning. "
|
| 99 |
+
"Default: auto (min of CPU count, 8). Use 1 for serial.",
|
| 100 |
+
)
|
| 101 |
def learn(
|
| 102 |
project: Path | None,
|
| 103 |
analyze_all: bool,
|
| 104 |
apply: bool,
|
| 105 |
agent: str,
|
| 106 |
model: str | None,
|
| 107 |
+
workers: int | None,
|
| 108 |
) -> None:
|
| 109 |
"""Learn from past tool call failures to prevent future ones.
|
| 110 |
|
|
|
|
| 124 |
headroom learn --all # Analyze all projects
|
| 125 |
headroom learn --agent codex --all # Analyze all Codex sessions
|
| 126 |
"""
|
| 127 |
+
import os
|
| 128 |
+
|
| 129 |
from ..learn.analyzer import SessionAnalyzer, _detect_default_model
|
| 130 |
from ..learn.registry import auto_detect_plugins, get_plugin
|
| 131 |
|
| 132 |
+
max_workers = workers if workers is not None else min(os.cpu_count() or 4, 8)
|
| 133 |
+
|
| 134 |
# Resolve model early to fail fast with a clear message
|
| 135 |
try:
|
| 136 |
resolved_model = model or _detect_default_model()
|
|
|
|
| 198 |
click.echo(f"Path: {proj.project_path}")
|
| 199 |
click.echo(f"{'=' * 60}")
|
| 200 |
|
| 201 |
+
sessions = plugin.scan_project(proj, max_workers=max_workers)
|
| 202 |
if not sessions:
|
| 203 |
click.echo(" No conversation data found.")
|
| 204 |
continue
|
headroom/cli/mcp.py
CHANGED
|
@@ -244,13 +244,33 @@ def mcp_uninstall() -> None:
|
|
| 244 |
err=True,
|
| 245 |
)
|
| 246 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
# Also remove from mcp.json fallback config if present
|
| 248 |
if MCP_CONFIG_PATH.exists():
|
| 249 |
config = load_mcp_config()
|
| 250 |
-
|
| 251 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
save_mcp_config(config)
|
| 253 |
-
click.echo(f"✓
|
| 254 |
removed = True
|
| 255 |
|
| 256 |
if not removed:
|
|
|
|
| 244 |
err=True,
|
| 245 |
)
|
| 246 |
|
| 247 |
+
# Also remove codebase-memory-mcp if registered (installed by --code-graph)
|
| 248 |
+
if claude_cli:
|
| 249 |
+
cbm_check = subprocess.run(
|
| 250 |
+
[claude_cli, "mcp", "get", "codebase-memory-mcp"],
|
| 251 |
+
capture_output=True,
|
| 252 |
+
)
|
| 253 |
+
if cbm_check.returncode == 0:
|
| 254 |
+
cbm_rm = subprocess.run(
|
| 255 |
+
[claude_cli, "mcp", "remove", "codebase-memory-mcp", "-s", "user"],
|
| 256 |
+
capture_output=True,
|
| 257 |
+
text=True,
|
| 258 |
+
)
|
| 259 |
+
if cbm_rm.returncode == 0:
|
| 260 |
+
click.echo("✓ codebase-memory-mcp MCP server removed")
|
| 261 |
+
removed = True
|
| 262 |
+
|
| 263 |
# Also remove from mcp.json fallback config if present
|
| 264 |
if MCP_CONFIG_PATH.exists():
|
| 265 |
config = load_mcp_config()
|
| 266 |
+
changed = False
|
| 267 |
+
for server_name in ("headroom", "codebase-memory-mcp"):
|
| 268 |
+
if server_name in config.get("mcpServers", {}):
|
| 269 |
+
del config["mcpServers"][server_name]
|
| 270 |
+
changed = True
|
| 271 |
+
if changed:
|
| 272 |
save_mcp_config(config)
|
| 273 |
+
click.echo(f"✓ MCP servers removed from {MCP_CONFIG_PATH}")
|
| 274 |
removed = True
|
| 275 |
|
| 276 |
if not removed:
|
headroom/cli/proxy.py
CHANGED
|
@@ -179,6 +179,13 @@ from .main import main
|
|
| 179 |
is_flag=True,
|
| 180 |
help="Disable anonymous usage telemetry (env: HEADROOM_TELEMETRY=off)",
|
| 181 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
@click.pass_context
|
| 183 |
def proxy(
|
| 184 |
ctx: click.Context,
|
|
@@ -213,6 +220,7 @@ def proxy(
|
|
| 213 |
bedrock_region: str | None,
|
| 214 |
bedrock_profile: str | None,
|
| 215 |
no_telemetry: bool,
|
|
|
|
| 216 |
) -> None:
|
| 217 |
"""Start the optimization proxy server.
|
| 218 |
|
|
@@ -251,10 +259,22 @@ def proxy(
|
|
| 251 |
mode or os.environ.get("HEADROOM_MODE") or PROXY_MODE_TOKEN
|
| 252 |
)
|
| 253 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
# Telemetry opt-out: --no-telemetry flag sets the env var
|
| 255 |
if no_telemetry:
|
| 256 |
os.environ["HEADROOM_TELEMETRY"] = "off"
|
| 257 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
# License key for managed/enterprise deployments (optional)
|
| 259 |
license_key = os.environ.get("HEADROOM_LICENSE_KEY")
|
| 260 |
|
|
@@ -272,7 +292,7 @@ def proxy(
|
|
| 272 |
connect_timeout_seconds=connect_timeout_seconds
|
| 273 |
if connect_timeout_seconds is not None
|
| 274 |
else 10,
|
| 275 |
-
log_file=log_file,
|
| 276 |
budget_limit_usd=budget,
|
| 277 |
# Code graph: live file watcher for incremental reindexing
|
| 278 |
code_graph_watcher=code_graph,
|
|
@@ -284,13 +304,15 @@ def proxy(
|
|
| 284 |
intelligent_context_compress_first=not no_compress_first,
|
| 285 |
# Memory System (Multi-Provider with auto-detection)
|
| 286 |
# --learn implies --memory (need backend for storing patterns)
|
| 287 |
-
|
|
|
|
| 288 |
memory_db_path=memory_db_path,
|
| 289 |
memory_inject_tools=not no_memory_tools,
|
| 290 |
memory_inject_context=not no_memory_context,
|
| 291 |
memory_top_k=memory_top_k,
|
| 292 |
# Traffic Learning: only with --learn, never with --no-learn
|
| 293 |
-
|
|
|
|
| 294 |
traffic_learning_agent_type=os.environ.get("HEADROOM_AGENT_TYPE", "unknown"),
|
| 295 |
# Backend (Anthropic direct, Bedrock, LiteLLM, or any-llm)
|
| 296 |
backend=backend,
|
|
@@ -299,6 +321,8 @@ def proxy(
|
|
| 299 |
anyllm_provider=effective_anyllm_provider,
|
| 300 |
# License / Usage Reporting (managed/enterprise)
|
| 301 |
license_key=license_key,
|
|
|
|
|
|
|
| 302 |
)
|
| 303 |
|
| 304 |
memory_status = "DISABLED"
|
|
@@ -355,6 +379,13 @@ Memory (Multi-Provider):
|
|
| 355 |
- Database: {config.memory_db_path}
|
| 356 |
"""
|
| 357 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 358 |
from headroom.telemetry.beacon import is_telemetry_enabled
|
| 359 |
|
| 360 |
# Build telemetry section for the startup banner
|
|
@@ -381,7 +412,7 @@ Starting proxy server...
|
|
| 381 |
Rate Limit: {"ENABLED" if config.rate_limit_enabled else "DISABLED"}
|
| 382 |
Memory: {memory_status}
|
| 383 |
License: {license_status}
|
| 384 |
-
{telemetry_line}
|
| 385 |
{backend_section}
|
| 386 |
Routing:
|
| 387 |
/v1/messages → {anthropic_url}
|
|
|
|
| 179 |
is_flag=True,
|
| 180 |
help="Disable anonymous usage telemetry (env: HEADROOM_TELEMETRY=off)",
|
| 181 |
)
|
| 182 |
+
@click.option(
|
| 183 |
+
"--stateless",
|
| 184 |
+
is_flag=True,
|
| 185 |
+
help="Disable all filesystem writes — run purely in-memory. "
|
| 186 |
+
"For containerized / read-only / load-balanced deployments. "
|
| 187 |
+
"(env: HEADROOM_STATELESS=true)",
|
| 188 |
+
)
|
| 189 |
@click.pass_context
|
| 190 |
def proxy(
|
| 191 |
ctx: click.Context,
|
|
|
|
| 220 |
bedrock_region: str | None,
|
| 221 |
bedrock_profile: str | None,
|
| 222 |
no_telemetry: bool,
|
| 223 |
+
stateless: bool,
|
| 224 |
) -> None:
|
| 225 |
"""Start the optimization proxy server.
|
| 226 |
|
|
|
|
| 259 |
mode or os.environ.get("HEADROOM_MODE") or PROXY_MODE_TOKEN
|
| 260 |
)
|
| 261 |
|
| 262 |
+
# Stateless mode: CLI flag or env var
|
| 263 |
+
is_stateless = stateless or os.environ.get("HEADROOM_STATELESS", "").lower() in (
|
| 264 |
+
"true",
|
| 265 |
+
"1",
|
| 266 |
+
"yes",
|
| 267 |
+
"on",
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
# Telemetry opt-out: --no-telemetry flag sets the env var
|
| 271 |
if no_telemetry:
|
| 272 |
os.environ["HEADROOM_TELEMETRY"] = "off"
|
| 273 |
|
| 274 |
+
# Stateless mode: suppress TOIN filesystem persistence
|
| 275 |
+
if is_stateless:
|
| 276 |
+
os.environ["HEADROOM_TOIN_BACKEND"] = "none"
|
| 277 |
+
|
| 278 |
# License key for managed/enterprise deployments (optional)
|
| 279 |
license_key = os.environ.get("HEADROOM_LICENSE_KEY")
|
| 280 |
|
|
|
|
| 292 |
connect_timeout_seconds=connect_timeout_seconds
|
| 293 |
if connect_timeout_seconds is not None
|
| 294 |
else 10,
|
| 295 |
+
log_file=None if is_stateless else log_file,
|
| 296 |
budget_limit_usd=budget,
|
| 297 |
# Code graph: live file watcher for incremental reindexing
|
| 298 |
code_graph_watcher=code_graph,
|
|
|
|
| 304 |
intelligent_context_compress_first=not no_compress_first,
|
| 305 |
# Memory System (Multi-Provider with auto-detection)
|
| 306 |
# --learn implies --memory (need backend for storing patterns)
|
| 307 |
+
# Stateless mode disables memory (requires SQLite on disk)
|
| 308 |
+
memory_enabled=False if is_stateless else (memory or (learn and not no_learn)),
|
| 309 |
memory_db_path=memory_db_path,
|
| 310 |
memory_inject_tools=not no_memory_tools,
|
| 311 |
memory_inject_context=not no_memory_context,
|
| 312 |
memory_top_k=memory_top_k,
|
| 313 |
# Traffic Learning: only with --learn, never with --no-learn
|
| 314 |
+
# Stateless mode disables learning (requires filesystem)
|
| 315 |
+
traffic_learning_enabled=False if is_stateless else (learn and not no_learn),
|
| 316 |
traffic_learning_agent_type=os.environ.get("HEADROOM_AGENT_TYPE", "unknown"),
|
| 317 |
# Backend (Anthropic direct, Bedrock, LiteLLM, or any-llm)
|
| 318 |
backend=backend,
|
|
|
|
| 321 |
anyllm_provider=effective_anyllm_provider,
|
| 322 |
# License / Usage Reporting (managed/enterprise)
|
| 323 |
license_key=license_key,
|
| 324 |
+
# Stateless mode: disable all filesystem writes
|
| 325 |
+
stateless=is_stateless,
|
| 326 |
)
|
| 327 |
|
| 328 |
memory_status = "DISABLED"
|
|
|
|
| 379 |
- Database: {config.memory_db_path}
|
| 380 |
"""
|
| 381 |
|
| 382 |
+
# Stateless mode warning
|
| 383 |
+
stateless_line = ""
|
| 384 |
+
if is_stateless:
|
| 385 |
+
stateless_line = (
|
| 386 |
+
" Stateless: YES (no filesystem writes — memory, logs, TOIN disabled)\n"
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
from headroom.telemetry.beacon import is_telemetry_enabled
|
| 390 |
|
| 391 |
# Build telemetry section for the startup banner
|
|
|
|
| 412 |
Rate Limit: {"ENABLED" if config.rate_limit_enabled else "DISABLED"}
|
| 413 |
Memory: {memory_status}
|
| 414 |
License: {license_status}
|
| 415 |
+
{stateless_line}{telemetry_line}
|
| 416 |
{backend_section}
|
| 417 |
Routing:
|
| 418 |
/v1/messages → {anthropic_url}
|
headroom/cli/wrap.py
CHANGED
|
@@ -192,13 +192,51 @@ def _setup_rtk(verbose: bool = False) -> Path | None:
|
|
| 192 |
return rtk_path
|
| 193 |
|
| 194 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
def _setup_code_graph(verbose: bool = False) -> bool:
|
| 196 |
-
"""Ensure codebase-memory-mcp is installed and project is indexed.
|
| 197 |
|
| 198 |
codebase-memory-mcp builds a knowledge graph of the codebase using
|
| 199 |
tree-sitter, enabling the LLM to query code structure (call chains,
|
| 200 |
function definitions, impact analysis) instead of reading entire files.
|
| 201 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
With Claude Code's MCP Tool Search, the 14 graph tools add ~200 tokens
|
| 203 |
overhead per request (not the full ~1,915) — they're lazy-loaded.
|
| 204 |
|
|
@@ -218,6 +256,9 @@ def _setup_code_graph(verbose: bool = False) -> bool:
|
|
| 218 |
|
| 219 |
cbm_bin = str(cbm_path)
|
| 220 |
|
|
|
|
|
|
|
|
|
|
| 221 |
# Index current project (fast — ~1s for most repos, idempotent)
|
| 222 |
project_dir = str(Path.cwd())
|
| 223 |
try:
|
|
@@ -320,6 +361,11 @@ rtk pip list rtk pnpm install rtk npm run <script>
|
|
| 320 |
# Marker used to detect if instructions are already injected
|
| 321 |
_RTK_MARKER = "<!-- headroom:rtk-instructions -->"
|
| 322 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 323 |
|
| 324 |
def _ensure_rtk_binary(verbose: bool = False) -> Path | None:
|
| 325 |
"""Ensure rtk binary is installed (download if needed). No hook registration."""
|
|
@@ -364,11 +410,12 @@ def _inject_codex_provider_config(port: int) -> None:
|
|
| 364 |
config_dir = Path.home() / ".codex"
|
| 365 |
config_file = config_dir / "config.toml"
|
| 366 |
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
|
|
|
| 372 |
f'name = "OpenAI via Headroom proxy"\n'
|
| 373 |
f'base_url = "http://127.0.0.1:{port}/v1"\n'
|
| 374 |
f'env_key = "OPENAI_API_KEY"\n'
|
|
@@ -377,7 +424,7 @@ def _inject_codex_provider_config(port: int) -> None:
|
|
| 377 |
f"# --- end Headroom ---\n"
|
| 378 |
)
|
| 379 |
|
| 380 |
-
marker =
|
| 381 |
end_marker = "# --- end Headroom ---"
|
| 382 |
|
| 383 |
try:
|
|
@@ -386,14 +433,20 @@ def _inject_codex_provider_config(port: int) -> None:
|
|
| 386 |
if config_file.exists():
|
| 387 |
content = config_file.read_text()
|
| 388 |
if marker in content:
|
| 389 |
-
#
|
| 390 |
start = content.index(marker)
|
| 391 |
end = content.index(end_marker) + len(end_marker)
|
| 392 |
-
content = content[:start].rstrip() +
|
| 393 |
-
|
| 394 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 395 |
else:
|
| 396 |
-
content =
|
| 397 |
|
| 398 |
config_file.write_text(content)
|
| 399 |
click.echo(f" Codex config: injected Headroom provider (WS + HTTP) into {config_file}")
|
|
@@ -424,6 +477,81 @@ def _inject_rtk_instructions(file_path: Path, verbose: bool = False) -> bool:
|
|
| 424 |
return True
|
| 425 |
|
| 426 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 427 |
def _resolve_copilot_provider_type(backend: str | None, provider_type: str) -> str:
|
| 428 |
"""Resolve Copilot BYOK provider type for the current proxy backend."""
|
| 429 |
if provider_type != "auto":
|
|
@@ -1088,6 +1216,48 @@ def claude(
|
|
| 1088 |
signal.signal(signal.SIGINT, cleanup)
|
| 1089 |
signal.signal(signal.SIGTERM, cleanup)
|
| 1090 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1091 |
try:
|
| 1092 |
click.echo()
|
| 1093 |
click.echo(" ╔═══════════════════════════════════════════════╗")
|
|
@@ -1241,6 +1411,20 @@ def copilot(
|
|
| 1241 |
env["COPILOT_PROVIDER_TYPE"] = effective_provider_type
|
| 1242 |
env.pop("COPILOT_PROVIDER_WIRE_API", None)
|
| 1243 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1244 |
env_vars_display: list[str]
|
| 1245 |
if effective_provider_type == "anthropic":
|
| 1246 |
env["COPILOT_PROVIDER_BASE_URL"] = f"http://127.0.0.1:{port}"
|
|
@@ -1258,6 +1442,19 @@ def copilot(
|
|
| 1258 |
f"COPILOT_PROVIDER_WIRE_API={effective_wire_api}",
|
| 1259 |
]
|
| 1260 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1261 |
if not _copilot_model_configured(copilot_args, env):
|
| 1262 |
click.echo(
|
| 1263 |
" Note: Copilot BYOK requires a model. Pass `--model <name>` "
|
|
@@ -1357,6 +1554,49 @@ def codex(
|
|
| 1357 |
global_agents = Path.home() / ".codex" / "AGENTS.md"
|
| 1358 |
_inject_rtk_instructions(global_agents, verbose=verbose)
|
| 1359 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1360 |
if prepare_only:
|
| 1361 |
_inject_codex_provider_config(port)
|
| 1362 |
return
|
|
@@ -1373,7 +1613,15 @@ def codex(
|
|
| 1373 |
# Inject Headroom provider into Codex config so WebSocket traffic also
|
| 1374 |
# routes through the proxy. Codex ignores OPENAI_BASE_URL for its WS
|
| 1375 |
# transport unless a custom provider declares supports_websockets = true.
|
|
|
|
|
|
|
| 1376 |
_inject_codex_provider_config(port)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1377 |
|
| 1378 |
_launch_tool(
|
| 1379 |
binary=codex_bin,
|
|
|
|
| 192 |
return rtk_path
|
| 193 |
|
| 194 |
|
| 195 |
+
_CBM_MCP_SERVER_NAME = "codebase-memory-mcp"
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def _register_cbm_mcp_server(cbm_bin: str) -> None:
|
| 199 |
+
"""Register codebase-memory-mcp as an MCP server in Claude Code.
|
| 200 |
+
|
| 201 |
+
Uses ``claude mcp add`` so the tools appear in ``/mcp`` automatically.
|
| 202 |
+
Idempotent — skips if already registered.
|
| 203 |
+
"""
|
| 204 |
+
claude_cli = shutil.which("claude")
|
| 205 |
+
if not claude_cli:
|
| 206 |
+
return
|
| 207 |
+
|
| 208 |
+
# Check if already registered
|
| 209 |
+
check = subprocess.run(
|
| 210 |
+
[claude_cli, "mcp", "get", _CBM_MCP_SERVER_NAME],
|
| 211 |
+
capture_output=True,
|
| 212 |
+
text=True,
|
| 213 |
+
)
|
| 214 |
+
if check.returncode == 0:
|
| 215 |
+
return # Already registered
|
| 216 |
+
|
| 217 |
+
result = subprocess.run(
|
| 218 |
+
[claude_cli, "mcp", "add", _CBM_MCP_SERVER_NAME, "-s", "user", "--", cbm_bin],
|
| 219 |
+
capture_output=True,
|
| 220 |
+
text=True,
|
| 221 |
+
)
|
| 222 |
+
if result.returncode == 0:
|
| 223 |
+
click.echo(f" Code graph: registered {_CBM_MCP_SERVER_NAME} MCP server")
|
| 224 |
+
else:
|
| 225 |
+
pass # Non-critical — tools won't appear in /mcp but graph still works
|
| 226 |
+
|
| 227 |
+
|
| 228 |
def _setup_code_graph(verbose: bool = False) -> bool:
|
| 229 |
+
"""Ensure codebase-memory-mcp is installed, registered as MCP server, and project is indexed.
|
| 230 |
|
| 231 |
codebase-memory-mcp builds a knowledge graph of the codebase using
|
| 232 |
tree-sitter, enabling the LLM to query code structure (call chains,
|
| 233 |
function definitions, impact analysis) instead of reading entire files.
|
| 234 |
|
| 235 |
+
Steps:
|
| 236 |
+
1. Download the binary if not already present.
|
| 237 |
+
2. Register as an MCP server in Claude Code (``claude mcp add``).
|
| 238 |
+
3. Index the current project (fast, idempotent).
|
| 239 |
+
|
| 240 |
With Claude Code's MCP Tool Search, the 14 graph tools add ~200 tokens
|
| 241 |
overhead per request (not the full ~1,915) — they're lazy-loaded.
|
| 242 |
|
|
|
|
| 256 |
|
| 257 |
cbm_bin = str(cbm_path)
|
| 258 |
|
| 259 |
+
# Register as MCP server so tools appear in /mcp
|
| 260 |
+
_register_cbm_mcp_server(cbm_bin)
|
| 261 |
+
|
| 262 |
# Index current project (fast — ~1s for most repos, idempotent)
|
| 263 |
project_dir = str(Path.cwd())
|
| 264 |
try:
|
|
|
|
| 361 |
# Marker used to detect if instructions are already injected
|
| 362 |
_RTK_MARKER = "<!-- headroom:rtk-instructions -->"
|
| 363 |
|
| 364 |
+
# Memory MCP markers
|
| 365 |
+
_MEMORY_MCP_MARKER = "# --- Headroom memory MCP (auto-injected) ---"
|
| 366 |
+
_MEMORY_MCP_END = "# --- end Headroom memory ---"
|
| 367 |
+
_MEMORY_AGENTS_MARKER = "<!-- headroom:memory-instructions -->"
|
| 368 |
+
|
| 369 |
|
| 370 |
def _ensure_rtk_binary(verbose: bool = False) -> Path | None:
|
| 371 |
"""Ensure rtk binary is installed (download if needed). No hook registration."""
|
|
|
|
| 410 |
config_dir = Path.home() / ".codex"
|
| 411 |
config_file = config_dir / "config.toml"
|
| 412 |
|
| 413 |
+
# model_provider must be a top-level TOML key (before any [section]).
|
| 414 |
+
# The [model_providers.headroom] table can go at the end.
|
| 415 |
+
top_level_marker = "# --- Headroom proxy (auto-injected by headroom wrap codex) ---"
|
| 416 |
+
top_level_block = f'{top_level_marker}\nmodel_provider = "headroom"\n'
|
| 417 |
+
provider_section = (
|
| 418 |
+
f"\n[model_providers.headroom]\n"
|
| 419 |
f'name = "OpenAI via Headroom proxy"\n'
|
| 420 |
f'base_url = "http://127.0.0.1:{port}/v1"\n'
|
| 421 |
f'env_key = "OPENAI_API_KEY"\n'
|
|
|
|
| 424 |
f"# --- end Headroom ---\n"
|
| 425 |
)
|
| 426 |
|
| 427 |
+
marker = top_level_marker
|
| 428 |
end_marker = "# --- end Headroom ---"
|
| 429 |
|
| 430 |
try:
|
|
|
|
| 433 |
if config_file.exists():
|
| 434 |
content = config_file.read_text()
|
| 435 |
if marker in content:
|
| 436 |
+
# Remove existing Headroom blocks entirely
|
| 437 |
start = content.index(marker)
|
| 438 |
end = content.index(end_marker) + len(end_marker)
|
| 439 |
+
content = content[:start].rstrip("\n") + content[end:].lstrip("\n")
|
| 440 |
+
|
| 441 |
+
# Strip any stale top-level model_provider left behind
|
| 442 |
+
import re
|
| 443 |
+
|
| 444 |
+
content = re.sub(r'\nmodel_provider\s*=\s*"headroom"\n', "\n", content)
|
| 445 |
+
|
| 446 |
+
# Place top-level key at the very beginning, provider table at the end
|
| 447 |
+
content = top_level_block + "\n" + content.strip() + "\n" + provider_section
|
| 448 |
else:
|
| 449 |
+
content = top_level_block + "\n" + provider_section
|
| 450 |
|
| 451 |
config_file.write_text(content)
|
| 452 |
click.echo(f" Codex config: injected Headroom provider (WS + HTTP) into {config_file}")
|
|
|
|
| 477 |
return True
|
| 478 |
|
| 479 |
|
| 480 |
+
def _inject_memory_mcp_config(db_path: str, user_id: str) -> None:
|
| 481 |
+
"""Register headroom memory as an MCP server in Codex's config.toml.
|
| 482 |
+
|
| 483 |
+
Idempotent — replaces existing section if present.
|
| 484 |
+
"""
|
| 485 |
+
import sys
|
| 486 |
+
|
| 487 |
+
config_dir = Path.home() / ".codex"
|
| 488 |
+
config_file = config_dir / "config.toml"
|
| 489 |
+
|
| 490 |
+
# Use forward slashes in TOML paths (works on all platforms, avoids
|
| 491 |
+
# backslash escaping issues on Windows)
|
| 492 |
+
python_bin = sys.executable.replace("\\", "/")
|
| 493 |
+
db_path_toml = db_path.replace("\\", "/")
|
| 494 |
+
mcp_section = (
|
| 495 |
+
f"\n{_MEMORY_MCP_MARKER}\n"
|
| 496 |
+
f"[mcp_servers.headroom_memory]\n"
|
| 497 |
+
f'command = "{python_bin}"\n'
|
| 498 |
+
f'args = ["-m", "headroom.memory.mcp_server", "--db", "{db_path_toml}", "--user", "{user_id}"]\n'
|
| 499 |
+
f"startup_timeout_sec = 30\n"
|
| 500 |
+
f"tool_timeout_sec = 30\n"
|
| 501 |
+
f"{_MEMORY_MCP_END}\n"
|
| 502 |
+
)
|
| 503 |
+
|
| 504 |
+
try:
|
| 505 |
+
config_dir.mkdir(parents=True, exist_ok=True)
|
| 506 |
+
|
| 507 |
+
if config_file.exists():
|
| 508 |
+
content = config_file.read_text()
|
| 509 |
+
if _MEMORY_MCP_MARKER in content:
|
| 510 |
+
start = content.index(_MEMORY_MCP_MARKER)
|
| 511 |
+
end = content.index(_MEMORY_MCP_END) + len(_MEMORY_MCP_END)
|
| 512 |
+
content = content[:start].rstrip("\n") + mcp_section + content[end:].lstrip("\n")
|
| 513 |
+
else:
|
| 514 |
+
content = content.rstrip() + "\n" + mcp_section
|
| 515 |
+
else:
|
| 516 |
+
content = mcp_section
|
| 517 |
+
|
| 518 |
+
config_file.write_text(content)
|
| 519 |
+
click.echo(f" Memory MCP: registered in {config_file}")
|
| 520 |
+
except Exception as e:
|
| 521 |
+
click.echo(f" Warning: could not register memory MCP: {e}")
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
def _inject_memory_agents_md(file_path: Path) -> bool:
|
| 525 |
+
"""Inject memory usage guidance into AGENTS.md.
|
| 526 |
+
|
| 527 |
+
Idempotent — skips if marker already present.
|
| 528 |
+
"""
|
| 529 |
+
memory_block = (
|
| 530 |
+
f"{_MEMORY_AGENTS_MARKER}\n"
|
| 531 |
+
"## Memory\n\n"
|
| 532 |
+
"Use the `headroom_memory` MCP server for persistent cross-session knowledge.\n\n"
|
| 533 |
+
"**Before** answering questions about prior decisions, conventions, project context,\n"
|
| 534 |
+
"architecture, user preferences, org info, codenames, debugging history, or anything\n"
|
| 535 |
+
"from past sessions — call `memory_search` first.\n\n"
|
| 536 |
+
"**After** making durable decisions, discovering conventions, or learning important\n"
|
| 537 |
+
"facts — call `memory_save` to persist them for future sessions.\n\n"
|
| 538 |
+
"Memory is your first source of truth for anything not visible in the current conversation.\n"
|
| 539 |
+
)
|
| 540 |
+
|
| 541 |
+
if file_path.exists():
|
| 542 |
+
existing = file_path.read_text()
|
| 543 |
+
if _MEMORY_AGENTS_MARKER in existing:
|
| 544 |
+
return True # Already injected
|
| 545 |
+
with open(file_path, "a") as f:
|
| 546 |
+
f.write("\n\n" + memory_block)
|
| 547 |
+
else:
|
| 548 |
+
file_path.parent.mkdir(parents=True, exist_ok=True)
|
| 549 |
+
file_path.write_text(memory_block)
|
| 550 |
+
|
| 551 |
+
click.echo(f" Memory guidance injected into {file_path.name}")
|
| 552 |
+
return True
|
| 553 |
+
|
| 554 |
+
|
| 555 |
def _resolve_copilot_provider_type(backend: str | None, provider_type: str) -> str:
|
| 556 |
"""Resolve Copilot BYOK provider type for the current proxy backend."""
|
| 557 |
if provider_type != "auto":
|
|
|
|
| 1216 |
signal.signal(signal.SIGINT, cleanup)
|
| 1217 |
signal.signal(signal.SIGTERM, cleanup)
|
| 1218 |
|
| 1219 |
+
# Memory sync BEFORE proxy startup — sync headroom DB ↔ Claude's files
|
| 1220 |
+
if memory:
|
| 1221 |
+
try:
|
| 1222 |
+
import subprocess as _sp
|
| 1223 |
+
|
| 1224 |
+
mem_dir = Path.cwd() / ".headroom"
|
| 1225 |
+
mem_dir.mkdir(parents=True, exist_ok=True)
|
| 1226 |
+
_sync_db = str(mem_dir / "memory.db")
|
| 1227 |
+
_sync_user = os.environ.get("USER", os.environ.get("USERNAME", "default"))
|
| 1228 |
+
|
| 1229 |
+
click.echo(f" Syncing memory (user={_sync_user})...")
|
| 1230 |
+
sync_result = _sp.run(
|
| 1231 |
+
[
|
| 1232 |
+
sys.executable,
|
| 1233 |
+
"-m",
|
| 1234 |
+
"headroom.memory.sync",
|
| 1235 |
+
"--db",
|
| 1236 |
+
_sync_db,
|
| 1237 |
+
"--user",
|
| 1238 |
+
_sync_user,
|
| 1239 |
+
"--agent",
|
| 1240 |
+
"claude",
|
| 1241 |
+
"--force",
|
| 1242 |
+
],
|
| 1243 |
+
capture_output=True,
|
| 1244 |
+
text=True,
|
| 1245 |
+
timeout=30,
|
| 1246 |
+
)
|
| 1247 |
+
if sync_result.returncode == 0 and sync_result.stdout.strip():
|
| 1248 |
+
import json as _json
|
| 1249 |
+
|
| 1250 |
+
stats = _json.loads(sync_result.stdout.strip().split("\n")[-1])
|
| 1251 |
+
imp, exp, ms = stats["imported"], stats["exported"], stats["ms"]
|
| 1252 |
+
if imp or exp:
|
| 1253 |
+
click.echo(f" Memory synced: {imp} imported, {exp} exported ({ms}ms)")
|
| 1254 |
+
else:
|
| 1255 |
+
click.echo(f" Memory: up to date ({ms}ms)")
|
| 1256 |
+
elif sync_result.returncode != 0:
|
| 1257 |
+
click.echo(f" Warning: memory sync error: {sync_result.stderr[-200:]}")
|
| 1258 |
+
except Exception as e:
|
| 1259 |
+
click.echo(f" Warning: memory sync failed: {e}")
|
| 1260 |
+
|
| 1261 |
try:
|
| 1262 |
click.echo()
|
| 1263 |
click.echo(" ╔═══════════════════════════════════════════════╗")
|
|
|
|
| 1411 |
env["COPILOT_PROVIDER_TYPE"] = effective_provider_type
|
| 1412 |
env.pop("COPILOT_PROVIDER_WIRE_API", None)
|
| 1413 |
|
| 1414 |
+
# Copilot BYOK requires COPILOT_PROVIDER_API_KEY — propagate from the
|
| 1415 |
+
# user's existing provider key so they don't have to set it twice.
|
| 1416 |
+
# Note: `headroom wrap copilot` uses Copilot's BYOK mode, which bypasses
|
| 1417 |
+
# GitHub's Copilot API and talks directly to the model provider through
|
| 1418 |
+
# the Headroom proxy. This requires the provider's own API key — a GitHub
|
| 1419 |
+
# Copilot subscription alone is not sufficient for BYOK mode.
|
| 1420 |
+
if not env.get("COPILOT_PROVIDER_API_KEY"):
|
| 1421 |
+
if effective_provider_type == "anthropic":
|
| 1422 |
+
_key = env.get("ANTHROPIC_API_KEY", "")
|
| 1423 |
+
else:
|
| 1424 |
+
_key = env.get("OPENAI_API_KEY", "")
|
| 1425 |
+
if _key:
|
| 1426 |
+
env["COPILOT_PROVIDER_API_KEY"] = _key
|
| 1427 |
+
|
| 1428 |
env_vars_display: list[str]
|
| 1429 |
if effective_provider_type == "anthropic":
|
| 1430 |
env["COPILOT_PROVIDER_BASE_URL"] = f"http://127.0.0.1:{port}"
|
|
|
|
| 1442 |
f"COPILOT_PROVIDER_WIRE_API={effective_wire_api}",
|
| 1443 |
]
|
| 1444 |
|
| 1445 |
+
if not env.get("COPILOT_PROVIDER_API_KEY"):
|
| 1446 |
+
src = "ANTHROPIC_API_KEY" if effective_provider_type == "anthropic" else "OPENAI_API_KEY"
|
| 1447 |
+
click.echo(
|
| 1448 |
+
f"\n Error: Copilot BYOK mode requires a provider API key.\n"
|
| 1449 |
+
f" `headroom wrap copilot` uses Copilot's BYOK mode, which bypasses GitHub's\n"
|
| 1450 |
+
f" Copilot API and routes requests directly to the model provider through the\n"
|
| 1451 |
+
f" Headroom proxy. A GitHub Copilot subscription alone is not sufficient.\n\n"
|
| 1452 |
+
f" Set one of:\n"
|
| 1453 |
+
f" export {src}=sk-... # recommended\n"
|
| 1454 |
+
f" export COPILOT_PROVIDER_API_KEY=sk-... # also works\n"
|
| 1455 |
+
)
|
| 1456 |
+
raise SystemExit(1)
|
| 1457 |
+
|
| 1458 |
if not _copilot_model_configured(copilot_args, env):
|
| 1459 |
click.echo(
|
| 1460 |
" Note: Copilot BYOK requires a model. Pass `--model <name>` "
|
|
|
|
| 1554 |
global_agents = Path.home() / ".codex" / "AGENTS.md"
|
| 1555 |
_inject_rtk_instructions(global_agents, verbose=verbose)
|
| 1556 |
|
| 1557 |
+
# Setup memory MCP server for Codex (native tool integration)
|
| 1558 |
+
if memory:
|
| 1559 |
+
click.echo(" Setting up memory for Codex...")
|
| 1560 |
+
mem_dir = Path.cwd() / ".headroom"
|
| 1561 |
+
mem_dir.mkdir(parents=True, exist_ok=True)
|
| 1562 |
+
db_path = str(mem_dir / "memory.db")
|
| 1563 |
+
mem_user = os.environ.get("USER", os.environ.get("USERNAME", "default"))
|
| 1564 |
+
|
| 1565 |
+
# Register MCP server in Codex config
|
| 1566 |
+
_inject_memory_mcp_config(db_path, mem_user)
|
| 1567 |
+
|
| 1568 |
+
# Inject memory guidance into project AGENTS.md
|
| 1569 |
+
agents_md = Path.cwd() / "AGENTS.md"
|
| 1570 |
+
_inject_memory_agents_md(agents_md)
|
| 1571 |
+
|
| 1572 |
+
# Sync Claude's memories → DB so MCP search finds them
|
| 1573 |
+
try:
|
| 1574 |
+
import asyncio
|
| 1575 |
+
|
| 1576 |
+
from headroom.memory.backends.local import LocalBackend, LocalBackendConfig
|
| 1577 |
+
from headroom.memory.sync import sync_import
|
| 1578 |
+
from headroom.memory.sync_adapters.claude_code import (
|
| 1579 |
+
ClaudeCodeAdapter,
|
| 1580 |
+
get_claude_memory_dir,
|
| 1581 |
+
)
|
| 1582 |
+
|
| 1583 |
+
claude_memory_dir = get_claude_memory_dir()
|
| 1584 |
+
|
| 1585 |
+
async def _import_claude_memories() -> int:
|
| 1586 |
+
config = LocalBackendConfig(db_path=db_path)
|
| 1587 |
+
backend = LocalBackend(config)
|
| 1588 |
+
await backend._ensure_initialized()
|
| 1589 |
+
adapter = ClaudeCodeAdapter(claude_memory_dir)
|
| 1590 |
+
count = await sync_import(backend, adapter, mem_user)
|
| 1591 |
+
await backend.close()
|
| 1592 |
+
return count
|
| 1593 |
+
|
| 1594 |
+
imported = asyncio.run(_import_claude_memories())
|
| 1595 |
+
if imported:
|
| 1596 |
+
click.echo(f" Memory: imported {imported} memories from Claude")
|
| 1597 |
+
except Exception as e:
|
| 1598 |
+
click.echo(f" Warning: Claude memory import failed: {e}")
|
| 1599 |
+
|
| 1600 |
if prepare_only:
|
| 1601 |
_inject_codex_provider_config(port)
|
| 1602 |
return
|
|
|
|
| 1613 |
# Inject Headroom provider into Codex config so WebSocket traffic also
|
| 1614 |
# routes through the proxy. Codex ignores OPENAI_BASE_URL for its WS
|
| 1615 |
# transport unless a custom provider declares supports_websockets = true.
|
| 1616 |
+
# NOTE: this must run BEFORE _inject_memory_mcp_config because it rewrites
|
| 1617 |
+
# the config file. Re-inject MCP config after if memory is enabled.
|
| 1618 |
_inject_codex_provider_config(port)
|
| 1619 |
+
if memory:
|
| 1620 |
+
mem_dir = Path.cwd() / ".headroom"
|
| 1621 |
+
_inject_memory_mcp_config(
|
| 1622 |
+
str(mem_dir / "memory.db"),
|
| 1623 |
+
os.environ.get("USER", os.environ.get("USERNAME", "default")),
|
| 1624 |
+
)
|
| 1625 |
|
| 1626 |
_launch_tool(
|
| 1627 |
binary=codex_bin,
|
headroom/install/state.py
CHANGED
|
@@ -3,21 +3,30 @@
|
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
import json
|
|
|
|
| 6 |
import shutil
|
| 7 |
from dataclasses import asdict
|
| 8 |
|
| 9 |
from .models import ArtifactRecord, DeploymentManifest, ManagedMutation, iso_utc_now
|
| 10 |
from .paths import deploy_root, manifest_path, profile_root
|
| 11 |
|
|
|
|
| 12 |
|
| 13 |
-
def save_manifest(manifest: DeploymentManifest) -> None:
|
| 14 |
-
"""Persist a deployment manifest to disk."""
|
| 15 |
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
|
| 23 |
def load_manifest(profile: str = "default") -> DeploymentManifest | None:
|
|
|
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
import json
|
| 6 |
+
import logging
|
| 7 |
import shutil
|
| 8 |
from dataclasses import asdict
|
| 9 |
|
| 10 |
from .models import ArtifactRecord, DeploymentManifest, ManagedMutation, iso_utc_now
|
| 11 |
from .paths import deploy_root, manifest_path, profile_root
|
| 12 |
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
|
|
|
|
|
|
|
| 15 |
|
| 16 |
+
def save_manifest(manifest: DeploymentManifest) -> None:
|
| 17 |
+
"""Persist a deployment manifest to disk.
|
| 18 |
+
|
| 19 |
+
Gracefully handles read-only filesystems by logging a warning
|
| 20 |
+
instead of crashing.
|
| 21 |
+
"""
|
| 22 |
+
try:
|
| 23 |
+
root = profile_root(manifest.profile)
|
| 24 |
+
root.mkdir(parents=True, exist_ok=True)
|
| 25 |
+
manifest.updated_at = iso_utc_now()
|
| 26 |
+
path = manifest_path(manifest.profile)
|
| 27 |
+
path.write_text(json.dumps(asdict(manifest), indent=2) + "\n")
|
| 28 |
+
except OSError as e:
|
| 29 |
+
logger.warning("Cannot save deployment manifest: %s — continuing without persistence", e)
|
| 30 |
|
| 31 |
|
| 32 |
def load_manifest(profile: str = "default") -> DeploymentManifest | None:
|
headroom/learn/base.py
CHANGED
|
@@ -25,7 +25,7 @@ class ConversationScanner(ABC):
|
|
| 25 |
...
|
| 26 |
|
| 27 |
@abstractmethod
|
| 28 |
-
def scan_project(self, project: ProjectInfo) -> list[SessionData]:
|
| 29 |
"""Scan all sessions for a project, returning normalized tool calls."""
|
| 30 |
...
|
| 31 |
|
|
@@ -99,8 +99,14 @@ class LearnPlugin(ABC):
|
|
| 99 |
...
|
| 100 |
|
| 101 |
@abstractmethod
|
| 102 |
-
def scan_project(self, project: ProjectInfo) -> list[SessionData]:
|
| 103 |
-
"""Scan all sessions for a project, returning normalized data.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
...
|
| 105 |
|
| 106 |
# --- Writing ---
|
|
|
|
| 25 |
...
|
| 26 |
|
| 27 |
@abstractmethod
|
| 28 |
+
def scan_project(self, project: ProjectInfo, max_workers: int = 1) -> list[SessionData]:
|
| 29 |
"""Scan all sessions for a project, returning normalized tool calls."""
|
| 30 |
...
|
| 31 |
|
|
|
|
| 99 |
...
|
| 100 |
|
| 101 |
@abstractmethod
|
| 102 |
+
def scan_project(self, project: ProjectInfo, max_workers: int = 1) -> list[SessionData]:
|
| 103 |
+
"""Scan all sessions for a project, returning normalized data.
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
project: The project to scan.
|
| 107 |
+
max_workers: Number of threads for parallel file scanning.
|
| 108 |
+
1 (default) = serial. >1 = concurrent.
|
| 109 |
+
"""
|
| 110 |
...
|
| 111 |
|
| 112 |
# --- Writing ---
|
headroom/learn/plugins/claude.py
CHANGED
|
@@ -106,16 +106,24 @@ class ClaudeCodePlugin(LearnPlugin, ConversationScanner):
|
|
| 106 |
|
| 107 |
return projects
|
| 108 |
|
| 109 |
-
def scan_project(self, project: ProjectInfo) -> list[SessionData]:
|
| 110 |
"""Scan all conversation JSONL files for a project."""
|
| 111 |
-
sessions = []
|
| 112 |
jsonl_files = sorted(project.data_path.glob("*.jsonl"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
|
| 114 |
-
|
| 115 |
-
session = self._scan_session(jsonl_path)
|
| 116 |
-
if session and session.tool_calls:
|
| 117 |
-
sessions.append(session)
|
| 118 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
return sessions
|
| 120 |
|
| 121 |
def _scan_session(self, jsonl_path: Path) -> SessionData | None:
|
|
@@ -375,9 +383,9 @@ def _component_tokenizations(component: str) -> list[list[str]]:
|
|
| 375 |
|
| 376 |
add([component])
|
| 377 |
|
| 378 |
-
for separator in ("-", ".", None):
|
| 379 |
if separator is None:
|
| 380 |
-
tokens = [token for token in re.split(r"[-.]", component) if token]
|
| 381 |
else:
|
| 382 |
tokens = [token for token in component.split(separator) if token]
|
| 383 |
add(tokens)
|
|
@@ -385,9 +393,9 @@ def _component_tokenizations(component: str) -> list[list[str]]:
|
|
| 385 |
if component.startswith(".") and len(component) > 1:
|
| 386 |
hidden_component = component[1:]
|
| 387 |
add(["", hidden_component])
|
| 388 |
-
for separator in ("-", ".", None):
|
| 389 |
if separator is None:
|
| 390 |
-
tokens = [token for token in re.split(r"[-.]", hidden_component) if token]
|
| 391 |
else:
|
| 392 |
tokens = [token for token in hidden_component.split(separator) if token]
|
| 393 |
add(["", *tokens])
|
|
|
|
| 106 |
|
| 107 |
return projects
|
| 108 |
|
| 109 |
+
def scan_project(self, project: ProjectInfo, max_workers: int = 1) -> list[SessionData]:
|
| 110 |
"""Scan all conversation JSONL files for a project."""
|
|
|
|
| 111 |
jsonl_files = sorted(project.data_path.glob("*.jsonl"))
|
| 112 |
+
if not jsonl_files:
|
| 113 |
+
return []
|
| 114 |
+
|
| 115 |
+
if max_workers <= 1 or len(jsonl_files) <= 1:
|
| 116 |
+
return [s for f in jsonl_files if (s := self._scan_session(f)) and s.tool_calls]
|
| 117 |
|
| 118 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
+
sessions: list[SessionData] = []
|
| 121 |
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
| 122 |
+
futures = {executor.submit(self._scan_session, f): f for f in jsonl_files}
|
| 123 |
+
for future in as_completed(futures):
|
| 124 |
+
session = future.result()
|
| 125 |
+
if session and session.tool_calls:
|
| 126 |
+
sessions.append(session)
|
| 127 |
return sessions
|
| 128 |
|
| 129 |
def _scan_session(self, jsonl_path: Path) -> SessionData | None:
|
|
|
|
| 383 |
|
| 384 |
add([component])
|
| 385 |
|
| 386 |
+
for separator in ("-", ".", "_", None):
|
| 387 |
if separator is None:
|
| 388 |
+
tokens = [token for token in re.split(r"[-._]", component) if token]
|
| 389 |
else:
|
| 390 |
tokens = [token for token in component.split(separator) if token]
|
| 391 |
add(tokens)
|
|
|
|
| 393 |
if component.startswith(".") and len(component) > 1:
|
| 394 |
hidden_component = component[1:]
|
| 395 |
add(["", hidden_component])
|
| 396 |
+
for separator in ("-", ".", "_", None):
|
| 397 |
if separator is None:
|
| 398 |
+
tokens = [token for token in re.split(r"[-._]", hidden_component) if token]
|
| 399 |
else:
|
| 400 |
tokens = [token for token in hidden_component.split(separator) if token]
|
| 401 |
add(["", *tokens])
|
headroom/learn/plugins/codex.py
CHANGED
|
@@ -91,13 +91,24 @@ class CodexPlugin(LearnPlugin, ConversationScanner):
|
|
| 91 |
)
|
| 92 |
]
|
| 93 |
|
| 94 |
-
def scan_project(self, project: ProjectInfo) -> list[SessionData]:
|
| 95 |
"""Scan all Codex session JSON files."""
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
return sessions
|
| 102 |
|
| 103 |
def _scan_session(self, json_path: Path) -> SessionData | None:
|
|
|
|
| 91 |
)
|
| 92 |
]
|
| 93 |
|
| 94 |
+
def scan_project(self, project: ProjectInfo, max_workers: int = 1) -> list[SessionData]:
|
| 95 |
"""Scan all Codex session JSON files."""
|
| 96 |
+
session_files = self._iter_session_files(project.data_path)
|
| 97 |
+
if not session_files:
|
| 98 |
+
return []
|
| 99 |
+
|
| 100 |
+
if max_workers <= 1 or len(session_files) <= 1:
|
| 101 |
+
return [s for f in session_files if (s := self._scan_session(f)) and s.tool_calls]
|
| 102 |
+
|
| 103 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 104 |
+
|
| 105 |
+
sessions: list[SessionData] = []
|
| 106 |
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
| 107 |
+
futures = {executor.submit(self._scan_session, f): f for f in session_files}
|
| 108 |
+
for future in as_completed(futures):
|
| 109 |
+
session = future.result()
|
| 110 |
+
if session and session.tool_calls:
|
| 111 |
+
sessions.append(session)
|
| 112 |
return sessions
|
| 113 |
|
| 114 |
def _scan_session(self, json_path: Path) -> SessionData | None:
|
headroom/learn/plugins/gemini.py
CHANGED
|
@@ -106,18 +106,26 @@ class GeminiPlugin(LearnPlugin, ConversationScanner):
|
|
| 106 |
|
| 107 |
return projects
|
| 108 |
|
| 109 |
-
def scan_project(self, project: ProjectInfo) -> list[SessionData]:
|
| 110 |
"""Scan all Gemini session files for a project."""
|
| 111 |
-
sessions = []
|
| 112 |
session_files = sorted(project.data_path.glob("session-*.json")) + sorted(
|
| 113 |
project.data_path.glob("session-*.jsonl")
|
| 114 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
|
| 116 |
-
|
| 117 |
-
session = self._scan_session(session_path)
|
| 118 |
-
if session and session.tool_calls:
|
| 119 |
-
sessions.append(session)
|
| 120 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
return sessions
|
| 122 |
|
| 123 |
def _scan_session(self, session_path: Path) -> SessionData | None:
|
|
|
|
| 106 |
|
| 107 |
return projects
|
| 108 |
|
| 109 |
+
def scan_project(self, project: ProjectInfo, max_workers: int = 1) -> list[SessionData]:
|
| 110 |
"""Scan all Gemini session files for a project."""
|
|
|
|
| 111 |
session_files = sorted(project.data_path.glob("session-*.json")) + sorted(
|
| 112 |
project.data_path.glob("session-*.jsonl")
|
| 113 |
)
|
| 114 |
+
if not session_files:
|
| 115 |
+
return []
|
| 116 |
+
|
| 117 |
+
if max_workers <= 1 or len(session_files) <= 1:
|
| 118 |
+
return [s for f in session_files if (s := self._scan_session(f)) and s.tool_calls]
|
| 119 |
|
| 120 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
|
|
|
|
|
|
|
|
| 121 |
|
| 122 |
+
sessions: list[SessionData] = []
|
| 123 |
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
| 124 |
+
futures = {executor.submit(self._scan_session, f): f for f in session_files}
|
| 125 |
+
for future in as_completed(futures):
|
| 126 |
+
session = future.result()
|
| 127 |
+
if session and session.tool_calls:
|
| 128 |
+
sessions.append(session)
|
| 129 |
return sessions
|
| 130 |
|
| 131 |
def _scan_session(self, session_path: Path) -> SessionData | None:
|
headroom/memory/factory.py
CHANGED
|
@@ -7,6 +7,7 @@ and proper wiring between components.
|
|
| 7 |
|
| 8 |
from __future__ import annotations
|
| 9 |
|
|
|
|
| 10 |
from typing import TYPE_CHECKING
|
| 11 |
|
| 12 |
from headroom.memory.config import (
|
|
@@ -199,12 +200,21 @@ def _create_vector_index(config: MemoryConfig) -> VectorIndex:
|
|
| 199 |
|
| 200 |
from headroom.memory.adapters.hnsw import HNSWVectorIndex
|
| 201 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
return HNSWVectorIndex(
|
| 203 |
dimension=config.vector_dimension,
|
| 204 |
ef_construction=config.hnsw_ef_construction,
|
| 205 |
m=config.hnsw_m,
|
| 206 |
ef_search=config.hnsw_ef_search,
|
| 207 |
max_entries=config.hnsw_max_entries,
|
|
|
|
|
|
|
| 208 |
)
|
| 209 |
|
| 210 |
raise ValueError(f"Unknown vector backend: {config.vector_backend}")
|
|
|
|
| 7 |
|
| 8 |
from __future__ import annotations
|
| 9 |
|
| 10 |
+
from pathlib import Path
|
| 11 |
from typing import TYPE_CHECKING
|
| 12 |
|
| 13 |
from headroom.memory.config import (
|
|
|
|
| 200 |
|
| 201 |
from headroom.memory.adapters.hnsw import HNSWVectorIndex
|
| 202 |
|
| 203 |
+
# Derive persistent save path from the main DB path so the HNSW
|
| 204 |
+
# index survives across process restarts (critical for cross-agent
|
| 205 |
+
# interop: memories saved by Codex MCP must be searchable by Claude).
|
| 206 |
+
hnsw_save_path: str | Path | None = None
|
| 207 |
+
if config.db_path:
|
| 208 |
+
hnsw_save_path = config.db_path.parent / f"{config.db_path.stem}_hnsw"
|
| 209 |
+
|
| 210 |
return HNSWVectorIndex(
|
| 211 |
dimension=config.vector_dimension,
|
| 212 |
ef_construction=config.hnsw_ef_construction,
|
| 213 |
m=config.hnsw_m,
|
| 214 |
ef_search=config.hnsw_ef_search,
|
| 215 |
max_entries=config.hnsw_max_entries,
|
| 216 |
+
save_path=hnsw_save_path,
|
| 217 |
+
auto_save=True,
|
| 218 |
)
|
| 219 |
|
| 220 |
raise ValueError(f"Unknown vector backend: {config.vector_backend}")
|
headroom/memory/mcp_server.py
ADDED
|
@@ -0,0 +1,375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Headroom Memory MCP Server.
|
| 2 |
+
|
| 3 |
+
A stdio MCP server that exposes headroom's memory backend as tools
|
| 4 |
+
that Codex (or any MCP-compatible client) can call natively.
|
| 5 |
+
|
| 6 |
+
Tools:
|
| 7 |
+
memory_search — semantic search across stored memories
|
| 8 |
+
memory_save — persist a new fact/decision/convention
|
| 9 |
+
|
| 10 |
+
Design:
|
| 11 |
+
- Embedder is pre-loaded at startup (no cold-start on first query)
|
| 12 |
+
- On startup, any memories missing vector embeddings are re-indexed
|
| 13 |
+
(fixes interop gap when memories were saved via a different path)
|
| 14 |
+
- Save always generates embeddings inline
|
| 15 |
+
|
| 16 |
+
Usage:
|
| 17 |
+
# Standalone (for testing):
|
| 18 |
+
python -m headroom.memory.mcp_server --db /path/to/.headroom/memory.db
|
| 19 |
+
|
| 20 |
+
# Registered in Codex config.toml (done by `headroom wrap codex --memory`):
|
| 21 |
+
[mcp_servers.headroom_memory]
|
| 22 |
+
command = "python"
|
| 23 |
+
args = ["-m", "headroom.memory.mcp_server", "--db", ".headroom/memory.db"]
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
from __future__ import annotations
|
| 27 |
+
|
| 28 |
+
import argparse
|
| 29 |
+
import asyncio
|
| 30 |
+
import logging
|
| 31 |
+
import os
|
| 32 |
+
import sys
|
| 33 |
+
from pathlib import Path
|
| 34 |
+
from typing import Any
|
| 35 |
+
|
| 36 |
+
from mcp.server import Server
|
| 37 |
+
from mcp.server.stdio import stdio_server
|
| 38 |
+
from mcp.types import TextContent, Tool
|
| 39 |
+
|
| 40 |
+
from headroom.memory.backends.local import LocalBackend, LocalBackendConfig
|
| 41 |
+
|
| 42 |
+
logger = logging.getLogger("headroom.memory.mcp")
|
| 43 |
+
|
| 44 |
+
# ---------------------------------------------------------------------------
|
| 45 |
+
# Tool definitions
|
| 46 |
+
# ---------------------------------------------------------------------------
|
| 47 |
+
|
| 48 |
+
_TOOLS = [
|
| 49 |
+
Tool(
|
| 50 |
+
name="memory_search",
|
| 51 |
+
description=(
|
| 52 |
+
"Search persistent memory for relevant knowledge from prior sessions. "
|
| 53 |
+
"Use this for questions about architecture, conventions, prior decisions, "
|
| 54 |
+
"project context, user preferences, org info, codenames, debugging history, "
|
| 55 |
+
"or anything that might have been discussed before."
|
| 56 |
+
),
|
| 57 |
+
inputSchema={
|
| 58 |
+
"type": "object",
|
| 59 |
+
"properties": {
|
| 60 |
+
"query": {
|
| 61 |
+
"type": "string",
|
| 62 |
+
"description": "Natural-language search query.",
|
| 63 |
+
},
|
| 64 |
+
"top_k": {
|
| 65 |
+
"type": "integer",
|
| 66 |
+
"description": "Max results to return (default 10).",
|
| 67 |
+
"default": 10,
|
| 68 |
+
},
|
| 69 |
+
},
|
| 70 |
+
"required": ["query"],
|
| 71 |
+
},
|
| 72 |
+
),
|
| 73 |
+
Tool(
|
| 74 |
+
name="memory_save",
|
| 75 |
+
description=(
|
| 76 |
+
"Save information to persistent memory for future sessions. "
|
| 77 |
+
"Use this for decisions, conventions, architecture context, "
|
| 78 |
+
"user preferences, project facts, or anything worth remembering.\n\n"
|
| 79 |
+
"IMPORTANT: Break information into atomic facts — one fact per "
|
| 80 |
+
"entry in the 'facts' array. Each fact should be a single, "
|
| 81 |
+
"self-contained statement that answers one question. "
|
| 82 |
+
"Do NOT combine multiple facts into one string.\n\n"
|
| 83 |
+
"Good: facts: ['Repo owner is Tejas C.', 'User prefers dark mode']\n"
|
| 84 |
+
"Bad: facts: ['Repo owner is Tejas C. Prefers dark mode.']"
|
| 85 |
+
),
|
| 86 |
+
inputSchema={
|
| 87 |
+
"type": "object",
|
| 88 |
+
"properties": {
|
| 89 |
+
"facts": {
|
| 90 |
+
"type": "array",
|
| 91 |
+
"items": {"type": "string"},
|
| 92 |
+
"description": (
|
| 93 |
+
"Array of atomic facts to save. Each entry should be "
|
| 94 |
+
"one self-contained fact. The system stores and indexes "
|
| 95 |
+
"each fact separately for precise retrieval."
|
| 96 |
+
),
|
| 97 |
+
},
|
| 98 |
+
"importance": {
|
| 99 |
+
"type": "number",
|
| 100 |
+
"description": "0.0 (low) to 1.0 (critical). Default 0.7.",
|
| 101 |
+
"default": 0.7,
|
| 102 |
+
},
|
| 103 |
+
},
|
| 104 |
+
"required": [],
|
| 105 |
+
},
|
| 106 |
+
),
|
| 107 |
+
]
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
# ---------------------------------------------------------------------------
|
| 111 |
+
# Startup: pre-load embedder + re-index unembedded memories
|
| 112 |
+
# ---------------------------------------------------------------------------
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
async def _warm_up_backend(backend: LocalBackend, user_id: str) -> None:
|
| 116 |
+
"""Pre-load the embedder and re-index memories that lack embeddings.
|
| 117 |
+
|
| 118 |
+
Memories saved via other paths (e.g. Claude Code proxy direct SQL)
|
| 119 |
+
may exist in the store but have no vector embeddings. This scans
|
| 120 |
+
for those and re-indexes them so vector search works across agents.
|
| 121 |
+
"""
|
| 122 |
+
await backend._ensure_initialized()
|
| 123 |
+
hm = backend._hierarchical_memory
|
| 124 |
+
if hm is None:
|
| 125 |
+
return
|
| 126 |
+
|
| 127 |
+
# Force-load the embedder now (not lazily on first search)
|
| 128 |
+
_dummy = await hm._embedder.embed("warmup")
|
| 129 |
+
logger.info("Memory MCP: embedder pre-loaded")
|
| 130 |
+
|
| 131 |
+
# Ensure ALL memories are in the vector index.
|
| 132 |
+
# Memories saved via other agents (Claude Code proxy, direct SQL) may
|
| 133 |
+
# exist in the store but not be indexed — re-embed and index them all.
|
| 134 |
+
all_memories = await backend.get_user_memories(user_id, limit=500)
|
| 135 |
+
if not all_memories:
|
| 136 |
+
return
|
| 137 |
+
|
| 138 |
+
indexed = 0
|
| 139 |
+
for mem in all_memories:
|
| 140 |
+
if mem.embedding is None:
|
| 141 |
+
mem.embedding = await hm._embedder.embed(mem.content)
|
| 142 |
+
await hm._store.save(mem)
|
| 143 |
+
await hm._vector_index.index(mem)
|
| 144 |
+
indexed += 1
|
| 145 |
+
|
| 146 |
+
logger.info(f"Memory MCP: indexed {indexed} memories into vector store")
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
# ---------------------------------------------------------------------------
|
| 150 |
+
# MCP Server
|
| 151 |
+
# ---------------------------------------------------------------------------
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def create_memory_server(db_path: str, user_id: str = "default") -> Server:
|
| 155 |
+
"""Create an MCP server backed by headroom's local memory."""
|
| 156 |
+
|
| 157 |
+
server = Server("headroom-memory")
|
| 158 |
+
_backend: LocalBackend | None = None
|
| 159 |
+
_init_task: asyncio.Task | None = None
|
| 160 |
+
|
| 161 |
+
async def _init_backend() -> LocalBackend:
|
| 162 |
+
"""Initialize backend with ONNX embedder (fast, no PyTorch)."""
|
| 163 |
+
nonlocal _backend
|
| 164 |
+
config = LocalBackendConfig(db_path=db_path, embedder_backend="onnx")
|
| 165 |
+
_backend = LocalBackend(config)
|
| 166 |
+
await _warm_up_backend(_backend, user_id)
|
| 167 |
+
logger.info(f"Memory MCP: ready (db={db_path}, user={user_id})")
|
| 168 |
+
return _backend
|
| 169 |
+
|
| 170 |
+
async def _get_backend() -> LocalBackend:
|
| 171 |
+
nonlocal _backend, _init_task
|
| 172 |
+
if _backend is not None:
|
| 173 |
+
return _backend
|
| 174 |
+
# Wait for background init if it's running
|
| 175 |
+
if _init_task is not None:
|
| 176 |
+
await _init_task
|
| 177 |
+
return _backend # type: ignore[return-value]
|
| 178 |
+
# Fallback: init inline (shouldn't normally happen)
|
| 179 |
+
return await _init_backend()
|
| 180 |
+
|
| 181 |
+
@server.list_tools()
|
| 182 |
+
async def list_tools() -> list[Tool]:
|
| 183 |
+
# Kick off background init on first list_tools (called at MCP handshake)
|
| 184 |
+
nonlocal _init_task
|
| 185 |
+
if _backend is None and _init_task is None:
|
| 186 |
+
_init_task = asyncio.create_task(_init_backend())
|
| 187 |
+
return _TOOLS
|
| 188 |
+
|
| 189 |
+
@server.call_tool()
|
| 190 |
+
async def call_tool(name: str, arguments: dict) -> list[TextContent]:
|
| 191 |
+
backend = await _get_backend()
|
| 192 |
+
|
| 193 |
+
if name == "memory_search":
|
| 194 |
+
return await _handle_search(backend, arguments, user_id)
|
| 195 |
+
elif name == "memory_save":
|
| 196 |
+
return await _handle_save(backend, arguments, user_id)
|
| 197 |
+
|
| 198 |
+
return [TextContent(type="text", text=f"Unknown tool: {name}")]
|
| 199 |
+
|
| 200 |
+
return server
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
async def _handle_search(
|
| 204 |
+
backend: LocalBackend, arguments: dict[str, Any], user_id: str
|
| 205 |
+
) -> list[TextContent]:
|
| 206 |
+
query = arguments.get("query", "")
|
| 207 |
+
top_k = arguments.get("top_k", 10)
|
| 208 |
+
|
| 209 |
+
if not query:
|
| 210 |
+
return [TextContent(type="text", text="Error: query is required")]
|
| 211 |
+
|
| 212 |
+
try:
|
| 213 |
+
# Over-fetch to compensate for filtering out superseded memories
|
| 214 |
+
results = await backend.search_memories(
|
| 215 |
+
query=query,
|
| 216 |
+
user_id=user_id,
|
| 217 |
+
top_k=top_k * 3,
|
| 218 |
+
include_related=True,
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
if not results:
|
| 222 |
+
return [TextContent(type="text", text="No memories found.")]
|
| 223 |
+
|
| 224 |
+
# Filter out superseded memories — only return current/active ones.
|
| 225 |
+
# Re-check the store because in-memory HNSW metadata may be stale.
|
| 226 |
+
active_results = []
|
| 227 |
+
for r in results:
|
| 228 |
+
if getattr(r.memory, "superseded_by", None):
|
| 229 |
+
continue
|
| 230 |
+
# Double-check against the store for recently superseded memories
|
| 231 |
+
try:
|
| 232 |
+
stored = await backend.get_memory(r.memory.id)
|
| 233 |
+
if stored and getattr(stored, "superseded_by", None):
|
| 234 |
+
continue
|
| 235 |
+
except Exception:
|
| 236 |
+
pass
|
| 237 |
+
active_results.append(r)
|
| 238 |
+
|
| 239 |
+
if not active_results:
|
| 240 |
+
return [TextContent(type="text", text="No memories found.")]
|
| 241 |
+
|
| 242 |
+
# Trim to requested top_k
|
| 243 |
+
active_results = active_results[:top_k]
|
| 244 |
+
|
| 245 |
+
lines = []
|
| 246 |
+
for i, r in enumerate(active_results, 1):
|
| 247 |
+
score = f"{r.score:.2f}" if hasattr(r, "score") else "?"
|
| 248 |
+
lines.append(f"{i}. [relevance={score}] {r.memory.content}")
|
| 249 |
+
if hasattr(r, "related_entities") and r.related_entities:
|
| 250 |
+
lines.append(f" Related: {', '.join(r.related_entities[:3])}")
|
| 251 |
+
|
| 252 |
+
return [TextContent(type="text", text="\n".join(lines))]
|
| 253 |
+
except Exception as e:
|
| 254 |
+
logger.error(f"memory_search failed: {e}")
|
| 255 |
+
return [TextContent(type="text", text=f"Search error: {e}")]
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
# Similarity threshold for auto-supersession: if a new memory is this
|
| 259 |
+
# similar to an existing one, it replaces (supersedes) the old one.
|
| 260 |
+
_SUPERSEDE_SIMILARITY = 0.70
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
async def _handle_save(
|
| 264 |
+
backend: LocalBackend, arguments: dict[str, Any], user_id: str
|
| 265 |
+
) -> list[TextContent]:
|
| 266 |
+
facts = arguments.get("facts", [])
|
| 267 |
+
importance = arguments.get("importance", 0.7)
|
| 268 |
+
|
| 269 |
+
# Backward compat: accept single "content" string too
|
| 270 |
+
if not facts:
|
| 271 |
+
content = arguments.get("content", "")
|
| 272 |
+
if content:
|
| 273 |
+
facts = [content]
|
| 274 |
+
|
| 275 |
+
if not facts:
|
| 276 |
+
return [TextContent(type="text", text="Error: facts array is required")]
|
| 277 |
+
|
| 278 |
+
try:
|
| 279 |
+
saved = 0
|
| 280 |
+
superseded = 0
|
| 281 |
+
results_lines: list[str] = []
|
| 282 |
+
|
| 283 |
+
for fact in facts:
|
| 284 |
+
fact = fact.strip()
|
| 285 |
+
if not fact:
|
| 286 |
+
continue
|
| 287 |
+
|
| 288 |
+
# Check for semantically similar existing memory to auto-supersede
|
| 289 |
+
superseded_id: str | None = None
|
| 290 |
+
try:
|
| 291 |
+
existing = await backend.search_memories(
|
| 292 |
+
query=fact,
|
| 293 |
+
user_id=user_id,
|
| 294 |
+
top_k=3,
|
| 295 |
+
)
|
| 296 |
+
for r in existing:
|
| 297 |
+
if getattr(r.memory, "superseded_by", None):
|
| 298 |
+
continue
|
| 299 |
+
if r.score >= _SUPERSEDE_SIMILARITY:
|
| 300 |
+
superseded_id = r.memory.id
|
| 301 |
+
logger.info(
|
| 302 |
+
f"Memory MCP: auto-superseding [{r.memory.id[:8]}] "
|
| 303 |
+
f"(similarity={r.score:.2f}): {r.memory.content[:60]}"
|
| 304 |
+
)
|
| 305 |
+
break
|
| 306 |
+
except Exception:
|
| 307 |
+
pass
|
| 308 |
+
|
| 309 |
+
if superseded_id:
|
| 310 |
+
memory = await backend.update_memory(
|
| 311 |
+
memory_id=superseded_id,
|
| 312 |
+
new_content=fact,
|
| 313 |
+
)
|
| 314 |
+
results_lines.append(
|
| 315 |
+
f" updated [{superseded_id[:8]}→{memory.id[:8]}]: {fact[:60]}"
|
| 316 |
+
)
|
| 317 |
+
superseded += 1
|
| 318 |
+
else:
|
| 319 |
+
memory = await backend.save_memory(
|
| 320 |
+
content=fact,
|
| 321 |
+
user_id=user_id,
|
| 322 |
+
importance=importance,
|
| 323 |
+
)
|
| 324 |
+
results_lines.append(f" saved [{memory.id[:8]}]: {fact[:60]}")
|
| 325 |
+
saved += 1
|
| 326 |
+
|
| 327 |
+
summary = f"Saved {saved} new, updated {superseded} existing ({saved + superseded} total)"
|
| 328 |
+
return [TextContent(type="text", text=summary + "\n" + "\n".join(results_lines))]
|
| 329 |
+
except Exception as e:
|
| 330 |
+
logger.error(f"memory_save failed: {e}")
|
| 331 |
+
return [TextContent(type="text", text=f"Save error: {e}")]
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
# ---------------------------------------------------------------------------
|
| 335 |
+
# Entry point
|
| 336 |
+
# ---------------------------------------------------------------------------
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
async def _run(db_path: str, user_id: str) -> None:
|
| 340 |
+
server = create_memory_server(db_path, user_id)
|
| 341 |
+
async with stdio_server() as (read_stream, write_stream):
|
| 342 |
+
await server.run(read_stream, write_stream, server.create_initialization_options())
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
def main() -> None:
|
| 346 |
+
parser = argparse.ArgumentParser(description="Headroom Memory MCP Server")
|
| 347 |
+
parser.add_argument(
|
| 348 |
+
"--db",
|
| 349 |
+
default=str(Path.cwd() / ".headroom" / "memory.db"),
|
| 350 |
+
help="Path to memory SQLite database",
|
| 351 |
+
)
|
| 352 |
+
parser.add_argument(
|
| 353 |
+
"--user",
|
| 354 |
+
default=os.environ.get("USER", os.environ.get("USERNAME", "default")),
|
| 355 |
+
help="User ID for memory scoping",
|
| 356 |
+
)
|
| 357 |
+
args = parser.parse_args()
|
| 358 |
+
|
| 359 |
+
# Skip HuggingFace model freshness checks — use cached models only.
|
| 360 |
+
# This eliminates 1-2s of HTTP HEAD requests on every startup.
|
| 361 |
+
os.environ.setdefault("HF_HUB_OFFLINE", "1")
|
| 362 |
+
os.environ.setdefault("TRANSFORMERS_OFFLINE", "1")
|
| 363 |
+
|
| 364 |
+
# Log to stderr (MCP uses stdout for protocol)
|
| 365 |
+
logging.basicConfig(
|
| 366 |
+
level=logging.INFO,
|
| 367 |
+
stream=sys.stderr,
|
| 368 |
+
format="%(name)s: %(message)s",
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
asyncio.run(_run(args.db, args.user))
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
if __name__ == "__main__":
|
| 375 |
+
main()
|
headroom/memory/sync.py
ADDED
|
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Universal memory sync engine for cross-agent interoperability.
|
| 2 |
+
|
| 3 |
+
Provides bidirectional sync between headroom's memory DB and any
|
| 4 |
+
agent's native memory format via pluggable adapters.
|
| 5 |
+
|
| 6 |
+
Architecture:
|
| 7 |
+
DB ← sync_import → Agent files (agent's knowledge enters the shared DB)
|
| 8 |
+
DB → sync_export → Agent files (shared knowledge flows to the agent)
|
| 9 |
+
sync() = import + export (bidirectional, fast no-op when unchanged)
|
| 10 |
+
|
| 11 |
+
Usage:
|
| 12 |
+
from headroom.memory.sync import sync, SyncResult
|
| 13 |
+
from headroom.memory.sync_adapters.claude_code import ClaudeCodeAdapter
|
| 14 |
+
|
| 15 |
+
adapter = ClaudeCodeAdapter(memory_dir=Path("~/.claude/projects/.../memory"))
|
| 16 |
+
backend = LocalBackend(config)
|
| 17 |
+
|
| 18 |
+
result: SyncResult = await sync(backend, adapter, user_id="tcms")
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
from __future__ import annotations
|
| 22 |
+
|
| 23 |
+
import hashlib
|
| 24 |
+
import json
|
| 25 |
+
import logging
|
| 26 |
+
import time
|
| 27 |
+
from abc import ABC, abstractmethod
|
| 28 |
+
from dataclasses import dataclass, field
|
| 29 |
+
from datetime import datetime, timezone
|
| 30 |
+
from pathlib import Path
|
| 31 |
+
from typing import Any
|
| 32 |
+
|
| 33 |
+
logger = logging.getLogger("headroom.memory.sync")
|
| 34 |
+
|
| 35 |
+
# State file for fast no-op detection
|
| 36 |
+
_DEFAULT_STATE_PATH = Path.home() / ".headroom" / "sync_state.json"
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# ---------------------------------------------------------------------------
|
| 40 |
+
# Data models
|
| 41 |
+
# ---------------------------------------------------------------------------
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@dataclass
|
| 45 |
+
class SyncResult:
|
| 46 |
+
"""Result of a sync operation."""
|
| 47 |
+
|
| 48 |
+
imported: int = 0 # agent files → DB
|
| 49 |
+
exported: int = 0 # DB → agent files
|
| 50 |
+
skipped_unchanged: int = 0
|
| 51 |
+
skipped_dedup: int = 0
|
| 52 |
+
duration_ms: float = 0
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@dataclass
|
| 56 |
+
class AgentMemory:
|
| 57 |
+
"""A memory entry read from an agent's native format."""
|
| 58 |
+
|
| 59 |
+
content: str
|
| 60 |
+
category: str = ""
|
| 61 |
+
source_file: str = ""
|
| 62 |
+
content_hash: str = ""
|
| 63 |
+
metadata: dict[str, Any] = field(default_factory=dict)
|
| 64 |
+
|
| 65 |
+
def __post_init__(self) -> None:
|
| 66 |
+
if not self.content_hash:
|
| 67 |
+
self.content_hash = hashlib.sha256(self.content.encode()).hexdigest()[:16]
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# ---------------------------------------------------------------------------
|
| 71 |
+
# Adapter interface
|
| 72 |
+
# ---------------------------------------------------------------------------
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class AgentMemoryAdapter(ABC):
|
| 76 |
+
"""Base class for agent memory format adapters.
|
| 77 |
+
|
| 78 |
+
Each agent (Claude Code, Codex, Aider, Cursor) has a subclass
|
| 79 |
+
that knows how to read/write that agent's native memory format.
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
agent_name: str = "unknown"
|
| 83 |
+
|
| 84 |
+
@abstractmethod
|
| 85 |
+
async def read_memories(self) -> list[AgentMemory]:
|
| 86 |
+
"""Read memories from the agent's native format.
|
| 87 |
+
|
| 88 |
+
Returns a list of AgentMemory entries found in the agent's files.
|
| 89 |
+
"""
|
| 90 |
+
...
|
| 91 |
+
|
| 92 |
+
@abstractmethod
|
| 93 |
+
async def write_memories(self, memories: list[dict[str, Any]]) -> int:
|
| 94 |
+
"""Write memories to the agent's native format.
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
memories: List of dicts with keys: content, category, importance,
|
| 98 |
+
headroom_id, source_agent, content_hash.
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
Count of memories written.
|
| 102 |
+
"""
|
| 103 |
+
...
|
| 104 |
+
|
| 105 |
+
@abstractmethod
|
| 106 |
+
def fingerprint(self) -> str:
|
| 107 |
+
"""Fast hash of the agent's memory state.
|
| 108 |
+
|
| 109 |
+
Used for no-op detection: if the fingerprint hasn't changed
|
| 110 |
+
since last sync, we can skip the full read/compare cycle.
|
| 111 |
+
"""
|
| 112 |
+
...
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
# ---------------------------------------------------------------------------
|
| 116 |
+
# Sync state persistence
|
| 117 |
+
# ---------------------------------------------------------------------------
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def _load_sync_state(state_path: Path) -> dict[str, Any]:
|
| 121 |
+
"""Load sync state from disk."""
|
| 122 |
+
if state_path.exists():
|
| 123 |
+
try:
|
| 124 |
+
result: dict[str, Any] = json.loads(state_path.read_text())
|
| 125 |
+
return result
|
| 126 |
+
except (json.JSONDecodeError, OSError):
|
| 127 |
+
pass
|
| 128 |
+
return {}
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def _save_sync_state(state_path: Path, state: dict[str, Any]) -> None:
|
| 132 |
+
"""Save sync state to disk."""
|
| 133 |
+
state_path.parent.mkdir(parents=True, exist_ok=True)
|
| 134 |
+
state_path.write_text(json.dumps(state, indent=2))
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def _db_fingerprint(memories: list[Any]) -> str:
|
| 138 |
+
"""Compute a fast fingerprint of DB state."""
|
| 139 |
+
if not memories:
|
| 140 |
+
return "empty"
|
| 141 |
+
# Hash: count + most recent created_at
|
| 142 |
+
parts = [str(len(memories))]
|
| 143 |
+
for m in memories[:5]: # Sample first 5 for speed
|
| 144 |
+
parts.append(getattr(m, "id", "")[:8])
|
| 145 |
+
return hashlib.sha256("|".join(parts).encode()).hexdigest()[:16]
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
# ---------------------------------------------------------------------------
|
| 149 |
+
# Sync engine
|
| 150 |
+
# ---------------------------------------------------------------------------
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
async def sync(
|
| 154 |
+
backend: Any,
|
| 155 |
+
adapter: AgentMemoryAdapter,
|
| 156 |
+
user_id: str,
|
| 157 |
+
state_path: Path = _DEFAULT_STATE_PATH,
|
| 158 |
+
force: bool = False,
|
| 159 |
+
) -> SyncResult:
|
| 160 |
+
"""Bidirectional sync between headroom DB and an agent's memory.
|
| 161 |
+
|
| 162 |
+
1. Fast no-op check (fingerprint comparison)
|
| 163 |
+
2. Import: agent files → DB (new entries only, deduped by content hash)
|
| 164 |
+
3. Export: DB → agent files (entries not already in agent's files)
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
backend: LocalBackend instance (must have save_memory, get_user_memories).
|
| 168 |
+
adapter: Agent-specific memory adapter.
|
| 169 |
+
user_id: User ID for memory scoping.
|
| 170 |
+
state_path: Path to sync state file.
|
| 171 |
+
force: Skip no-op check and always sync.
|
| 172 |
+
|
| 173 |
+
Returns:
|
| 174 |
+
SyncResult with import/export counts and timing.
|
| 175 |
+
"""
|
| 176 |
+
start = time.monotonic()
|
| 177 |
+
result = SyncResult()
|
| 178 |
+
|
| 179 |
+
# --- Fast no-op check ---
|
| 180 |
+
if not force:
|
| 181 |
+
state = _load_sync_state(state_path)
|
| 182 |
+
adapter_key = f"{adapter.agent_name}:{user_id}"
|
| 183 |
+
prev = state.get(adapter_key, {})
|
| 184 |
+
|
| 185 |
+
current_agent_fp = adapter.fingerprint()
|
| 186 |
+
all_memories = await backend.get_user_memories(user_id, limit=500)
|
| 187 |
+
current_db_fp = _db_fingerprint(all_memories)
|
| 188 |
+
|
| 189 |
+
if (
|
| 190 |
+
prev.get("agent_fingerprint") == current_agent_fp
|
| 191 |
+
and prev.get("db_fingerprint") == current_db_fp
|
| 192 |
+
):
|
| 193 |
+
result.duration_ms = (time.monotonic() - start) * 1000
|
| 194 |
+
logger.info(
|
| 195 |
+
f"Sync [{adapter.agent_name}]: no-op — nothing changed ({result.duration_ms:.1f}ms)"
|
| 196 |
+
)
|
| 197 |
+
return result
|
| 198 |
+
else:
|
| 199 |
+
all_memories = await backend.get_user_memories(user_id, limit=500)
|
| 200 |
+
|
| 201 |
+
# --- Phase 1: Import (agent files → DB) ---
|
| 202 |
+
result.imported = await sync_import(backend, adapter, user_id, all_memories)
|
| 203 |
+
|
| 204 |
+
# --- Phase 2: Export (DB → agent files) ---
|
| 205 |
+
# Re-fetch if imports happened (new entries)
|
| 206 |
+
if result.imported > 0:
|
| 207 |
+
all_memories = await backend.get_user_memories(user_id, limit=500)
|
| 208 |
+
result.exported = await sync_export(backend, adapter, user_id, all_memories)
|
| 209 |
+
|
| 210 |
+
# --- Update sync state ---
|
| 211 |
+
state = _load_sync_state(state_path)
|
| 212 |
+
adapter_key = f"{adapter.agent_name}:{user_id}"
|
| 213 |
+
state[adapter_key] = {
|
| 214 |
+
"agent_fingerprint": adapter.fingerprint(),
|
| 215 |
+
"db_fingerprint": _db_fingerprint(all_memories),
|
| 216 |
+
"last_sync": datetime.now(timezone.utc).isoformat(),
|
| 217 |
+
"last_imported": result.imported,
|
| 218 |
+
"last_exported": result.exported,
|
| 219 |
+
}
|
| 220 |
+
_save_sync_state(state_path, state)
|
| 221 |
+
|
| 222 |
+
result.duration_ms = (time.monotonic() - start) * 1000
|
| 223 |
+
logger.info(
|
| 224 |
+
f"Sync [{adapter.agent_name}]: imported={result.imported}, "
|
| 225 |
+
f"exported={result.exported} ({result.duration_ms:.1f}ms)"
|
| 226 |
+
)
|
| 227 |
+
return result
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
async def sync_import(
|
| 231 |
+
backend: Any,
|
| 232 |
+
adapter: AgentMemoryAdapter,
|
| 233 |
+
user_id: str,
|
| 234 |
+
existing_memories: list[Any] | None = None,
|
| 235 |
+
) -> int:
|
| 236 |
+
"""Import: agent files → DB. Returns count imported."""
|
| 237 |
+
agent_memories = await adapter.read_memories()
|
| 238 |
+
if not agent_memories:
|
| 239 |
+
return 0
|
| 240 |
+
|
| 241 |
+
# Build set of existing content hashes for dedup
|
| 242 |
+
if existing_memories is None:
|
| 243 |
+
existing_memories = await backend.get_user_memories(user_id, limit=500)
|
| 244 |
+
|
| 245 |
+
existing_hashes: set[str] = set()
|
| 246 |
+
for mem in existing_memories:
|
| 247 |
+
h = (mem.metadata or {}).get("content_hash", "")
|
| 248 |
+
if h:
|
| 249 |
+
existing_hashes.add(h)
|
| 250 |
+
# Also hash the content directly for safety
|
| 251 |
+
existing_hashes.add(hashlib.sha256(mem.content.encode()).hexdigest()[:16])
|
| 252 |
+
|
| 253 |
+
imported = 0
|
| 254 |
+
for am in agent_memories:
|
| 255 |
+
if am.content_hash in existing_hashes:
|
| 256 |
+
continue
|
| 257 |
+
|
| 258 |
+
# Save to DB with lineage metadata
|
| 259 |
+
await backend.save_memory(
|
| 260 |
+
content=am.content,
|
| 261 |
+
user_id=user_id,
|
| 262 |
+
importance=0.6,
|
| 263 |
+
metadata={
|
| 264 |
+
"source_agent": adapter.agent_name,
|
| 265 |
+
"source_file": am.source_file,
|
| 266 |
+
"content_hash": am.content_hash,
|
| 267 |
+
"synced_at": datetime.now(timezone.utc).isoformat(),
|
| 268 |
+
"sync_direction": "import",
|
| 269 |
+
**am.metadata,
|
| 270 |
+
},
|
| 271 |
+
)
|
| 272 |
+
existing_hashes.add(am.content_hash)
|
| 273 |
+
imported += 1
|
| 274 |
+
|
| 275 |
+
if imported:
|
| 276 |
+
logger.info(f"Sync [{adapter.agent_name}]: imported {imported} memories from agent files")
|
| 277 |
+
return imported
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
async def sync_export(
|
| 281 |
+
backend: Any,
|
| 282 |
+
adapter: AgentMemoryAdapter,
|
| 283 |
+
user_id: str,
|
| 284 |
+
existing_memories: list[Any] | None = None,
|
| 285 |
+
) -> int:
|
| 286 |
+
"""Export: DB → agent files. Returns count exported."""
|
| 287 |
+
if existing_memories is None:
|
| 288 |
+
existing_memories = await backend.get_user_memories(user_id, limit=500)
|
| 289 |
+
|
| 290 |
+
if not existing_memories:
|
| 291 |
+
return 0
|
| 292 |
+
|
| 293 |
+
# Read what the agent already has (to avoid re-exporting)
|
| 294 |
+
agent_memories = await adapter.read_memories()
|
| 295 |
+
agent_hashes: set[str] = {am.content_hash for am in agent_memories}
|
| 296 |
+
|
| 297 |
+
# Find memories to export (not already in agent, not imported FROM this agent)
|
| 298 |
+
to_export: list[dict[str, Any]] = []
|
| 299 |
+
for mem in existing_memories:
|
| 300 |
+
content_hash = hashlib.sha256(mem.content.encode()).hexdigest()[:16]
|
| 301 |
+
|
| 302 |
+
# Skip if agent already has it
|
| 303 |
+
if content_hash in agent_hashes:
|
| 304 |
+
continue
|
| 305 |
+
|
| 306 |
+
# Skip if this memory was originally imported FROM this same agent
|
| 307 |
+
# (prevents echo: agent → DB → agent)
|
| 308 |
+
meta = mem.metadata or {}
|
| 309 |
+
if (
|
| 310 |
+
meta.get("source_agent") == adapter.agent_name
|
| 311 |
+
and meta.get("sync_direction") == "import"
|
| 312 |
+
):
|
| 313 |
+
continue
|
| 314 |
+
|
| 315 |
+
to_export.append(
|
| 316 |
+
{
|
| 317 |
+
"content": mem.content,
|
| 318 |
+
"category": getattr(mem, "category", "") or "",
|
| 319 |
+
"importance": getattr(mem, "importance", 0.5),
|
| 320 |
+
"headroom_id": mem.id,
|
| 321 |
+
"source_agent": meta.get("source_agent", "unknown"),
|
| 322 |
+
"content_hash": content_hash,
|
| 323 |
+
"created_at": mem.created_at.isoformat()
|
| 324 |
+
if hasattr(mem.created_at, "isoformat")
|
| 325 |
+
else str(mem.created_at),
|
| 326 |
+
}
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
if not to_export:
|
| 330 |
+
return 0
|
| 331 |
+
|
| 332 |
+
exported = await adapter.write_memories(to_export)
|
| 333 |
+
if exported:
|
| 334 |
+
logger.info(f"Sync [{adapter.agent_name}]: exported {exported} memories to agent files")
|
| 335 |
+
return exported
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
# ---------------------------------------------------------------------------
|
| 339 |
+
# CLI entry point: python -m headroom.memory.sync --db ... --user ... --agent ...
|
| 340 |
+
# ---------------------------------------------------------------------------
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
def main() -> None:
|
| 344 |
+
"""CLI entry point for running sync from a subprocess."""
|
| 345 |
+
import argparse
|
| 346 |
+
|
| 347 |
+
parser = argparse.ArgumentParser(description="Headroom memory sync")
|
| 348 |
+
parser.add_argument("--db", required=True, help="Path to memory DB")
|
| 349 |
+
parser.add_argument("--user", required=True, help="User ID")
|
| 350 |
+
parser.add_argument("--agent", required=True, choices=["claude", "codex"], help="Agent to sync")
|
| 351 |
+
parser.add_argument("--force", action="store_true", help="Skip no-op check")
|
| 352 |
+
args = parser.parse_args()
|
| 353 |
+
|
| 354 |
+
import asyncio
|
| 355 |
+
import json as _json
|
| 356 |
+
|
| 357 |
+
async def _run() -> None:
|
| 358 |
+
from headroom.memory.backends.local import LocalBackend, LocalBackendConfig
|
| 359 |
+
|
| 360 |
+
config = LocalBackendConfig(db_path=args.db)
|
| 361 |
+
backend = LocalBackend(config)
|
| 362 |
+
await backend._ensure_initialized()
|
| 363 |
+
|
| 364 |
+
if args.agent == "claude":
|
| 365 |
+
from headroom.memory.sync_adapters.claude_code import (
|
| 366 |
+
ClaudeCodeAdapter,
|
| 367 |
+
get_claude_memory_dir,
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
adapter: ClaudeCodeAdapter | Any = ClaudeCodeAdapter(get_claude_memory_dir())
|
| 371 |
+
elif args.agent == "codex":
|
| 372 |
+
from headroom.memory.sync_adapters.codex_agent import CodexAdapter
|
| 373 |
+
|
| 374 |
+
adapter = CodexAdapter()
|
| 375 |
+
else:
|
| 376 |
+
print(_json.dumps({"error": f"Unknown agent: {args.agent}"}))
|
| 377 |
+
return
|
| 378 |
+
|
| 379 |
+
result = await sync(backend, adapter, args.user, force=args.force)
|
| 380 |
+
await backend.close()
|
| 381 |
+
print(
|
| 382 |
+
_json.dumps(
|
| 383 |
+
{
|
| 384 |
+
"imported": result.imported,
|
| 385 |
+
"exported": result.exported,
|
| 386 |
+
"ms": round(result.duration_ms),
|
| 387 |
+
}
|
| 388 |
+
)
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
asyncio.run(_run())
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
if __name__ == "__main__":
|
| 395 |
+
main()
|
headroom/memory/sync_adapters/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Agent memory sync adapters for cross-agent interoperability."""
|
headroom/memory/sync_adapters/claude_code.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Claude Code memory sync adapter.
|
| 2 |
+
|
| 3 |
+
Reads/writes Claude Code's native memory format:
|
| 4 |
+
~/.claude/projects/<sanitized-path>/memory/
|
| 5 |
+
MEMORY.md — index file (first 200 lines always in context)
|
| 6 |
+
user_role.md — individual memory files with YAML frontmatter
|
| 7 |
+
project_codename.md
|
| 8 |
+
...
|
| 9 |
+
|
| 10 |
+
Each .md file has:
|
| 11 |
+
---
|
| 12 |
+
name: <title>
|
| 13 |
+
description: <one-line summary>
|
| 14 |
+
type: <user|project|reference|feedback>
|
| 15 |
+
headroom_id: <uuid> (added by sync for cross-reference)
|
| 16 |
+
source_agent: <agent name> (added by sync for lineage)
|
| 17 |
+
---
|
| 18 |
+
<body content>
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
from __future__ import annotations
|
| 22 |
+
|
| 23 |
+
import hashlib
|
| 24 |
+
import re
|
| 25 |
+
from pathlib import Path
|
| 26 |
+
from typing import Any
|
| 27 |
+
|
| 28 |
+
from headroom.memory.sync import AgentMemory, AgentMemoryAdapter
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _sanitize_for_filename(text: str) -> str:
|
| 32 |
+
"""Convert text to a safe filename slug."""
|
| 33 |
+
slug = re.sub(r"[^a-z0-9]+", "_", text.lower().strip())
|
| 34 |
+
slug = slug.strip("_")[:50]
|
| 35 |
+
return slug or "memory"
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _parse_frontmatter(content: str) -> tuple[dict[str, str], str]:
|
| 39 |
+
"""Parse YAML frontmatter from a markdown file.
|
| 40 |
+
|
| 41 |
+
Returns (frontmatter_dict, body).
|
| 42 |
+
"""
|
| 43 |
+
if not content.startswith("---"):
|
| 44 |
+
return {}, content
|
| 45 |
+
|
| 46 |
+
end = content.find("---", 3)
|
| 47 |
+
if end == -1:
|
| 48 |
+
return {}, content
|
| 49 |
+
|
| 50 |
+
fm_text = content[3:end].strip()
|
| 51 |
+
body = content[end + 3 :].strip()
|
| 52 |
+
|
| 53 |
+
fm: dict[str, str] = {}
|
| 54 |
+
for line in fm_text.split("\n"):
|
| 55 |
+
if ":" in line:
|
| 56 |
+
key, _, value = line.partition(":")
|
| 57 |
+
fm[key.strip()] = value.strip().strip('"').strip("'")
|
| 58 |
+
|
| 59 |
+
return fm, body
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def _build_frontmatter(fields: dict[str, str]) -> str:
|
| 63 |
+
"""Build YAML frontmatter block."""
|
| 64 |
+
lines = ["---"]
|
| 65 |
+
for key, value in fields.items():
|
| 66 |
+
if value:
|
| 67 |
+
lines.append(f"{key}: {value}")
|
| 68 |
+
lines.append("---")
|
| 69 |
+
return "\n".join(lines)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class ClaudeCodeAdapter(AgentMemoryAdapter):
|
| 73 |
+
"""Sync adapter for Claude Code's native memory files."""
|
| 74 |
+
|
| 75 |
+
agent_name = "claude"
|
| 76 |
+
|
| 77 |
+
def __init__(self, memory_dir: Path | str) -> None:
|
| 78 |
+
self._memory_dir = Path(memory_dir)
|
| 79 |
+
|
| 80 |
+
async def read_memories(self) -> list[AgentMemory]:
|
| 81 |
+
"""Read all .md memory files (except MEMORY.md index)."""
|
| 82 |
+
if not self._memory_dir.exists():
|
| 83 |
+
return []
|
| 84 |
+
|
| 85 |
+
memories: list[AgentMemory] = []
|
| 86 |
+
for md_file in sorted(self._memory_dir.glob("*.md")):
|
| 87 |
+
if md_file.name == "MEMORY.md":
|
| 88 |
+
continue # Index file, not a memory
|
| 89 |
+
|
| 90 |
+
try:
|
| 91 |
+
content = md_file.read_text(encoding="utf-8")
|
| 92 |
+
except OSError:
|
| 93 |
+
continue
|
| 94 |
+
|
| 95 |
+
fm, body = _parse_frontmatter(content)
|
| 96 |
+
if not body.strip():
|
| 97 |
+
continue
|
| 98 |
+
|
| 99 |
+
memories.append(
|
| 100 |
+
AgentMemory(
|
| 101 |
+
content=body.strip(),
|
| 102 |
+
category=fm.get("type", ""),
|
| 103 |
+
source_file=md_file.name,
|
| 104 |
+
metadata={
|
| 105 |
+
"name": fm.get("name", ""),
|
| 106 |
+
"description": fm.get("description", ""),
|
| 107 |
+
"headroom_id": fm.get("headroom_id", ""),
|
| 108 |
+
"source_agent": fm.get("source_agent", "claude"),
|
| 109 |
+
},
|
| 110 |
+
)
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
return memories
|
| 114 |
+
|
| 115 |
+
async def write_memories(self, memories: list[dict[str, Any]]) -> int:
|
| 116 |
+
"""Write memories as individual .md files with frontmatter.
|
| 117 |
+
|
| 118 |
+
Also updates MEMORY.md index.
|
| 119 |
+
"""
|
| 120 |
+
if not memories:
|
| 121 |
+
return 0
|
| 122 |
+
|
| 123 |
+
self._memory_dir.mkdir(parents=True, exist_ok=True)
|
| 124 |
+
|
| 125 |
+
written = 0
|
| 126 |
+
new_index_entries: list[str] = []
|
| 127 |
+
|
| 128 |
+
for mem in memories:
|
| 129 |
+
content = mem["content"]
|
| 130 |
+
category = mem.get("category", "project")
|
| 131 |
+
headroom_id = mem.get("headroom_id", "")
|
| 132 |
+
source_agent = mem.get("source_agent", "unknown")
|
| 133 |
+
content_hash = mem.get("content_hash", "")
|
| 134 |
+
|
| 135 |
+
# Generate filename from content
|
| 136 |
+
first_line = content.split("\n")[0][:60].strip()
|
| 137 |
+
slug = _sanitize_for_filename(first_line)
|
| 138 |
+
filename = f"headroom_{slug}.md"
|
| 139 |
+
|
| 140 |
+
# Skip if file already exists with same content
|
| 141 |
+
target = self._memory_dir / filename
|
| 142 |
+
if target.exists():
|
| 143 |
+
existing_fm, existing_body = _parse_frontmatter(target.read_text(encoding="utf-8"))
|
| 144 |
+
existing_hash = hashlib.sha256(existing_body.strip().encode()).hexdigest()[:16]
|
| 145 |
+
if existing_hash == content_hash:
|
| 146 |
+
continue
|
| 147 |
+
|
| 148 |
+
# Build description (first 100 chars)
|
| 149 |
+
description = content.replace("\n", " ")[:100]
|
| 150 |
+
|
| 151 |
+
# Write file
|
| 152 |
+
fm = _build_frontmatter(
|
| 153 |
+
{
|
| 154 |
+
"name": first_line[:60],
|
| 155 |
+
"description": description,
|
| 156 |
+
"type": category or "project",
|
| 157 |
+
"headroom_id": headroom_id,
|
| 158 |
+
"source_agent": source_agent,
|
| 159 |
+
}
|
| 160 |
+
)
|
| 161 |
+
target.write_text(f"{fm}\n\n{content}\n", encoding="utf-8")
|
| 162 |
+
written += 1
|
| 163 |
+
|
| 164 |
+
# Track for MEMORY.md index
|
| 165 |
+
new_index_entries.append(f"- [{first_line[:60]}]({filename}) — {description[:80]}")
|
| 166 |
+
|
| 167 |
+
# Update MEMORY.md index
|
| 168 |
+
if new_index_entries:
|
| 169 |
+
self._update_memory_md_index(new_index_entries)
|
| 170 |
+
|
| 171 |
+
return written
|
| 172 |
+
|
| 173 |
+
def _update_memory_md_index(self, new_entries: list[str]) -> None:
|
| 174 |
+
"""Append new entries to MEMORY.md under a Headroom section."""
|
| 175 |
+
memory_md = self._memory_dir / "MEMORY.md"
|
| 176 |
+
|
| 177 |
+
section_marker = "## Headroom Shared Memory"
|
| 178 |
+
new_section = f"\n{section_marker}\n" + "\n".join(new_entries) + "\n"
|
| 179 |
+
|
| 180 |
+
if memory_md.exists():
|
| 181 |
+
content = memory_md.read_text(encoding="utf-8")
|
| 182 |
+
if section_marker in content:
|
| 183 |
+
# Append to existing section (before next ## or end)
|
| 184 |
+
idx = content.index(section_marker)
|
| 185 |
+
# Find end of section (next ## or end of file)
|
| 186 |
+
next_section = content.find("\n## ", idx + len(section_marker))
|
| 187 |
+
if next_section == -1:
|
| 188 |
+
# Append at end
|
| 189 |
+
content = content.rstrip() + "\n" + "\n".join(new_entries) + "\n"
|
| 190 |
+
else:
|
| 191 |
+
# Insert before next section
|
| 192 |
+
content = (
|
| 193 |
+
content[:next_section].rstrip()
|
| 194 |
+
+ "\n"
|
| 195 |
+
+ "\n".join(new_entries)
|
| 196 |
+
+ "\n"
|
| 197 |
+
+ content[next_section:]
|
| 198 |
+
)
|
| 199 |
+
else:
|
| 200 |
+
content = content.rstrip() + "\n" + new_section
|
| 201 |
+
else:
|
| 202 |
+
content = "# Memory\n" + new_section
|
| 203 |
+
|
| 204 |
+
memory_md.write_text(content, encoding="utf-8")
|
| 205 |
+
|
| 206 |
+
def fingerprint(self) -> str:
|
| 207 |
+
"""Hash of all .md filenames + mtimes for fast change detection."""
|
| 208 |
+
if not self._memory_dir.exists():
|
| 209 |
+
return "empty"
|
| 210 |
+
|
| 211 |
+
parts: list[str] = []
|
| 212 |
+
for md_file in sorted(self._memory_dir.glob("*.md")):
|
| 213 |
+
try:
|
| 214 |
+
stat = md_file.stat()
|
| 215 |
+
parts.append(f"{md_file.name}:{stat.st_mtime_ns}")
|
| 216 |
+
except OSError:
|
| 217 |
+
continue
|
| 218 |
+
|
| 219 |
+
if not parts:
|
| 220 |
+
return "empty"
|
| 221 |
+
return hashlib.sha256("|".join(parts).encode()).hexdigest()[:16]
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def get_claude_memory_dir(project_path: Path | None = None) -> Path:
|
| 225 |
+
"""Get the Claude Code memory directory for a project.
|
| 226 |
+
|
| 227 |
+
Claude Code stores per-project memory at:
|
| 228 |
+
~/.claude/projects/-<sanitized-path>/memory/
|
| 229 |
+
"""
|
| 230 |
+
project = project_path or Path.cwd()
|
| 231 |
+
# Replace both Unix and Windows path separators (Claude Code does the same)
|
| 232 |
+
sanitized = str(project).replace("/", "-").replace("\\", "-")
|
| 233 |
+
return Path.home() / ".claude" / "projects" / sanitized / "memory"
|
headroom/memory/sync_adapters/codex_agent.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Codex CLI memory sync adapter.
|
| 2 |
+
|
| 3 |
+
Syncs memories to/from a headroom-managed section in AGENTS.md.
|
| 4 |
+
Codex reads AGENTS.md automatically before every task.
|
| 5 |
+
|
| 6 |
+
Note: Codex primarily uses the MCP server for memory (memory_search/save).
|
| 7 |
+
This adapter provides supplementary context injection via AGENTS.md so
|
| 8 |
+
Codex has key memories even without explicit tool calls.
|
| 9 |
+
|
| 10 |
+
Format in AGENTS.md:
|
| 11 |
+
<!-- headroom:memory:start -->
|
| 12 |
+
## Headroom Shared Memory
|
| 13 |
+
- fact 1
|
| 14 |
+
- fact 2
|
| 15 |
+
<!-- headroom:memory:end -->
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
from __future__ import annotations
|
| 19 |
+
|
| 20 |
+
import hashlib
|
| 21 |
+
import re
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
from typing import Any
|
| 24 |
+
|
| 25 |
+
from headroom.memory.sync import AgentMemory, AgentMemoryAdapter
|
| 26 |
+
|
| 27 |
+
_MARKER_START = "<!-- headroom:memory:start -->"
|
| 28 |
+
_MARKER_END = "<!-- headroom:memory:end -->"
|
| 29 |
+
_MARKER_PATTERN = re.compile(
|
| 30 |
+
re.escape(_MARKER_START) + r"(.*?)" + re.escape(_MARKER_END),
|
| 31 |
+
re.DOTALL,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class CodexAdapter(AgentMemoryAdapter):
|
| 36 |
+
"""Sync adapter for Codex's AGENTS.md."""
|
| 37 |
+
|
| 38 |
+
agent_name = "codex"
|
| 39 |
+
|
| 40 |
+
def __init__(self, agents_md_path: Path | str | None = None) -> None:
|
| 41 |
+
self._path = Path(agents_md_path) if agents_md_path else Path.cwd() / "AGENTS.md"
|
| 42 |
+
|
| 43 |
+
async def read_memories(self) -> list[AgentMemory]:
|
| 44 |
+
"""Read memories from the headroom section of AGENTS.md."""
|
| 45 |
+
if not self._path.exists():
|
| 46 |
+
return []
|
| 47 |
+
|
| 48 |
+
content = self._path.read_text(encoding="utf-8")
|
| 49 |
+
match = _MARKER_PATTERN.search(content)
|
| 50 |
+
if not match:
|
| 51 |
+
return []
|
| 52 |
+
|
| 53 |
+
section = match.group(1).strip()
|
| 54 |
+
memories: list[AgentMemory] = []
|
| 55 |
+
|
| 56 |
+
for line in section.split("\n"):
|
| 57 |
+
line = line.strip()
|
| 58 |
+
if line.startswith("- "):
|
| 59 |
+
fact = line[2:].strip()
|
| 60 |
+
if fact:
|
| 61 |
+
memories.append(
|
| 62 |
+
AgentMemory(
|
| 63 |
+
content=fact,
|
| 64 |
+
source_file=self._path.name,
|
| 65 |
+
)
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
return memories
|
| 69 |
+
|
| 70 |
+
async def write_memories(self, memories: list[dict[str, Any]]) -> int:
|
| 71 |
+
"""Write memories into the headroom section of AGENTS.md."""
|
| 72 |
+
if not memories:
|
| 73 |
+
return 0
|
| 74 |
+
|
| 75 |
+
# Build section content
|
| 76 |
+
lines = ["## Headroom Shared Memory", ""]
|
| 77 |
+
for mem in memories:
|
| 78 |
+
content = mem["content"].split("\n")[0].strip() # First line only
|
| 79 |
+
lines.append(f"- {content}")
|
| 80 |
+
lines.append("")
|
| 81 |
+
|
| 82 |
+
section = f"{_MARKER_START}\n" + "\n".join(lines) + f"{_MARKER_END}"
|
| 83 |
+
|
| 84 |
+
# Merge into AGENTS.md
|
| 85 |
+
if self._path.exists():
|
| 86 |
+
content = self._path.read_text(encoding="utf-8")
|
| 87 |
+
if _MARKER_START in content:
|
| 88 |
+
content = _MARKER_PATTERN.sub(section, content)
|
| 89 |
+
else:
|
| 90 |
+
content = content.rstrip() + "\n\n" + section + "\n"
|
| 91 |
+
else:
|
| 92 |
+
self._path.parent.mkdir(parents=True, exist_ok=True)
|
| 93 |
+
content = section + "\n"
|
| 94 |
+
|
| 95 |
+
self._path.write_text(content, encoding="utf-8")
|
| 96 |
+
return len(memories)
|
| 97 |
+
|
| 98 |
+
def fingerprint(self) -> str:
|
| 99 |
+
"""Hash of AGENTS.md mtime."""
|
| 100 |
+
if not self._path.exists():
|
| 101 |
+
return "empty"
|
| 102 |
+
try:
|
| 103 |
+
stat = self._path.stat()
|
| 104 |
+
return hashlib.sha256(f"{self._path.name}:{stat.st_mtime_ns}".encode()).hexdigest()[:16]
|
| 105 |
+
except OSError:
|
| 106 |
+
return "error"
|
headroom/memory/writers/claude_writer.py
CHANGED
|
@@ -64,7 +64,8 @@ class ClaudeCodeMemoryWriter(AgentWriter):
|
|
| 64 |
project_path = self._project_path
|
| 65 |
# Claude Code stores per-project memory at:
|
| 66 |
# ~/.claude/projects/-<sanitized-path>/memory/MEMORY.md
|
| 67 |
-
|
|
|
|
| 68 |
claude_memory_dir = Path.home() / ".claude" / "projects" / sanitized / "memory"
|
| 69 |
return claude_memory_dir / "MEMORY.md"
|
| 70 |
|
|
|
|
| 64 |
project_path = self._project_path
|
| 65 |
# Claude Code stores per-project memory at:
|
| 66 |
# ~/.claude/projects/-<sanitized-path>/memory/MEMORY.md
|
| 67 |
+
# Replace both Unix and Windows path separators
|
| 68 |
+
sanitized = str(project_path).replace("/", "-").replace("\\", "-")
|
| 69 |
claude_memory_dir = Path.home() / ".claude" / "projects" / sanitized / "memory"
|
| 70 |
return claude_memory_dir / "MEMORY.md"
|
| 71 |
|
headroom/proxy/handlers/openai.py
CHANGED
|
@@ -672,7 +672,9 @@ class OpenAIHandlerMixin:
|
|
| 672 |
uncached_tokens=uncached_input_tokens,
|
| 673 |
)
|
| 674 |
|
| 675 |
-
# Memory: handle memory tool calls in OpenAI response
|
|
|
|
|
|
|
| 676 |
if (
|
| 677 |
self.memory_handler
|
| 678 |
and memory_user_id
|
|
@@ -684,10 +686,29 @@ class OpenAIHandlerMixin:
|
|
| 684 |
tool_results = await self.memory_handler.handle_memory_tool_calls(
|
| 685 |
resp_json, memory_user_id, "openai"
|
| 686 |
)
|
| 687 |
-
|
| 688 |
-
|
| 689 |
-
|
| 690 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 691 |
except Exception as e:
|
| 692 |
logger.warning(f"[{request_id}] Memory tool handling failed: {e}")
|
| 693 |
|
|
@@ -963,14 +984,29 @@ class OpenAIHandlerMixin:
|
|
| 963 |
f"of context into instructions for user {memory_user_id}"
|
| 964 |
)
|
| 965 |
|
| 966 |
-
# Inject memory tools
|
| 967 |
if self.memory_handler.config.inject_tools:
|
| 968 |
resp_tools = body.get("tools") or []
|
| 969 |
resp_tools, mem_tools_injected = self.memory_handler.inject_tools(
|
| 970 |
resp_tools, "openai"
|
| 971 |
)
|
| 972 |
if mem_tools_injected:
|
| 973 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 974 |
logger.info(
|
| 975 |
f"[{request_id}] Memory: Injected memory tools (openai/responses)"
|
| 976 |
)
|
|
@@ -1037,13 +1073,67 @@ class OpenAIHandlerMixin:
|
|
| 1037 |
and self.memory_handler.has_memory_tool_calls(resp_json, "openai")
|
| 1038 |
):
|
| 1039 |
try:
|
| 1040 |
-
|
| 1041 |
-
|
| 1042 |
-
|
| 1043 |
-
|
| 1044 |
-
|
| 1045 |
-
|
| 1046 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1047 |
except Exception as e:
|
| 1048 |
logger.warning(
|
| 1049 |
f"[{request_id}] Memory tool handling failed (responses): {e}"
|
|
@@ -1279,6 +1369,116 @@ class OpenAIHandlerMixin:
|
|
| 1279 |
# Not JSON — pass through as-is
|
| 1280 |
tokens_saved = 0
|
| 1281 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1282 |
# --- Connect to upstream OpenAI WebSocket ---
|
| 1283 |
logger.info(f"[{request_id}] WS /v1/responses connecting to {upstream_url}")
|
| 1284 |
|
|
@@ -1303,7 +1503,7 @@ class OpenAIHandlerMixin:
|
|
| 1303 |
# Send (potentially compressed) first message
|
| 1304 |
await upstream.send(first_msg_raw)
|
| 1305 |
|
| 1306 |
-
# Bidirectional relay
|
| 1307 |
async def _client_to_upstream() -> None:
|
| 1308 |
try:
|
| 1309 |
while True:
|
|
@@ -1318,14 +1518,164 @@ class OpenAIHandlerMixin:
|
|
| 1318 |
await upstream.close()
|
| 1319 |
|
| 1320 |
async def _upstream_to_client() -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1321 |
try:
|
| 1322 |
async for msg in upstream:
|
| 1323 |
-
if isinstance(msg,
|
| 1324 |
-
await websocket.send_text(msg)
|
| 1325 |
-
elif isinstance(msg, bytes):
|
| 1326 |
await websocket.send_bytes(msg)
|
| 1327 |
-
|
| 1328 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1329 |
except Exception as relay_err:
|
| 1330 |
if "WebSocketDisconnect" not in type(relay_err).__name__:
|
| 1331 |
logger.debug(
|
|
|
|
| 672 |
uncached_tokens=uncached_input_tokens,
|
| 673 |
)
|
| 674 |
|
| 675 |
+
# Memory: handle memory tool calls in OpenAI Chat Completions response.
|
| 676 |
+
# After executing tools, send a continuation request so the model
|
| 677 |
+
# can produce a final user-facing response (not just tool_calls).
|
| 678 |
if (
|
| 679 |
self.memory_handler
|
| 680 |
and memory_user_id
|
|
|
|
| 686 |
tool_results = await self.memory_handler.handle_memory_tool_calls(
|
| 687 |
resp_json, memory_user_id, "openai"
|
| 688 |
)
|
| 689 |
+
if tool_results:
|
| 690 |
+
# Build continuation: original messages + assistant tool_calls + tool results
|
| 691 |
+
assistant_msg = resp_json.get("choices", [{}])[0].get("message", {})
|
| 692 |
+
continuation_messages = list(optimized_messages)
|
| 693 |
+
continuation_messages.append(assistant_msg)
|
| 694 |
+
continuation_messages.extend(tool_results)
|
| 695 |
+
|
| 696 |
+
continuation_body = {
|
| 697 |
+
**body,
|
| 698 |
+
"messages": continuation_messages,
|
| 699 |
+
}
|
| 700 |
+
|
| 701 |
+
cont_response = await self._retry_request(
|
| 702 |
+
"POST", url, headers, continuation_body
|
| 703 |
+
)
|
| 704 |
+
if cont_response.status_code == 200:
|
| 705 |
+
resp_json = cont_response.json()
|
| 706 |
+
response = cont_response
|
| 707 |
+
|
| 708 |
+
logger.info(
|
| 709 |
+
f"[{request_id}] Memory: Handled {len(tool_results)} "
|
| 710 |
+
f"tool call(s) with continuation for user {memory_user_id}"
|
| 711 |
+
)
|
| 712 |
except Exception as e:
|
| 713 |
logger.warning(f"[{request_id}] Memory tool handling failed: {e}")
|
| 714 |
|
|
|
|
| 984 |
f"of context into instructions for user {memory_user_id}"
|
| 985 |
)
|
| 986 |
|
| 987 |
+
# Inject memory tools (Responses API format)
|
| 988 |
if self.memory_handler.config.inject_tools:
|
| 989 |
resp_tools = body.get("tools") or []
|
| 990 |
resp_tools, mem_tools_injected = self.memory_handler.inject_tools(
|
| 991 |
resp_tools, "openai"
|
| 992 |
)
|
| 993 |
if mem_tools_injected:
|
| 994 |
+
# Convert Chat Completions format to Responses API format
|
| 995 |
+
converted_tools = []
|
| 996 |
+
for t in resp_tools:
|
| 997 |
+
if t.get("type") == "function" and "function" in t:
|
| 998 |
+
fn = t["function"]
|
| 999 |
+
converted_tools.append(
|
| 1000 |
+
{
|
| 1001 |
+
"type": "function",
|
| 1002 |
+
"name": fn.get("name"),
|
| 1003 |
+
"description": fn.get("description", ""),
|
| 1004 |
+
"parameters": fn.get("parameters", {}),
|
| 1005 |
+
}
|
| 1006 |
+
)
|
| 1007 |
+
else:
|
| 1008 |
+
converted_tools.append(t)
|
| 1009 |
+
body["tools"] = converted_tools
|
| 1010 |
logger.info(
|
| 1011 |
f"[{request_id}] Memory: Injected memory tools (openai/responses)"
|
| 1012 |
)
|
|
|
|
| 1073 |
and self.memory_handler.has_memory_tool_calls(resp_json, "openai")
|
| 1074 |
):
|
| 1075 |
try:
|
| 1076 |
+
# Extract function_call items from output
|
| 1077 |
+
from headroom.proxy.memory_handler import MEMORY_TOOL_NAMES
|
| 1078 |
+
|
| 1079 |
+
output_items = resp_json.get("output", [])
|
| 1080 |
+
memory_fc_items = [
|
| 1081 |
+
item
|
| 1082 |
+
for item in output_items
|
| 1083 |
+
if isinstance(item, dict)
|
| 1084 |
+
and item.get("type") == "function_call"
|
| 1085 |
+
and item.get("name") in MEMORY_TOOL_NAMES
|
| 1086 |
+
]
|
| 1087 |
+
|
| 1088 |
+
# Execute memory tool calls
|
| 1089 |
+
tool_outputs: list[dict[str, Any]] = []
|
| 1090 |
+
for fc in memory_fc_items:
|
| 1091 |
+
call_id = fc.get("call_id", fc.get("id", ""))
|
| 1092 |
+
name = fc.get("name", "")
|
| 1093 |
+
args_str = fc.get("arguments", "{}")
|
| 1094 |
+
try:
|
| 1095 |
+
args = json.loads(args_str)
|
| 1096 |
+
except json.JSONDecodeError:
|
| 1097 |
+
args = {}
|
| 1098 |
+
|
| 1099 |
+
await self.memory_handler._ensure_initialized()
|
| 1100 |
+
if self.memory_handler._backend:
|
| 1101 |
+
result = await self.memory_handler._execute_memory_tool(
|
| 1102 |
+
name, args, memory_user_id, "openai"
|
| 1103 |
+
)
|
| 1104 |
+
else:
|
| 1105 |
+
result = json.dumps({"error": "Memory backend not initialized"})
|
| 1106 |
+
|
| 1107 |
+
tool_outputs.append(
|
| 1108 |
+
{
|
| 1109 |
+
"type": "function_call_output",
|
| 1110 |
+
"call_id": call_id,
|
| 1111 |
+
"output": result,
|
| 1112 |
+
}
|
| 1113 |
+
)
|
| 1114 |
+
|
| 1115 |
+
if tool_outputs:
|
| 1116 |
+
# Make continuation request with tool results
|
| 1117 |
+
response_id = resp_json.get("id")
|
| 1118 |
+
continuation_body = {
|
| 1119 |
+
"model": model,
|
| 1120 |
+
"input": tool_outputs,
|
| 1121 |
+
}
|
| 1122 |
+
if response_id:
|
| 1123 |
+
continuation_body["previous_response_id"] = response_id
|
| 1124 |
+
existing_tools = body.get("tools")
|
| 1125 |
+
if existing_tools:
|
| 1126 |
+
continuation_body["tools"] = existing_tools
|
| 1127 |
+
|
| 1128 |
+
cont_response = await self._retry_request(
|
| 1129 |
+
"POST", url, headers, continuation_body
|
| 1130 |
+
)
|
| 1131 |
+
resp_json = cont_response.json()
|
| 1132 |
+
response = cont_response
|
| 1133 |
+
logger.info(
|
| 1134 |
+
f"[{request_id}] Memory: Handled {len(tool_outputs)} "
|
| 1135 |
+
f"tool call(s) with continuation for user {memory_user_id} (responses)"
|
| 1136 |
+
)
|
| 1137 |
except Exception as e:
|
| 1138 |
logger.warning(
|
| 1139 |
f"[{request_id}] Memory tool handling failed (responses): {e}"
|
|
|
|
| 1369 |
# Not JSON — pass through as-is
|
| 1370 |
tokens_saved = 0
|
| 1371 |
|
| 1372 |
+
# --- Memory: inject context, tools, and instructions ---
|
| 1373 |
+
memory_user_id: str | None = None
|
| 1374 |
+
if self.memory_handler and body:
|
| 1375 |
+
memory_user_id = ws_headers.get(
|
| 1376 |
+
"x-headroom-user-id",
|
| 1377 |
+
os.environ.get("USER", os.environ.get("USERNAME", "default")),
|
| 1378 |
+
)
|
| 1379 |
+
try:
|
| 1380 |
+
# Unwrap response.create envelope to access the response body
|
| 1381 |
+
ws_response_body = body.get("response", body)
|
| 1382 |
+
|
| 1383 |
+
# Debug: log what Codex sends so we can see the full tool list
|
| 1384 |
+
existing_tool_names = [
|
| 1385 |
+
t.get("name") or t.get("function", {}).get("name", "?")
|
| 1386 |
+
for t in (ws_response_body.get("tools") or [])
|
| 1387 |
+
]
|
| 1388 |
+
instr_preview = (ws_response_body.get("instructions") or "")[:200]
|
| 1389 |
+
logger.info(
|
| 1390 |
+
f"[{request_id}] WS Memory: Codex tools={existing_tool_names}, "
|
| 1391 |
+
f"instructions_len={len(ws_response_body.get('instructions') or '')}, "
|
| 1392 |
+
f"instructions_preview={instr_preview!r}"
|
| 1393 |
+
)
|
| 1394 |
+
|
| 1395 |
+
# Inject memory context into instructions
|
| 1396 |
+
if self.memory_handler.config.inject_context:
|
| 1397 |
+
ws_input = ws_response_body.get("input", "")
|
| 1398 |
+
ws_instructions = ws_response_body.get("instructions")
|
| 1399 |
+
ws_msgs: list[dict[str, Any]] = []
|
| 1400 |
+
if ws_instructions:
|
| 1401 |
+
ws_msgs.append({"role": "system", "content": ws_instructions})
|
| 1402 |
+
if isinstance(ws_input, str) and ws_input:
|
| 1403 |
+
ws_msgs.append({"role": "user", "content": ws_input})
|
| 1404 |
+
elif isinstance(ws_input, list):
|
| 1405 |
+
from headroom.proxy.responses_converter import (
|
| 1406 |
+
responses_items_to_messages,
|
| 1407 |
+
)
|
| 1408 |
+
|
| 1409 |
+
converted_msgs, _ = responses_items_to_messages(ws_input)
|
| 1410 |
+
ws_msgs.extend(converted_msgs)
|
| 1411 |
+
|
| 1412 |
+
memory_context = await self.memory_handler.search_and_format_context(
|
| 1413 |
+
memory_user_id, ws_msgs
|
| 1414 |
+
)
|
| 1415 |
+
if memory_context:
|
| 1416 |
+
existing = ws_response_body.get("instructions") or ""
|
| 1417 |
+
if existing:
|
| 1418 |
+
ws_response_body["instructions"] = f"{existing}\n\n{memory_context}"
|
| 1419 |
+
else:
|
| 1420 |
+
ws_response_body["instructions"] = memory_context
|
| 1421 |
+
logger.info(
|
| 1422 |
+
f"[{request_id}] WS Memory: Injected {len(memory_context)} chars "
|
| 1423 |
+
f"of context into instructions"
|
| 1424 |
+
)
|
| 1425 |
+
|
| 1426 |
+
# Inject memory tools (Responses API format)
|
| 1427 |
+
if self.memory_handler.config.inject_tools:
|
| 1428 |
+
ws_tools = ws_response_body.get("tools") or []
|
| 1429 |
+
ws_tools, mem_injected = self.memory_handler.inject_tools(
|
| 1430 |
+
ws_tools, "openai"
|
| 1431 |
+
)
|
| 1432 |
+
if mem_injected:
|
| 1433 |
+
converted_tools = []
|
| 1434 |
+
for t in ws_tools:
|
| 1435 |
+
if t.get("type") == "function" and "function" in t:
|
| 1436 |
+
fn = t["function"]
|
| 1437 |
+
converted_tools.append(
|
| 1438 |
+
{
|
| 1439 |
+
"type": "function",
|
| 1440 |
+
"name": fn.get("name"),
|
| 1441 |
+
"description": fn.get("description", ""),
|
| 1442 |
+
"parameters": fn.get("parameters", {}),
|
| 1443 |
+
}
|
| 1444 |
+
)
|
| 1445 |
+
else:
|
| 1446 |
+
converted_tools.append(t)
|
| 1447 |
+
ws_response_body["tools"] = converted_tools
|
| 1448 |
+
|
| 1449 |
+
# Add memory instruction so the model uses
|
| 1450 |
+
# memory tools as persistent cross-session knowledge.
|
| 1451 |
+
mem_instruction = (
|
| 1452 |
+
"\n\n## Memory\n"
|
| 1453 |
+
"You have persistent memory via memory_search and "
|
| 1454 |
+
"memory_save tools. Memory stores knowledge across "
|
| 1455 |
+
"sessions — user info, project details, org context, "
|
| 1456 |
+
"decisions, architecture, conventions, anything worth "
|
| 1457 |
+
"remembering.\n\n"
|
| 1458 |
+
"- ALWAYS call memory_search BEFORE searching files "
|
| 1459 |
+
"when the user asks a question that could be answered "
|
| 1460 |
+
"from prior knowledge.\n"
|
| 1461 |
+
"- Call memory_save to store important facts, decisions, "
|
| 1462 |
+
"or context that would be useful in future sessions.\n"
|
| 1463 |
+
"- Memory is your first source of truth for anything "
|
| 1464 |
+
"not visible in the current conversation."
|
| 1465 |
+
)
|
| 1466 |
+
existing_instr = ws_response_body.get("instructions") or ""
|
| 1467 |
+
ws_response_body["instructions"] = existing_instr + mem_instruction
|
| 1468 |
+
logger.info(
|
| 1469 |
+
f"[{request_id}] WS Memory: Injected memory tools + instruction"
|
| 1470 |
+
)
|
| 1471 |
+
|
| 1472 |
+
# Write back into envelope if it was wrapped
|
| 1473 |
+
if "response" in body and isinstance(body["response"], dict):
|
| 1474 |
+
body["response"] = ws_response_body
|
| 1475 |
+
else:
|
| 1476 |
+
body = ws_response_body
|
| 1477 |
+
|
| 1478 |
+
first_msg_raw = json.dumps(body)
|
| 1479 |
+
except Exception as e:
|
| 1480 |
+
logger.warning(f"[{request_id}] WS Memory injection failed: {e}")
|
| 1481 |
+
|
| 1482 |
# --- Connect to upstream OpenAI WebSocket ---
|
| 1483 |
logger.info(f"[{request_id}] WS /v1/responses connecting to {upstream_url}")
|
| 1484 |
|
|
|
|
| 1503 |
# Send (potentially compressed) first message
|
| 1504 |
await upstream.send(first_msg_raw)
|
| 1505 |
|
| 1506 |
+
# Bidirectional relay with memory tool interception
|
| 1507 |
async def _client_to_upstream() -> None:
|
| 1508 |
try:
|
| 1509 |
while True:
|
|
|
|
| 1518 |
await upstream.close()
|
| 1519 |
|
| 1520 |
async def _upstream_to_client() -> None:
|
| 1521 |
+
"""Relay upstream→client with transparent memory tool handling.
|
| 1522 |
+
|
| 1523 |
+
Uses a buffer-then-decide approach:
|
| 1524 |
+
1. Buffer events until first output item arrives
|
| 1525 |
+
2. If first output is a memory tool → suppress entire response,
|
| 1526 |
+
execute tools silently, send continuation upstream
|
| 1527 |
+
3. If first output is non-memory → flush buffer, stream normally
|
| 1528 |
+
4. Continuation response events are relayed to Codex seamlessly
|
| 1529 |
+
|
| 1530 |
+
This prevents orphaned response.created events from confusing Codex.
|
| 1531 |
+
"""
|
| 1532 |
+
from headroom.proxy.memory_handler import MEMORY_TOOL_NAMES
|
| 1533 |
+
|
| 1534 |
+
memory_enabled = bool(self.memory_handler and memory_user_id)
|
| 1535 |
+
|
| 1536 |
+
# Per-response state (reset after each response.completed)
|
| 1537 |
+
event_buffer: list[str] = []
|
| 1538 |
+
decided = False
|
| 1539 |
+
suppress_response = False
|
| 1540 |
+
pending_fcs: list[dict[str, Any]] = []
|
| 1541 |
+
resp_id: str | None = None
|
| 1542 |
+
|
| 1543 |
+
def _reset() -> None:
|
| 1544 |
+
nonlocal decided, suppress_response, resp_id
|
| 1545 |
+
event_buffer.clear()
|
| 1546 |
+
decided = False
|
| 1547 |
+
suppress_response = False
|
| 1548 |
+
pending_fcs.clear()
|
| 1549 |
+
resp_id = None
|
| 1550 |
+
|
| 1551 |
try:
|
| 1552 |
async for msg in upstream:
|
| 1553 |
+
if isinstance(msg, bytes):
|
|
|
|
|
|
|
| 1554 |
await websocket.send_bytes(msg)
|
| 1555 |
+
continue
|
| 1556 |
+
msg_str = msg if isinstance(msg, str) else str(msg)
|
| 1557 |
+
|
| 1558 |
+
if not memory_enabled:
|
| 1559 |
+
await websocket.send_text(msg_str)
|
| 1560 |
+
continue
|
| 1561 |
+
|
| 1562 |
+
# Parse event
|
| 1563 |
+
try:
|
| 1564 |
+
event = json.loads(msg_str)
|
| 1565 |
+
except (json.JSONDecodeError, TypeError):
|
| 1566 |
+
await websocket.send_text(msg_str)
|
| 1567 |
+
continue
|
| 1568 |
+
|
| 1569 |
+
event_type = event.get("type", "")
|
| 1570 |
+
|
| 1571 |
+
# --- Phase 1: Buffer until first output item ---
|
| 1572 |
+
if not decided:
|
| 1573 |
+
event_buffer.append(msg_str)
|
| 1574 |
+
|
| 1575 |
+
if event_type == "response.output_item.added":
|
| 1576 |
+
item = event.get("item", {})
|
| 1577 |
+
if (
|
| 1578 |
+
item.get("type") == "function_call"
|
| 1579 |
+
and item.get("name") in MEMORY_TOOL_NAMES
|
| 1580 |
+
):
|
| 1581 |
+
# Memory tool first → suppress entire response
|
| 1582 |
+
suppress_response = True
|
| 1583 |
+
decided = True
|
| 1584 |
+
event_buffer.clear()
|
| 1585 |
+
logger.info(
|
| 1586 |
+
f"[{request_id}] WS Memory: Detected "
|
| 1587 |
+
f"{item.get('name')} — suppressing response"
|
| 1588 |
+
)
|
| 1589 |
+
else:
|
| 1590 |
+
# Non-memory first → flush buffer, pass through
|
| 1591 |
+
decided = True
|
| 1592 |
+
for buf in event_buffer:
|
| 1593 |
+
await websocket.send_text(buf)
|
| 1594 |
+
event_buffer.clear()
|
| 1595 |
+
|
| 1596 |
+
elif event_type == "response.completed":
|
| 1597 |
+
# No output items at all — flush
|
| 1598 |
+
decided = True
|
| 1599 |
+
for buf in event_buffer:
|
| 1600 |
+
await websocket.send_text(buf)
|
| 1601 |
+
event_buffer.clear()
|
| 1602 |
+
_reset()
|
| 1603 |
+
|
| 1604 |
+
continue
|
| 1605 |
+
|
| 1606 |
+
# --- Phase 2a: Suppress mode (memory response) ---
|
| 1607 |
+
if suppress_response:
|
| 1608 |
+
if event_type == "response.output_item.done":
|
| 1609 |
+
item = event.get("item", {})
|
| 1610 |
+
if (
|
| 1611 |
+
item.get("type") == "function_call"
|
| 1612 |
+
and item.get("name") in MEMORY_TOOL_NAMES
|
| 1613 |
+
):
|
| 1614 |
+
pending_fcs.append(item)
|
| 1615 |
+
|
| 1616 |
+
elif event_type == "response.completed":
|
| 1617 |
+
resp = event.get("response", {})
|
| 1618 |
+
resp_id = resp.get("id")
|
| 1619 |
+
|
| 1620 |
+
if pending_fcs:
|
| 1621 |
+
logger.info(
|
| 1622 |
+
f"[{request_id}] WS Memory: Executing "
|
| 1623 |
+
f"{len(pending_fcs)} tool(s) transparently"
|
| 1624 |
+
)
|
| 1625 |
+
|
| 1626 |
+
# Execute memory tool calls
|
| 1627 |
+
tool_outputs: list[dict[str, Any]] = []
|
| 1628 |
+
for fc in pending_fcs:
|
| 1629 |
+
call_id = fc.get("call_id", fc.get("id", ""))
|
| 1630 |
+
fc_name = fc.get("name", "")
|
| 1631 |
+
args_str = fc.get("arguments", "{}")
|
| 1632 |
+
try:
|
| 1633 |
+
fc_args = json.loads(args_str)
|
| 1634 |
+
except json.JSONDecodeError:
|
| 1635 |
+
fc_args = {}
|
| 1636 |
+
|
| 1637 |
+
await self.memory_handler._ensure_initialized()
|
| 1638 |
+
if self.memory_handler._backend:
|
| 1639 |
+
result = await self.memory_handler._execute_memory_tool(
|
| 1640 |
+
fc_name, fc_args, memory_user_id, "openai"
|
| 1641 |
+
)
|
| 1642 |
+
else:
|
| 1643 |
+
result = json.dumps(
|
| 1644 |
+
{"error": "backend not ready"}
|
| 1645 |
+
)
|
| 1646 |
+
|
| 1647 |
+
tool_outputs.append(
|
| 1648 |
+
{
|
| 1649 |
+
"type": "function_call_output",
|
| 1650 |
+
"call_id": call_id,
|
| 1651 |
+
"output": result,
|
| 1652 |
+
}
|
| 1653 |
+
)
|
| 1654 |
+
logger.info(
|
| 1655 |
+
f"[{request_id}] WS Memory: Executed "
|
| 1656 |
+
f"{fc_name} for user {memory_user_id}"
|
| 1657 |
+
)
|
| 1658 |
+
|
| 1659 |
+
# Send continuation upstream
|
| 1660 |
+
cont: dict[str, Any] = {
|
| 1661 |
+
"type": "response.create",
|
| 1662 |
+
"response": {"input": tool_outputs},
|
| 1663 |
+
}
|
| 1664 |
+
if resp_id:
|
| 1665 |
+
cont["response"]["previous_response_id"] = resp_id
|
| 1666 |
+
await upstream.send(json.dumps(cont))
|
| 1667 |
+
logger.info(
|
| 1668 |
+
f"[{request_id}] WS Memory: Sent continuation "
|
| 1669 |
+
f"with {len(tool_outputs)} result(s)"
|
| 1670 |
+
)
|
| 1671 |
+
|
| 1672 |
+
_reset()
|
| 1673 |
+
# All events suppressed in this mode
|
| 1674 |
+
continue
|
| 1675 |
+
|
| 1676 |
+
# --- Phase 2b: Pass-through mode ---
|
| 1677 |
+
await websocket.send_text(msg_str)
|
| 1678 |
+
|
| 1679 |
except Exception as relay_err:
|
| 1680 |
if "WebSocketDisconnect" not in type(relay_err).__name__:
|
| 1681 |
logger.debug(
|
headroom/proxy/handlers/streaming.py
CHANGED
|
@@ -614,15 +614,18 @@ class StreamingMixin:
|
|
| 614 |
f"[{request_id}] Memory: Detected tool calls in streaming response"
|
| 615 |
)
|
| 616 |
|
| 617 |
-
# Execute memory tool calls
|
| 618 |
-
#
|
|
|
|
|
|
|
| 619 |
tool_results = await self.memory_handler.handle_memory_tool_calls(
|
| 620 |
parsed_response, memory_user_id, provider
|
| 621 |
)
|
| 622 |
if tool_results:
|
| 623 |
logger.info(
|
| 624 |
-
f"[{request_id}] Memory: Tool calls executed
|
| 625 |
-
"(
|
|
|
|
| 626 |
)
|
| 627 |
|
| 628 |
# CCR Feedback: Record headroom_retrieve tool calls for TOIN learning.
|
|
|
|
| 614 |
f"[{request_id}] Memory: Detected tool calls in streaming response"
|
| 615 |
)
|
| 616 |
|
| 617 |
+
# Execute memory tool calls — response already streamed
|
| 618 |
+
# so results are saved but continuation is not possible
|
| 619 |
+
# in SSE streaming mode. The WS and non-streaming paths
|
| 620 |
+
# handle continuation properly.
|
| 621 |
tool_results = await self.memory_handler.handle_memory_tool_calls(
|
| 622 |
parsed_response, memory_user_id, provider
|
| 623 |
)
|
| 624 |
if tool_results:
|
| 625 |
logger.info(
|
| 626 |
+
f"[{request_id}] Memory: Tool calls executed "
|
| 627 |
+
f"({len(tool_results)} results saved, SSE streaming — "
|
| 628 |
+
"continuation handled by client)"
|
| 629 |
)
|
| 630 |
|
| 631 |
# CCR Feedback: Record headroom_retrieve tool calls for TOIN learning.
|
headroom/proxy/memory_handler.py
CHANGED
|
@@ -380,7 +380,14 @@ class MemoryHandler:
|
|
| 380 |
entities_str = ", ".join(result.related_entities[:3])
|
| 381 |
memory_lines.append(f" (Related: {entities_str})")
|
| 382 |
|
| 383 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 384 |
|
| 385 |
The following information was previously saved about this user:
|
| 386 |
|
|
@@ -388,15 +395,11 @@ The following information was previously saved about this user:
|
|
| 388 |
|
| 389 |
Use this context to provide personalized and contextually relevant responses."""
|
| 390 |
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
except Exception as e:
|
| 398 |
-
logger.warning(f"Memory: Search failed for user {user_id}: {e}")
|
| 399 |
-
return None
|
| 400 |
|
| 401 |
def _extract_user_query(self, messages: list[dict[str, Any]]) -> str:
|
| 402 |
"""Extract the user query from the last user message."""
|
|
@@ -445,10 +448,25 @@ Use this context to provide personalized and contextually relevant responses."""
|
|
| 445 |
return []
|
| 446 |
|
| 447 |
elif provider == "openai":
|
|
|
|
| 448 |
choices = response.get("choices", [])
|
| 449 |
if choices:
|
| 450 |
message = choices[0].get("message", {})
|
| 451 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 452 |
return []
|
| 453 |
|
| 454 |
return []
|
|
@@ -474,13 +492,15 @@ Use this context to provide personalized and contextually relevant responses."""
|
|
| 474 |
|
| 475 |
for tc in tool_calls:
|
| 476 |
tool_name = tc.get("name") or tc.get("function", {}).get("name")
|
| 477 |
-
tool_id = tc.get("id", "")
|
| 478 |
|
| 479 |
# Parse input data
|
| 480 |
if provider == "anthropic":
|
| 481 |
input_data = tc.get("input", {})
|
| 482 |
else:
|
| 483 |
-
|
|
|
|
|
|
|
| 484 |
try:
|
| 485 |
input_data = json.loads(args_str)
|
| 486 |
except json.JSONDecodeError:
|
|
|
|
| 380 |
entities_str = ", ".join(result.related_entities[:3])
|
| 381 |
memory_lines.append(f" (Related: {entities_str})")
|
| 382 |
|
| 383 |
+
except Exception as e:
|
| 384 |
+
logger.warning(f"Memory: Search failed for user {user_id}: {e}")
|
| 385 |
+
return None
|
| 386 |
+
|
| 387 |
+
if not memory_lines:
|
| 388 |
+
return None
|
| 389 |
+
|
| 390 |
+
context = f"""## Relevant Memories for This User
|
| 391 |
|
| 392 |
The following information was previously saved about this user:
|
| 393 |
|
|
|
|
| 395 |
|
| 396 |
Use this context to provide personalized and contextually relevant responses."""
|
| 397 |
|
| 398 |
+
logger.info(
|
| 399 |
+
f"Memory: Injecting {len(memory_lines)} memories "
|
| 400 |
+
f"({len(context)} chars) for user {user_id}"
|
| 401 |
+
)
|
| 402 |
+
return context
|
|
|
|
|
|
|
|
|
|
|
|
|
| 403 |
|
| 404 |
def _extract_user_query(self, messages: list[dict[str, Any]]) -> str:
|
| 405 |
"""Extract the user query from the last user message."""
|
|
|
|
| 448 |
return []
|
| 449 |
|
| 450 |
elif provider == "openai":
|
| 451 |
+
# Chat Completions format: choices[0].message.tool_calls
|
| 452 |
choices = response.get("choices", [])
|
| 453 |
if choices:
|
| 454 |
message = choices[0].get("message", {})
|
| 455 |
+
tc_list = list(message.get("tool_calls", []) or [])
|
| 456 |
+
if tc_list:
|
| 457 |
+
return tc_list
|
| 458 |
+
|
| 459 |
+
# Responses API format: output[] with type=function_call
|
| 460 |
+
output = response.get("output", [])
|
| 461 |
+
if isinstance(output, list):
|
| 462 |
+
fc_items = [
|
| 463 |
+
item
|
| 464 |
+
for item in output
|
| 465 |
+
if isinstance(item, dict) and item.get("type") == "function_call"
|
| 466 |
+
]
|
| 467 |
+
if fc_items:
|
| 468 |
+
return fc_items
|
| 469 |
+
|
| 470 |
return []
|
| 471 |
|
| 472 |
return []
|
|
|
|
| 492 |
|
| 493 |
for tc in tool_calls:
|
| 494 |
tool_name = tc.get("name") or tc.get("function", {}).get("name")
|
| 495 |
+
tool_id = tc.get("id") or tc.get("call_id", "")
|
| 496 |
|
| 497 |
# Parse input data
|
| 498 |
if provider == "anthropic":
|
| 499 |
input_data = tc.get("input", {})
|
| 500 |
else:
|
| 501 |
+
# Chat Completions format: function.arguments
|
| 502 |
+
# Responses API format: arguments (top-level string)
|
| 503 |
+
args_str = tc.get("arguments") or tc.get("function", {}).get("arguments") or "{}"
|
| 504 |
try:
|
| 505 |
input_data = json.loads(args_str)
|
| 506 |
except json.JSONDecodeError:
|
headroom/proxy/models.py
CHANGED
|
@@ -207,3 +207,6 @@ class ProxyConfig:
|
|
| 207 |
subscription_tracking_enabled: bool = True
|
| 208 |
subscription_poll_interval_s: int = 10
|
| 209 |
subscription_active_window_s: int = 60
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
subscription_tracking_enabled: bool = True
|
| 208 |
subscription_poll_interval_s: int = 10
|
| 209 |
subscription_active_window_s: int = 60
|
| 210 |
+
|
| 211 |
+
# Stateless mode — disable all filesystem writes for read-only / container deployments
|
| 212 |
+
stateless: bool = False
|
headroom/proxy/request_logger.py
CHANGED
|
@@ -8,6 +8,7 @@ Extracted from server.py for maintainability.
|
|
| 8 |
from __future__ import annotations
|
| 9 |
|
| 10 |
import json
|
|
|
|
| 11 |
import sys
|
| 12 |
from collections import deque
|
| 13 |
from dataclasses import asdict
|
|
@@ -19,11 +20,15 @@ if TYPE_CHECKING:
|
|
| 19 |
|
| 20 |
from headroom.proxy.models import RequestLog
|
| 21 |
|
|
|
|
|
|
|
| 22 |
|
| 23 |
class RequestLogger:
|
| 24 |
"""Log requests to JSONL file.
|
| 25 |
|
| 26 |
Uses a deque with max 10,000 entries to prevent unbounded memory growth.
|
|
|
|
|
|
|
| 27 |
"""
|
| 28 |
|
| 29 |
MAX_LOG_ENTRIES = 10_000
|
|
@@ -35,19 +40,30 @@ class RequestLogger:
|
|
| 35 |
self._logs: deque[RequestLog] = deque(maxlen=self.MAX_LOG_ENTRIES)
|
| 36 |
|
| 37 |
if self.log_file:
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
def log(self, entry: RequestLog):
|
| 41 |
"""Log a request. Oldest entries are automatically removed when limit reached."""
|
| 42 |
self._logs.append(entry)
|
| 43 |
|
| 44 |
if self.log_file:
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
def get_recent(self, n: int = 100) -> list[dict]:
|
| 53 |
"""Get recent log entries."""
|
|
|
|
| 8 |
from __future__ import annotations
|
| 9 |
|
| 10 |
import json
|
| 11 |
+
import logging
|
| 12 |
import sys
|
| 13 |
from collections import deque
|
| 14 |
from dataclasses import asdict
|
|
|
|
| 20 |
|
| 21 |
from headroom.proxy.models import RequestLog
|
| 22 |
|
| 23 |
+
logger = logging.getLogger(__name__)
|
| 24 |
+
|
| 25 |
|
| 26 |
class RequestLogger:
|
| 27 |
"""Log requests to JSONL file.
|
| 28 |
|
| 29 |
Uses a deque with max 10,000 entries to prevent unbounded memory growth.
|
| 30 |
+
Gracefully degrades to in-memory-only if the log file cannot be written
|
| 31 |
+
(read-only filesystem, permissions error, etc.).
|
| 32 |
"""
|
| 33 |
|
| 34 |
MAX_LOG_ENTRIES = 10_000
|
|
|
|
| 40 |
self._logs: deque[RequestLog] = deque(maxlen=self.MAX_LOG_ENTRIES)
|
| 41 |
|
| 42 |
if self.log_file:
|
| 43 |
+
try:
|
| 44 |
+
self.log_file.parent.mkdir(parents=True, exist_ok=True)
|
| 45 |
+
except OSError as e:
|
| 46 |
+
logger.warning(
|
| 47 |
+
"Cannot create log directory %s: %s — logging to memory only",
|
| 48 |
+
self.log_file.parent,
|
| 49 |
+
e,
|
| 50 |
+
)
|
| 51 |
+
self.log_file = None
|
| 52 |
|
| 53 |
def log(self, entry: RequestLog):
|
| 54 |
"""Log a request. Oldest entries are automatically removed when limit reached."""
|
| 55 |
self._logs.append(entry)
|
| 56 |
|
| 57 |
if self.log_file:
|
| 58 |
+
try:
|
| 59 |
+
with open(self.log_file, "a") as f:
|
| 60 |
+
log_dict = asdict(entry)
|
| 61 |
+
if not self.log_full_messages:
|
| 62 |
+
log_dict.pop("request_messages", None)
|
| 63 |
+
log_dict.pop("response_content", None)
|
| 64 |
+
f.write(json.dumps(log_dict) + "\n")
|
| 65 |
+
except OSError:
|
| 66 |
+
pass # Graceful degradation: memory-only logging continues
|
| 67 |
|
| 68 |
def get_recent(self, n: int = 100) -> list[dict]:
|
| 69 |
"""Get recent log entries."""
|
headroom/telemetry/toin.py
CHANGED
|
@@ -1174,11 +1174,11 @@ class ToolIntelligenceNetwork:
|
|
| 1174 |
canonical = "null"
|
| 1175 |
elif isinstance(value, bool):
|
| 1176 |
canonical = "true" if value else "false"
|
| 1177 |
-
elif isinstance(value,
|
| 1178 |
canonical = str(value)
|
| 1179 |
elif isinstance(value, str):
|
| 1180 |
canonical = value
|
| 1181 |
-
elif isinstance(value,
|
| 1182 |
# For complex types, use JSON serialization
|
| 1183 |
try:
|
| 1184 |
canonical = json.dumps(value, sort_keys=True, default=str)
|
|
@@ -1573,6 +1573,8 @@ def _create_default_toin_backend() -> Any:
|
|
| 1573 |
backend_type = (os.environ.get(TOIN_BACKEND_ENV_VAR) or "").strip().lower()
|
| 1574 |
if not backend_type or backend_type == "filesystem":
|
| 1575 |
return None
|
|
|
|
|
|
|
| 1576 |
try:
|
| 1577 |
from importlib.metadata import entry_points
|
| 1578 |
|
|
|
|
| 1174 |
canonical = "null"
|
| 1175 |
elif isinstance(value, bool):
|
| 1176 |
canonical = "true" if value else "false"
|
| 1177 |
+
elif isinstance(value, int | float):
|
| 1178 |
canonical = str(value)
|
| 1179 |
elif isinstance(value, str):
|
| 1180 |
canonical = value
|
| 1181 |
+
elif isinstance(value, list | dict):
|
| 1182 |
# For complex types, use JSON serialization
|
| 1183 |
try:
|
| 1184 |
canonical = json.dumps(value, sort_keys=True, default=str)
|
|
|
|
| 1573 |
backend_type = (os.environ.get(TOIN_BACKEND_ENV_VAR) or "").strip().lower()
|
| 1574 |
if not backend_type or backend_type == "filesystem":
|
| 1575 |
return None
|
| 1576 |
+
if backend_type == "none":
|
| 1577 |
+
return None # Explicit in-memory-only (e.g. --stateless mode)
|
| 1578 |
try:
|
| 1579 |
from importlib.metadata import entry_points
|
| 1580 |
|
headroom/transforms/kompress_compressor.py
CHANGED
|
@@ -25,12 +25,12 @@ from .base import Transform
|
|
| 25 |
|
| 26 |
logger = logging.getLogger(__name__)
|
| 27 |
|
| 28 |
-
# HuggingFace model ID
|
| 29 |
HF_MODEL_ID = "chopratejas/kompress-base"
|
| 30 |
|
| 31 |
-
#
|
| 32 |
-
|
| 33 |
-
|
| 34 |
_kompress_lock = threading.Lock()
|
| 35 |
|
| 36 |
|
|
@@ -132,9 +132,6 @@ def _get_model_class() -> type:
|
|
| 132 |
|
| 133 |
# ── Model Loading ─────────────────────────────────────────────────────
|
| 134 |
|
| 135 |
-
# Backend tag: "onnx" or "pytorch"
|
| 136 |
-
_kompress_backend: str | None = None
|
| 137 |
-
|
| 138 |
|
| 139 |
class _OnnxModel:
|
| 140 |
"""Thin wrapper so ONNX session has the same interface as PyTorch model."""
|
|
@@ -163,48 +160,42 @@ class _OnnxModel:
|
|
| 163 |
return (np.array(scores) > 0.5).tolist()
|
| 164 |
|
| 165 |
|
| 166 |
-
def _load_kompress_onnx() -> tuple[Any, Any]:
|
| 167 |
"""Download ONNX INT8 model from HuggingFace and load with onnxruntime."""
|
| 168 |
import onnxruntime as ort
|
| 169 |
from transformers import AutoTokenizer
|
| 170 |
|
| 171 |
-
global _kompress_model, _kompress_tokenizer, _kompress_backend
|
| 172 |
-
|
| 173 |
with _kompress_lock:
|
| 174 |
-
if
|
| 175 |
-
return
|
| 176 |
|
| 177 |
from huggingface_hub import hf_hub_download
|
| 178 |
|
| 179 |
-
logger.info("Downloading Kompress ONNX model from %s ...",
|
| 180 |
-
onnx_path = hf_hub_download(
|
| 181 |
|
| 182 |
session = ort.InferenceSession(onnx_path)
|
| 183 |
model = _OnnxModel(session)
|
| 184 |
tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base")
|
| 185 |
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
logger.info("Kompress ONNX INT8 loaded (no torch dependency)")
|
| 190 |
-
return model, tokenizer
|
| 191 |
|
| 192 |
|
| 193 |
-
def _load_kompress_pytorch(device: str = "auto") -> tuple[Any, Any]:
|
| 194 |
"""Download PyTorch model from HuggingFace and load with torch."""
|
| 195 |
import torch
|
| 196 |
from transformers import AutoTokenizer
|
| 197 |
|
| 198 |
-
global _kompress_model, _kompress_tokenizer, _kompress_backend
|
| 199 |
-
|
| 200 |
with _kompress_lock:
|
| 201 |
-
if
|
| 202 |
-
return
|
| 203 |
|
| 204 |
from huggingface_hub import hf_hub_download
|
| 205 |
|
| 206 |
-
logger.info("Downloading Kompress PyTorch model from %s ...",
|
| 207 |
-
weights_path = hf_hub_download(
|
| 208 |
|
| 209 |
HeadroomCompressorModel = _get_model_class()
|
| 210 |
model = HeadroomCompressorModel()
|
|
@@ -227,50 +218,60 @@ def _load_kompress_pytorch(device: str = "auto") -> tuple[Any, Any]:
|
|
| 227 |
|
| 228 |
tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base")
|
| 229 |
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
logger.info("Kompress PyTorch loaded on %s (%s)", device, HF_MODEL_ID)
|
| 234 |
-
return model, tokenizer
|
| 235 |
|
| 236 |
|
| 237 |
-
def _load_kompress(device: str = "auto") -> tuple[Any, Any]:
|
| 238 |
-
"""Load Kompress model
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
|
|
|
|
|
|
|
|
|
| 242 |
|
| 243 |
# Prefer ONNX (50MB onnxruntime vs 800MB torch)
|
| 244 |
if _is_onnx_available():
|
| 245 |
try:
|
| 246 |
-
return _load_kompress_onnx()
|
| 247 |
except Exception as e:
|
| 248 |
-
logger.warning("ONNX load failed, trying PyTorch: %s", e)
|
| 249 |
|
| 250 |
if _is_pytorch_available():
|
| 251 |
-
return _load_kompress_pytorch(device)
|
| 252 |
|
| 253 |
raise ImportError(
|
| 254 |
"Kompress requires onnxruntime or torch. Install with: pip install headroom-ai[proxy]"
|
| 255 |
)
|
| 256 |
|
| 257 |
|
| 258 |
-
def unload_kompress_model() -> bool:
|
| 259 |
-
"""Unload
|
| 260 |
-
|
|
|
|
|
|
|
|
|
|
| 261 |
with _kompress_lock:
|
| 262 |
-
if
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 267 |
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
|
|
|
|
|
|
| 274 |
|
| 275 |
|
| 276 |
# ── Compressor ────────────────────────────────────────────────────────
|
|
@@ -278,10 +279,26 @@ def unload_kompress_model() -> bool:
|
|
| 278 |
|
| 279 |
@dataclass
|
| 280 |
class KompressConfig:
|
| 281 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 282 |
|
| 283 |
device: str = "auto"
|
| 284 |
enable_ccr: bool = True
|
|
|
|
|
|
|
|
|
|
| 285 |
|
| 286 |
|
| 287 |
@dataclass
|
|
@@ -308,9 +325,10 @@ class KompressResult:
|
|
| 308 |
|
| 309 |
|
| 310 |
class KompressCompressor(Transform):
|
| 311 |
-
"""Kompress: ModernBERT token compressor
|
| 312 |
|
| 313 |
-
Auto-downloads
|
|
|
|
| 314 |
"""
|
| 315 |
|
| 316 |
name: str = "kompress_compressor"
|
|
@@ -347,11 +365,10 @@ class KompressCompressor(Transform):
|
|
| 347 |
return self._passthrough(content, n_words)
|
| 348 |
|
| 349 |
try:
|
| 350 |
-
model, tokenizer = _load_kompress(self.config.device)
|
| 351 |
-
is_onnx =
|
| 352 |
|
| 353 |
-
|
| 354 |
-
max_chunk_words = 350
|
| 355 |
kept_ids: set[int] = set()
|
| 356 |
|
| 357 |
for chunk_start in range(0, n_words, max_chunk_words):
|
|
@@ -423,6 +440,7 @@ class KompressCompressor(Transform):
|
|
| 423 |
original_tokens=n_words,
|
| 424 |
compressed_tokens=compressed_count,
|
| 425 |
compression_ratio=ratio,
|
|
|
|
| 426 |
)
|
| 427 |
|
| 428 |
# CCR marker
|
|
@@ -542,7 +560,7 @@ class KompressCompressor(Transform):
|
|
| 542 |
word_lists: list[list[str]] = [c.split() for c in contents]
|
| 543 |
|
| 544 |
# Short texts short-circuit to passthrough — no model call needed.
|
| 545 |
-
max_chunk_words =
|
| 546 |
chunk_queue: list[tuple[int, int, list[str], float | None]] = []
|
| 547 |
for i, (words, ratio) in enumerate(zip(word_lists, ratios, strict=True)):
|
| 548 |
if len(words) < 10:
|
|
@@ -558,7 +576,7 @@ class KompressCompressor(Transform):
|
|
| 558 |
|
| 559 |
# Load model once for the whole batch.
|
| 560 |
try:
|
| 561 |
-
model, tokenizer = _load_kompress(self.config.device)
|
| 562 |
except Exception as e:
|
| 563 |
logger.warning("Kompress load failed for batch: %s — passthrough all", e)
|
| 564 |
for i in range(n):
|
|
@@ -566,7 +584,7 @@ class KompressCompressor(Transform):
|
|
| 566 |
results[i] = self._passthrough(contents[i], len(word_lists[i]))
|
| 567 |
return [r for r in results if r is not None]
|
| 568 |
|
| 569 |
-
is_onnx =
|
| 570 |
kept_ids_per_text: dict[int, set[int]] = {i: set() for i in range(n) if results[i] is None}
|
| 571 |
|
| 572 |
for batch_start in range(0, len(chunk_queue), batch_size):
|
|
@@ -620,9 +638,9 @@ class KompressCompressor(Transform):
|
|
| 620 |
for wid in sorted_wids[:num_keep]:
|
| 621 |
kept_ids_per_text[text_idx].add(wid + chunk_start)
|
| 622 |
else:
|
| 623 |
-
# Threshold
|
| 624 |
for wid, score in word_scores.items():
|
| 625 |
-
if score >
|
| 626 |
kept_ids_per_text[text_idx].add(wid + chunk_start)
|
| 627 |
|
| 628 |
except Exception as e:
|
|
@@ -659,6 +677,7 @@ class KompressCompressor(Transform):
|
|
| 659 |
original_tokens=n_words,
|
| 660 |
compressed_tokens=compressed_count,
|
| 661 |
compression_ratio=comp_ratio,
|
|
|
|
| 662 |
)
|
| 663 |
|
| 664 |
if self.config.enable_ccr and comp_ratio < 0.8:
|
|
@@ -692,28 +711,28 @@ class KompressCompressor(Transform):
|
|
| 692 |
If the model isn't loaded yet, we trigger loading so the backend
|
| 693 |
is known. This is a no-op if the model is already in cache.
|
| 694 |
"""
|
| 695 |
-
|
| 696 |
-
if
|
| 697 |
try:
|
| 698 |
-
_load_kompress(self.config.device)
|
| 699 |
except Exception:
|
| 700 |
-
# If load fails, caller will see the error downstream.
|
| 701 |
return True
|
| 702 |
|
| 703 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 704 |
return True # ONNX CPU provider doesn't parallelize batch dim
|
| 705 |
-
if
|
| 706 |
try:
|
| 707 |
import torch
|
| 708 |
|
| 709 |
-
|
| 710 |
-
|
| 711 |
-
device
|
| 712 |
-
|
| 713 |
-
return False # GPU benefits from batching
|
| 714 |
-
if device.type == "mps":
|
| 715 |
-
return False # MPS (Apple Silicon) also benefits
|
| 716 |
-
# Fall through for CPU
|
| 717 |
_ = torch
|
| 718 |
except ImportError:
|
| 719 |
return True
|
|
|
|
| 25 |
|
| 26 |
logger = logging.getLogger(__name__)
|
| 27 |
|
| 28 |
+
# Default HuggingFace model ID
|
| 29 |
HF_MODEL_ID = "chopratejas/kompress-base"
|
| 30 |
|
| 31 |
+
# Model cache: model_id -> (model, tokenizer, backend)
|
| 32 |
+
# Supports multiple models loaded simultaneously.
|
| 33 |
+
_kompress_cache: dict[str, tuple[Any, Any, str]] = {}
|
| 34 |
_kompress_lock = threading.Lock()
|
| 35 |
|
| 36 |
|
|
|
|
| 132 |
|
| 133 |
# ── Model Loading ─────────────────────────────────────────────────────
|
| 134 |
|
|
|
|
|
|
|
|
|
|
| 135 |
|
| 136 |
class _OnnxModel:
|
| 137 |
"""Thin wrapper so ONNX session has the same interface as PyTorch model."""
|
|
|
|
| 160 |
return (np.array(scores) > 0.5).tolist()
|
| 161 |
|
| 162 |
|
| 163 |
+
def _load_kompress_onnx(model_id: str) -> tuple[Any, Any, str]:
|
| 164 |
"""Download ONNX INT8 model from HuggingFace and load with onnxruntime."""
|
| 165 |
import onnxruntime as ort
|
| 166 |
from transformers import AutoTokenizer
|
| 167 |
|
|
|
|
|
|
|
| 168 |
with _kompress_lock:
|
| 169 |
+
if model_id in _kompress_cache:
|
| 170 |
+
return _kompress_cache[model_id]
|
| 171 |
|
| 172 |
from huggingface_hub import hf_hub_download
|
| 173 |
|
| 174 |
+
logger.info("Downloading Kompress ONNX model from %s ...", model_id)
|
| 175 |
+
onnx_path = hf_hub_download(model_id, "onnx/kompress-int8.onnx")
|
| 176 |
|
| 177 |
session = ort.InferenceSession(onnx_path)
|
| 178 |
model = _OnnxModel(session)
|
| 179 |
tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base")
|
| 180 |
|
| 181 |
+
_kompress_cache[model_id] = (model, tokenizer, "onnx")
|
| 182 |
+
logger.info("Kompress ONNX INT8 loaded: %s", model_id)
|
| 183 |
+
return model, tokenizer, "onnx"
|
|
|
|
|
|
|
| 184 |
|
| 185 |
|
| 186 |
+
def _load_kompress_pytorch(model_id: str, device: str = "auto") -> tuple[Any, Any, str]:
|
| 187 |
"""Download PyTorch model from HuggingFace and load with torch."""
|
| 188 |
import torch
|
| 189 |
from transformers import AutoTokenizer
|
| 190 |
|
|
|
|
|
|
|
| 191 |
with _kompress_lock:
|
| 192 |
+
if model_id in _kompress_cache:
|
| 193 |
+
return _kompress_cache[model_id]
|
| 194 |
|
| 195 |
from huggingface_hub import hf_hub_download
|
| 196 |
|
| 197 |
+
logger.info("Downloading Kompress PyTorch model from %s ...", model_id)
|
| 198 |
+
weights_path = hf_hub_download(model_id, "model.safetensors")
|
| 199 |
|
| 200 |
HeadroomCompressorModel = _get_model_class()
|
| 201 |
model = HeadroomCompressorModel()
|
|
|
|
| 218 |
|
| 219 |
tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base")
|
| 220 |
|
| 221 |
+
_kompress_cache[model_id] = (model, tokenizer, "pytorch")
|
| 222 |
+
logger.info("Kompress PyTorch loaded on %s (%s)", device, model_id)
|
| 223 |
+
return model, tokenizer, "pytorch"
|
|
|
|
|
|
|
| 224 |
|
| 225 |
|
| 226 |
+
def _load_kompress(model_id: str = HF_MODEL_ID, device: str = "auto") -> tuple[Any, Any, str]:
|
| 227 |
+
"""Load Kompress model, returns (model, tokenizer, backend).
|
| 228 |
+
|
| 229 |
+
Try ONNX first (lightweight), fall back to PyTorch.
|
| 230 |
+
Models are cached by model_id — multiple models can coexist.
|
| 231 |
+
"""
|
| 232 |
+
if model_id in _kompress_cache:
|
| 233 |
+
return _kompress_cache[model_id]
|
| 234 |
|
| 235 |
# Prefer ONNX (50MB onnxruntime vs 800MB torch)
|
| 236 |
if _is_onnx_available():
|
| 237 |
try:
|
| 238 |
+
return _load_kompress_onnx(model_id)
|
| 239 |
except Exception as e:
|
| 240 |
+
logger.warning("ONNX load failed for %s, trying PyTorch: %s", model_id, e)
|
| 241 |
|
| 242 |
if _is_pytorch_available():
|
| 243 |
+
return _load_kompress_pytorch(model_id, device)
|
| 244 |
|
| 245 |
raise ImportError(
|
| 246 |
"Kompress requires onnxruntime or torch. Install with: pip install headroom-ai[proxy]"
|
| 247 |
)
|
| 248 |
|
| 249 |
|
| 250 |
+
def unload_kompress_model(model_id: str | None = None) -> bool:
|
| 251 |
+
"""Unload Kompress model(s) to free memory.
|
| 252 |
+
|
| 253 |
+
Args:
|
| 254 |
+
model_id: Specific model to unload. If None, unloads all cached models.
|
| 255 |
+
"""
|
| 256 |
with _kompress_lock:
|
| 257 |
+
if model_id is not None:
|
| 258 |
+
if model_id in _kompress_cache:
|
| 259 |
+
del _kompress_cache[model_id]
|
| 260 |
+
else:
|
| 261 |
+
return False
|
| 262 |
+
elif _kompress_cache:
|
| 263 |
+
_kompress_cache.clear()
|
| 264 |
+
else:
|
| 265 |
+
return False
|
| 266 |
|
| 267 |
+
try:
|
| 268 |
+
import torch
|
| 269 |
+
|
| 270 |
+
if torch.cuda.is_available():
|
| 271 |
+
torch.cuda.empty_cache()
|
| 272 |
+
except ImportError:
|
| 273 |
+
pass
|
| 274 |
+
return True
|
| 275 |
|
| 276 |
|
| 277 |
# ── Compressor ────────────────────────────────────────────────────────
|
|
|
|
| 279 |
|
| 280 |
@dataclass
|
| 281 |
class KompressConfig:
|
| 282 |
+
"""Configuration for Kompress compression.
|
| 283 |
+
|
| 284 |
+
The model_id, chunk_words, and score_threshold are coupled: a model
|
| 285 |
+
trained on 50-word chunks needs chunk_words=50 at inference. The
|
| 286 |
+
defaults match kompress-base. For domain-specific models, set all three.
|
| 287 |
+
|
| 288 |
+
Example — financial documents::
|
| 289 |
+
|
| 290 |
+
KompressConfig(
|
| 291 |
+
model_id="chopratejas/kompress-finance",
|
| 292 |
+
chunk_words=50,
|
| 293 |
+
score_threshold=0.5,
|
| 294 |
+
)
|
| 295 |
+
"""
|
| 296 |
|
| 297 |
device: str = "auto"
|
| 298 |
enable_ccr: bool = True
|
| 299 |
+
model_id: str = HF_MODEL_ID
|
| 300 |
+
chunk_words: int = 350
|
| 301 |
+
score_threshold: float = 0.5
|
| 302 |
|
| 303 |
|
| 304 |
@dataclass
|
|
|
|
| 325 |
|
| 326 |
|
| 327 |
class KompressCompressor(Transform):
|
| 328 |
+
"""Kompress: ModernBERT token compressor.
|
| 329 |
|
| 330 |
+
Auto-downloads the model from HuggingFace on first use.
|
| 331 |
+
Configure via KompressConfig to select model, chunk size, and threshold.
|
| 332 |
"""
|
| 333 |
|
| 334 |
name: str = "kompress_compressor"
|
|
|
|
| 365 |
return self._passthrough(content, n_words)
|
| 366 |
|
| 367 |
try:
|
| 368 |
+
model, tokenizer, backend = _load_kompress(self.config.model_id, self.config.device)
|
| 369 |
+
is_onnx = backend == "onnx"
|
| 370 |
|
| 371 |
+
max_chunk_words = self.config.chunk_words
|
|
|
|
| 372 |
kept_ids: set[int] = set()
|
| 373 |
|
| 374 |
for chunk_start in range(0, n_words, max_chunk_words):
|
|
|
|
| 440 |
original_tokens=n_words,
|
| 441 |
compressed_tokens=compressed_count,
|
| 442 |
compression_ratio=ratio,
|
| 443 |
+
model_used=self.config.model_id,
|
| 444 |
)
|
| 445 |
|
| 446 |
# CCR marker
|
|
|
|
| 560 |
word_lists: list[list[str]] = [c.split() for c in contents]
|
| 561 |
|
| 562 |
# Short texts short-circuit to passthrough — no model call needed.
|
| 563 |
+
max_chunk_words = self.config.chunk_words
|
| 564 |
chunk_queue: list[tuple[int, int, list[str], float | None]] = []
|
| 565 |
for i, (words, ratio) in enumerate(zip(word_lists, ratios, strict=True)):
|
| 566 |
if len(words) < 10:
|
|
|
|
| 576 |
|
| 577 |
# Load model once for the whole batch.
|
| 578 |
try:
|
| 579 |
+
model, tokenizer, backend = _load_kompress(self.config.model_id, self.config.device)
|
| 580 |
except Exception as e:
|
| 581 |
logger.warning("Kompress load failed for batch: %s — passthrough all", e)
|
| 582 |
for i in range(n):
|
|
|
|
| 584 |
results[i] = self._passthrough(contents[i], len(word_lists[i]))
|
| 585 |
return [r for r in results if r is not None]
|
| 586 |
|
| 587 |
+
is_onnx = backend == "onnx"
|
| 588 |
kept_ids_per_text: dict[int, set[int]] = {i: set() for i in range(n) if results[i] is None}
|
| 589 |
|
| 590 |
for batch_start in range(0, len(chunk_queue), batch_size):
|
|
|
|
| 638 |
for wid in sorted_wids[:num_keep]:
|
| 639 |
kept_ids_per_text[text_idx].add(wid + chunk_start)
|
| 640 |
else:
|
| 641 |
+
# Threshold from config (default 0.5, matches ONNX get_keep_mask).
|
| 642 |
for wid, score in word_scores.items():
|
| 643 |
+
if score > self.config.score_threshold:
|
| 644 |
kept_ids_per_text[text_idx].add(wid + chunk_start)
|
| 645 |
|
| 646 |
except Exception as e:
|
|
|
|
| 677 |
original_tokens=n_words,
|
| 678 |
compressed_tokens=compressed_count,
|
| 679 |
compression_ratio=comp_ratio,
|
| 680 |
+
model_used=self.config.model_id,
|
| 681 |
)
|
| 682 |
|
| 683 |
if self.config.enable_ccr and comp_ratio < 0.8:
|
|
|
|
| 711 |
If the model isn't loaded yet, we trigger loading so the backend
|
| 712 |
is known. This is a no-op if the model is already in cache.
|
| 713 |
"""
|
| 714 |
+
model_id = self.config.model_id
|
| 715 |
+
if model_id not in _kompress_cache:
|
| 716 |
try:
|
| 717 |
+
_load_kompress(model_id, self.config.device)
|
| 718 |
except Exception:
|
|
|
|
| 719 |
return True
|
| 720 |
|
| 721 |
+
if model_id not in _kompress_cache:
|
| 722 |
+
return True
|
| 723 |
+
|
| 724 |
+
model, _tokenizer, backend = _kompress_cache[model_id]
|
| 725 |
+
|
| 726 |
+
if backend == "onnx":
|
| 727 |
return True # ONNX CPU provider doesn't parallelize batch dim
|
| 728 |
+
if backend == "pytorch":
|
| 729 |
try:
|
| 730 |
import torch
|
| 731 |
|
| 732 |
+
if hasattr(model, "parameters"):
|
| 733 |
+
device = next(model.parameters()).device
|
| 734 |
+
if device.type in ("cuda", "mps"):
|
| 735 |
+
return False # GPU/MPS benefits from batching
|
|
|
|
|
|
|
|
|
|
|
|
|
| 736 |
_ = torch
|
| 737 |
except ImportError:
|
| 738 |
return True
|
headroom/transforms/smart_crusher.py
CHANGED
|
@@ -180,27 +180,31 @@ def _hash_field_name(field_name: str) -> str:
|
|
| 180 |
# Minimum chars for a text field to be worth compressing within an item
|
| 181 |
_MIN_FIELD_CHARS_FOR_WITHIN = 200
|
| 182 |
|
| 183 |
-
# Lazy-loaded compressor for within-item text compression
|
| 184 |
_within_compressor: Any = None
|
| 185 |
_within_compressor_checked = False
|
|
|
|
| 186 |
|
| 187 |
|
| 188 |
def _get_within_compressor() -> Any:
|
| 189 |
"""Get a text compressor for within-item field compression.
|
| 190 |
|
| 191 |
Returns Kompress if available (requires [ml] extra), else None.
|
|
|
|
| 192 |
"""
|
| 193 |
global _within_compressor, _within_compressor_checked
|
| 194 |
if not _within_compressor_checked:
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
|
|
|
|
|
|
| 204 |
return _within_compressor
|
| 205 |
|
| 206 |
|
|
@@ -435,7 +439,7 @@ def _detect_sequential_pattern(values: list[Any], check_order: bool = True) -> b
|
|
| 435 |
# Get numeric values
|
| 436 |
nums = []
|
| 437 |
for v in values:
|
| 438 |
-
if isinstance(v,
|
| 439 |
nums.append(v)
|
| 440 |
elif isinstance(v, str):
|
| 441 |
try:
|
|
@@ -546,7 +550,6 @@ def _detect_score_field_statistically(stats: FieldStats, items: list[dict]) -> t
|
|
| 546 |
confidence = 0.0
|
| 547 |
|
| 548 |
# Check for bounded range typical of scores
|
| 549 |
-
stats.max_val - stats.min_val
|
| 550 |
min_val, max_val = stats.min_val, stats.max_val
|
| 551 |
|
| 552 |
# Common score ranges: [0,1], [0,10], [0,100], [-1,1], [0,5]
|
|
@@ -578,7 +581,7 @@ def _detect_score_field_statistically(stats: FieldStats, items: list[dict]) -> t
|
|
| 578 |
for item in items:
|
| 579 |
if stats.name in item:
|
| 580 |
val = item.get(stats.name)
|
| 581 |
-
if isinstance(val,
|
| 582 |
values_in_order.append(float(val))
|
| 583 |
if len(values_in_order) >= 5:
|
| 584 |
# Check for descending sort
|
|
@@ -804,11 +807,11 @@ def _detect_items_by_learned_semantics(
|
|
| 804 |
value_canonical = "null"
|
| 805 |
elif isinstance(value, bool):
|
| 806 |
value_canonical = "true" if value else "false"
|
| 807 |
-
elif isinstance(value,
|
| 808 |
value_canonical = str(value)
|
| 809 |
elif isinstance(value, str):
|
| 810 |
value_canonical = value
|
| 811 |
-
elif isinstance(value,
|
| 812 |
try:
|
| 813 |
value_canonical = json.dumps(value, sort_keys=True, default=str)
|
| 814 |
except (TypeError, ValueError):
|
|
@@ -1030,7 +1033,7 @@ class SmartAnalyzer:
|
|
| 1030 |
first_val = non_null_values[0]
|
| 1031 |
if isinstance(first_val, bool):
|
| 1032 |
field_type = "boolean"
|
| 1033 |
-
elif isinstance(first_val,
|
| 1034 |
field_type = "numeric"
|
| 1035 |
elif isinstance(first_val, str):
|
| 1036 |
field_type = "string"
|
|
@@ -1064,7 +1067,7 @@ class SmartAnalyzer:
|
|
| 1064 |
# Numeric-specific analysis
|
| 1065 |
if field_type == "numeric":
|
| 1066 |
# Filter out NaN and Infinity which break statistics functions
|
| 1067 |
-
nums = [v for v in non_null_values if isinstance(v,
|
| 1068 |
if nums:
|
| 1069 |
try:
|
| 1070 |
stats.min_val = min(nums)
|
|
@@ -1283,7 +1286,7 @@ class SmartAnalyzer:
|
|
| 1283 |
threshold = self.config.variance_threshold * std
|
| 1284 |
for i, item in enumerate(items):
|
| 1285 |
val = item.get(stats.name)
|
| 1286 |
-
if isinstance(val,
|
| 1287 |
if abs(val - stats.mean_val) > threshold:
|
| 1288 |
anomaly_indices.add(i)
|
| 1289 |
|
|
@@ -1953,9 +1956,9 @@ class SmartCrusher(Transform):
|
|
| 1953 |
if len(keep_indices) <= effective_max:
|
| 1954 |
return keep_indices
|
| 1955 |
|
| 1956 |
-
# Use provided field_semantics or fall back to
|
| 1957 |
effective_field_semantics = field_semantics or getattr(
|
| 1958 |
-
self, "
|
| 1959 |
)
|
| 1960 |
|
| 1961 |
# Identify error items using KEYWORD detection (preservation guarantee)
|
|
@@ -1976,7 +1979,7 @@ class SmartCrusher(Transform):
|
|
| 1976 |
threshold = self.config.variance_threshold * std
|
| 1977 |
for i, item in enumerate(items):
|
| 1978 |
val = item.get(field_name)
|
| 1979 |
-
if isinstance(val,
|
| 1980 |
if abs(val - stats.mean_val) > threshold:
|
| 1981 |
anomaly_indices.add(i)
|
| 1982 |
|
|
@@ -2297,6 +2300,10 @@ class SmartCrusher(Transform):
|
|
| 2297 |
|
| 2298 |
return result, was_modified, info
|
| 2299 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2300 |
def _process_value(
|
| 2301 |
self,
|
| 2302 |
value: Any,
|
|
@@ -2311,6 +2318,10 @@ class SmartCrusher(Transform):
|
|
| 2311 |
Tuple of (processed_value, info_string, ccr_markers).
|
| 2312 |
ccr_markers is a list of (hash, original_count, compressed_count, summary) tuples.
|
| 2313 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2314 |
info_parts = []
|
| 2315 |
ccr_markers: list[tuple] = []
|
| 2316 |
|
|
@@ -2495,9 +2506,12 @@ class SmartCrusher(Transform):
|
|
| 2495 |
)
|
| 2496 |
|
| 2497 |
# === TOIN Evolution: Extract field semantics for signal detection ===
|
| 2498 |
-
# Store
|
| 2499 |
# This enables learned signal detection without changing all method signatures
|
| 2500 |
-
|
|
|
|
|
|
|
|
|
|
| 2501 |
toin_hint.field_semantics if toin_hint.field_semantics else None
|
| 2502 |
)
|
| 2503 |
|
|
@@ -2661,12 +2675,14 @@ class SmartCrusher(Transform):
|
|
| 2661 |
)
|
| 2662 |
|
| 2663 |
# Clean up temporary instance variable
|
| 2664 |
-
self
|
|
|
|
| 2665 |
return result, strategy_info, ccr_hash, dropped_summary
|
| 2666 |
|
| 2667 |
except Exception:
|
| 2668 |
# Clean up temporary instance variable
|
| 2669 |
-
self
|
|
|
|
| 2670 |
# Re-raise any exceptions (removed finally block since we no longer mutate config)
|
| 2671 |
raise
|
| 2672 |
|
|
@@ -2814,7 +2830,7 @@ class SmartCrusher(Transform):
|
|
| 2814 |
return items, "number:passthrough"
|
| 2815 |
|
| 2816 |
# Filter out non-finite values for statistics
|
| 2817 |
-
finite = [x for x in items if isinstance(x,
|
| 2818 |
if not finite:
|
| 2819 |
return items, "number:no_finite"
|
| 2820 |
|
|
@@ -2832,7 +2848,7 @@ class SmartCrusher(Transform):
|
|
| 2832 |
outlier_indices: set[int] = set()
|
| 2833 |
if std_val > 0:
|
| 2834 |
for i, val in enumerate(items):
|
| 2835 |
-
if isinstance(val,
|
| 2836 |
if abs(val - mean_val) > self.config.variance_threshold * std_val:
|
| 2837 |
outlier_indices.add(i)
|
| 2838 |
|
|
@@ -2844,12 +2860,12 @@ class SmartCrusher(Transform):
|
|
| 2844 |
left = [
|
| 2845 |
items[j]
|
| 2846 |
for j in range(i - window, i)
|
| 2847 |
-
if isinstance(items[j],
|
| 2848 |
]
|
| 2849 |
right = [
|
| 2850 |
items[j]
|
| 2851 |
for j in range(i, i + window)
|
| 2852 |
-
if isinstance(items[j],
|
| 2853 |
]
|
| 2854 |
if left and right:
|
| 2855 |
left_mean = statistics.mean(left)
|
|
@@ -2877,27 +2893,23 @@ class SmartCrusher(Transform):
|
|
| 2877 |
if i not in keep_indices:
|
| 2878 |
keep_indices.add(i)
|
| 2879 |
|
| 2880 |
-
# Build output:
|
| 2881 |
-
stats_summary = (
|
| 2882 |
-
f"[{n} numbers: min={min(finite)}, max={max(finite)}, "
|
| 2883 |
-
f"mean={mean_val:.4g}, median={median_val:.4g}, "
|
| 2884 |
-
f"stddev={std_val:.4g}, p25={p25:.4g}, p75={p75:.4g}"
|
| 2885 |
-
)
|
| 2886 |
-
if outlier_indices:
|
| 2887 |
-
stats_summary += f", outliers={len(outlier_indices)}"
|
| 2888 |
-
if change_indices:
|
| 2889 |
-
stats_summary += f", change_points={len(change_indices)}"
|
| 2890 |
-
stats_summary += "]"
|
| 2891 |
-
|
| 2892 |
kept_values = [items[i] for i in sorted(keep_indices)]
|
| 2893 |
-
result: list = [stats_summary] + kept_values
|
| 2894 |
|
| 2895 |
-
strategy
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2896 |
if outlier_indices:
|
| 2897 |
strategy += f",outliers={len(outlier_indices)}"
|
|
|
|
|
|
|
| 2898 |
strategy += ")"
|
| 2899 |
|
| 2900 |
-
return
|
| 2901 |
|
| 2902 |
def _crush_mixed_array(
|
| 2903 |
self,
|
|
@@ -2930,7 +2942,7 @@ class SmartCrusher(Transform):
|
|
| 2930 |
key = "str"
|
| 2931 |
elif isinstance(item, bool):
|
| 2932 |
key = "bool"
|
| 2933 |
-
elif isinstance(item,
|
| 2934 |
key = "number"
|
| 2935 |
elif isinstance(item, list):
|
| 2936 |
key = "list"
|
|
@@ -2979,13 +2991,13 @@ class SmartCrusher(Transform):
|
|
| 2979 |
last_idx = set(indices[-k_last:])
|
| 2980 |
keep_indices.update(first_idx | last_idx)
|
| 2981 |
# Outliers
|
| 2982 |
-
finite = [v for v in values if isinstance(v,
|
| 2983 |
if len(finite) > 1:
|
| 2984 |
mean_v = statistics.mean(finite)
|
| 2985 |
std_v = statistics.stdev(finite)
|
| 2986 |
if std_v > 0:
|
| 2987 |
for idx, val in group_items:
|
| 2988 |
-
if isinstance(val,
|
| 2989 |
if abs(val - mean_v) > self.config.variance_threshold * std_v:
|
| 2990 |
keep_indices.add(idx)
|
| 2991 |
strategy_parts.append(f"num:{len(values)}")
|
|
@@ -3553,7 +3565,7 @@ class SmartCrusher(Transform):
|
|
| 3553 |
threshold = self.config.variance_threshold * std
|
| 3554 |
for i, item in enumerate(items):
|
| 3555 |
val = item.get(name)
|
| 3556 |
-
if isinstance(val,
|
| 3557 |
if abs(val - stats.mean_val) > threshold:
|
| 3558 |
keep_indices.add(i)
|
| 3559 |
|
|
|
|
| 180 |
# Minimum chars for a text field to be worth compressing within an item
|
| 181 |
_MIN_FIELD_CHARS_FOR_WITHIN = 200
|
| 182 |
|
| 183 |
+
# Lazy-loaded compressor for within-item text compression (thread-safe)
|
| 184 |
_within_compressor: Any = None
|
| 185 |
_within_compressor_checked = False
|
| 186 |
+
_within_compressor_lock = threading.Lock()
|
| 187 |
|
| 188 |
|
| 189 |
def _get_within_compressor() -> Any:
|
| 190 |
"""Get a text compressor for within-item field compression.
|
| 191 |
|
| 192 |
Returns Kompress if available (requires [ml] extra), else None.
|
| 193 |
+
Thread-safe via double-checked locking.
|
| 194 |
"""
|
| 195 |
global _within_compressor, _within_compressor_checked
|
| 196 |
if not _within_compressor_checked:
|
| 197 |
+
with _within_compressor_lock:
|
| 198 |
+
if not _within_compressor_checked:
|
| 199 |
+
try:
|
| 200 |
+
from .kompress_compressor import KompressCompressor, is_kompress_available
|
| 201 |
+
|
| 202 |
+
if is_kompress_available():
|
| 203 |
+
_within_compressor = KompressCompressor()
|
| 204 |
+
logger.debug("Within-item compression: using Kompress")
|
| 205 |
+
except ImportError:
|
| 206 |
+
pass
|
| 207 |
+
_within_compressor_checked = True
|
| 208 |
return _within_compressor
|
| 209 |
|
| 210 |
|
|
|
|
| 439 |
# Get numeric values
|
| 440 |
nums = []
|
| 441 |
for v in values:
|
| 442 |
+
if isinstance(v, int | float) and not isinstance(v, bool):
|
| 443 |
nums.append(v)
|
| 444 |
elif isinstance(v, str):
|
| 445 |
try:
|
|
|
|
| 550 |
confidence = 0.0
|
| 551 |
|
| 552 |
# Check for bounded range typical of scores
|
|
|
|
| 553 |
min_val, max_val = stats.min_val, stats.max_val
|
| 554 |
|
| 555 |
# Common score ranges: [0,1], [0,10], [0,100], [-1,1], [0,5]
|
|
|
|
| 581 |
for item in items:
|
| 582 |
if stats.name in item:
|
| 583 |
val = item.get(stats.name)
|
| 584 |
+
if isinstance(val, int | float) and math.isfinite(val):
|
| 585 |
values_in_order.append(float(val))
|
| 586 |
if len(values_in_order) >= 5:
|
| 587 |
# Check for descending sort
|
|
|
|
| 807 |
value_canonical = "null"
|
| 808 |
elif isinstance(value, bool):
|
| 809 |
value_canonical = "true" if value else "false"
|
| 810 |
+
elif isinstance(value, int | float):
|
| 811 |
value_canonical = str(value)
|
| 812 |
elif isinstance(value, str):
|
| 813 |
value_canonical = value
|
| 814 |
+
elif isinstance(value, list | dict):
|
| 815 |
try:
|
| 816 |
value_canonical = json.dumps(value, sort_keys=True, default=str)
|
| 817 |
except (TypeError, ValueError):
|
|
|
|
| 1033 |
first_val = non_null_values[0]
|
| 1034 |
if isinstance(first_val, bool):
|
| 1035 |
field_type = "boolean"
|
| 1036 |
+
elif isinstance(first_val, int | float):
|
| 1037 |
field_type = "numeric"
|
| 1038 |
elif isinstance(first_val, str):
|
| 1039 |
field_type = "string"
|
|
|
|
| 1067 |
# Numeric-specific analysis
|
| 1068 |
if field_type == "numeric":
|
| 1069 |
# Filter out NaN and Infinity which break statistics functions
|
| 1070 |
+
nums = [v for v in non_null_values if isinstance(v, int | float) and math.isfinite(v)]
|
| 1071 |
if nums:
|
| 1072 |
try:
|
| 1073 |
stats.min_val = min(nums)
|
|
|
|
| 1286 |
threshold = self.config.variance_threshold * std
|
| 1287 |
for i, item in enumerate(items):
|
| 1288 |
val = item.get(stats.name)
|
| 1289 |
+
if isinstance(val, int | float):
|
| 1290 |
if abs(val - stats.mean_val) > threshold:
|
| 1291 |
anomaly_indices.add(i)
|
| 1292 |
|
|
|
|
| 1956 |
if len(keep_indices) <= effective_max:
|
| 1957 |
return keep_indices
|
| 1958 |
|
| 1959 |
+
# Use provided field_semantics or fall back to thread-local (set by _crush_array)
|
| 1960 |
effective_field_semantics = field_semantics or getattr(
|
| 1961 |
+
getattr(self, "_thread_local", None), "field_semantics", None
|
| 1962 |
)
|
| 1963 |
|
| 1964 |
# Identify error items using KEYWORD detection (preservation guarantee)
|
|
|
|
| 1979 |
threshold = self.config.variance_threshold * std
|
| 1980 |
for i, item in enumerate(items):
|
| 1981 |
val = item.get(field_name)
|
| 1982 |
+
if isinstance(val, int | float):
|
| 1983 |
if abs(val - stats.mean_val) > threshold:
|
| 1984 |
anomaly_indices.add(i)
|
| 1985 |
|
|
|
|
| 2300 |
|
| 2301 |
return result, was_modified, info
|
| 2302 |
|
| 2303 |
+
# Maximum recursion depth for nested JSON processing.
|
| 2304 |
+
# Prevents RecursionError on adversarial/deeply-nested input.
|
| 2305 |
+
_MAX_PROCESS_DEPTH = 50
|
| 2306 |
+
|
| 2307 |
def _process_value(
|
| 2308 |
self,
|
| 2309 |
value: Any,
|
|
|
|
| 2318 |
Tuple of (processed_value, info_string, ccr_markers).
|
| 2319 |
ccr_markers is a list of (hash, original_count, compressed_count, summary) tuples.
|
| 2320 |
"""
|
| 2321 |
+
# Guard against deeply nested JSON causing RecursionError
|
| 2322 |
+
if depth >= self._MAX_PROCESS_DEPTH:
|
| 2323 |
+
return value, "", []
|
| 2324 |
+
|
| 2325 |
info_parts = []
|
| 2326 |
ccr_markers: list[tuple] = []
|
| 2327 |
|
|
|
|
| 2506 |
)
|
| 2507 |
|
| 2508 |
# === TOIN Evolution: Extract field semantics for signal detection ===
|
| 2509 |
+
# Store in thread-local storage for use in _prioritize_indices.
|
| 2510 |
# This enables learned signal detection without changing all method signatures
|
| 2511 |
+
# while remaining thread-safe (no cross-thread contamination).
|
| 2512 |
+
if not hasattr(self, "_thread_local"):
|
| 2513 |
+
self._thread_local = threading.local()
|
| 2514 |
+
self._thread_local.field_semantics = (
|
| 2515 |
toin_hint.field_semantics if toin_hint.field_semantics else None
|
| 2516 |
)
|
| 2517 |
|
|
|
|
| 2675 |
)
|
| 2676 |
|
| 2677 |
# Clean up temporary instance variable
|
| 2678 |
+
if hasattr(self, "_thread_local"):
|
| 2679 |
+
self._thread_local.field_semantics = None
|
| 2680 |
return result, strategy_info, ccr_hash, dropped_summary
|
| 2681 |
|
| 2682 |
except Exception:
|
| 2683 |
# Clean up temporary instance variable
|
| 2684 |
+
if hasattr(self, "_thread_local"):
|
| 2685 |
+
self._thread_local.field_semantics = None
|
| 2686 |
# Re-raise any exceptions (removed finally block since we no longer mutate config)
|
| 2687 |
raise
|
| 2688 |
|
|
|
|
| 2830 |
return items, "number:passthrough"
|
| 2831 |
|
| 2832 |
# Filter out non-finite values for statistics
|
| 2833 |
+
finite = [x for x in items if isinstance(x, int | float) and math.isfinite(x)]
|
| 2834 |
if not finite:
|
| 2835 |
return items, "number:no_finite"
|
| 2836 |
|
|
|
|
| 2848 |
outlier_indices: set[int] = set()
|
| 2849 |
if std_val > 0:
|
| 2850 |
for i, val in enumerate(items):
|
| 2851 |
+
if isinstance(val, int | float) and math.isfinite(val):
|
| 2852 |
if abs(val - mean_val) > self.config.variance_threshold * std_val:
|
| 2853 |
outlier_indices.add(i)
|
| 2854 |
|
|
|
|
| 2860 |
left = [
|
| 2861 |
items[j]
|
| 2862 |
for j in range(i - window, i)
|
| 2863 |
+
if isinstance(items[j], int | float) and math.isfinite(items[j])
|
| 2864 |
]
|
| 2865 |
right = [
|
| 2866 |
items[j]
|
| 2867 |
for j in range(i, i + window)
|
| 2868 |
+
if isinstance(items[j], int | float) and math.isfinite(items[j])
|
| 2869 |
]
|
| 2870 |
if left and right:
|
| 2871 |
left_mean = statistics.mean(left)
|
|
|
|
| 2893 |
if i not in keep_indices:
|
| 2894 |
keep_indices.add(i)
|
| 2895 |
|
| 2896 |
+
# Build output: kept values only (schema-preserving — no generated text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2897 |
kept_values = [items[i] for i in sorted(keep_indices)]
|
|
|
|
| 2898 |
|
| 2899 |
+
# Encode statistics into the strategy string (not the array itself)
|
| 2900 |
+
strategy = (
|
| 2901 |
+
f"number:adaptive({n}->{len(kept_values)}"
|
| 2902 |
+
f",min={min(finite)},max={max(finite)}"
|
| 2903 |
+
f",mean={mean_val:.4g},median={median_val:.4g}"
|
| 2904 |
+
f",stddev={std_val:.4g},p25={p25:.4g},p75={p75:.4g}"
|
| 2905 |
+
)
|
| 2906 |
if outlier_indices:
|
| 2907 |
strategy += f",outliers={len(outlier_indices)}"
|
| 2908 |
+
if change_indices:
|
| 2909 |
+
strategy += f",change_points={len(change_indices)}"
|
| 2910 |
strategy += ")"
|
| 2911 |
|
| 2912 |
+
return kept_values, strategy
|
| 2913 |
|
| 2914 |
def _crush_mixed_array(
|
| 2915 |
self,
|
|
|
|
| 2942 |
key = "str"
|
| 2943 |
elif isinstance(item, bool):
|
| 2944 |
key = "bool"
|
| 2945 |
+
elif isinstance(item, int | float):
|
| 2946 |
key = "number"
|
| 2947 |
elif isinstance(item, list):
|
| 2948 |
key = "list"
|
|
|
|
| 2991 |
last_idx = set(indices[-k_last:])
|
| 2992 |
keep_indices.update(first_idx | last_idx)
|
| 2993 |
# Outliers
|
| 2994 |
+
finite = [v for v in values if isinstance(v, int | float) and math.isfinite(v)]
|
| 2995 |
if len(finite) > 1:
|
| 2996 |
mean_v = statistics.mean(finite)
|
| 2997 |
std_v = statistics.stdev(finite)
|
| 2998 |
if std_v > 0:
|
| 2999 |
for idx, val in group_items:
|
| 3000 |
+
if isinstance(val, int | float) and math.isfinite(val):
|
| 3001 |
if abs(val - mean_v) > self.config.variance_threshold * std_v:
|
| 3002 |
keep_indices.add(idx)
|
| 3003 |
strategy_parts.append(f"num:{len(values)}")
|
|
|
|
| 3565 |
threshold = self.config.variance_threshold * std
|
| 3566 |
for i, item in enumerate(items):
|
| 3567 |
val = item.get(name)
|
| 3568 |
+
if isinstance(val, int | float):
|
| 3569 |
if abs(val - stats.mean_val) > threshold:
|
| 3570 |
keep_indices.add(i)
|
| 3571 |
|
plugins/openclaw/package.json
CHANGED
|
@@ -1,54 +1,54 @@
|
|
| 1 |
-
{
|
| 2 |
-
"name": "headroom-openclaw",
|
| 3 |
-
"version": "0.1.
|
| 4 |
-
"description": "Headroom context compression plugin for OpenClaw — 70-90% token savings with zero LLM calls",
|
| 5 |
-
"type": "module",
|
| 6 |
-
"main": "./dist/index.js",
|
| 7 |
-
"types": "./dist/index.d.ts",
|
| 8 |
-
"files": [
|
| 9 |
-
"dist",
|
| 10 |
-
"hook-shim",
|
| 11 |
-
"openclaw.plugin.json",
|
| 12 |
-
"README.md"
|
| 13 |
-
],
|
| 14 |
-
"scripts": {
|
| 15 |
-
"build": "tsup && node prepare-dist.mjs",
|
| 16 |
-
"test": "vitest run",
|
| 17 |
-
"test:watch": "vitest",
|
| 18 |
-
"typecheck": "tsc --noEmit"
|
| 19 |
-
},
|
| 20 |
-
"dependencies": {
|
| 21 |
-
"headroom-ai": "^0.1.0"
|
| 22 |
-
},
|
| 23 |
-
"peerDependencies": {
|
| 24 |
-
"openclaw": "*"
|
| 25 |
-
},
|
| 26 |
-
"peerDependenciesMeta": {
|
| 27 |
-
"openclaw": {
|
| 28 |
-
"optional": true
|
| 29 |
-
}
|
| 30 |
-
},
|
| 31 |
-
"devDependencies": {
|
| 32 |
-
"@types/node": "^22.10.0",
|
| 33 |
-
"tsup": "^8.0.0",
|
| 34 |
-
"typescript": "^5.5.0",
|
| 35 |
-
"vitest": "^2.0.0"
|
| 36 |
-
},
|
| 37 |
-
"openclaw": {
|
| 38 |
-
"hooks": [
|
| 39 |
-
"./hook-shim"
|
| 40 |
-
],
|
| 41 |
-
"extensions": [
|
| 42 |
-
"./dist/index.js"
|
| 43 |
-
],
|
| 44 |
-
"capabilities": {
|
| 45 |
-
"network": {
|
| 46 |
-
"allow": [
|
| 47 |
-
"http://*:*",
|
| 48 |
-
"https://*:*"
|
| 49 |
-
]
|
| 50 |
-
}
|
| 51 |
-
}
|
| 52 |
-
},
|
| 53 |
-
"license": "Apache-2.0"
|
| 54 |
-
}
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"name": "headroom-openclaw",
|
| 3 |
+
"version": "0.1.1",
|
| 4 |
+
"description": "Headroom context compression plugin for OpenClaw — 70-90% token savings with zero LLM calls",
|
| 5 |
+
"type": "module",
|
| 6 |
+
"main": "./dist/index.js",
|
| 7 |
+
"types": "./dist/index.d.ts",
|
| 8 |
+
"files": [
|
| 9 |
+
"dist",
|
| 10 |
+
"hook-shim",
|
| 11 |
+
"openclaw.plugin.json",
|
| 12 |
+
"README.md"
|
| 13 |
+
],
|
| 14 |
+
"scripts": {
|
| 15 |
+
"build": "tsup && node prepare-dist.mjs",
|
| 16 |
+
"test": "vitest run",
|
| 17 |
+
"test:watch": "vitest",
|
| 18 |
+
"typecheck": "tsc --noEmit"
|
| 19 |
+
},
|
| 20 |
+
"dependencies": {
|
| 21 |
+
"headroom-ai": "^0.1.0"
|
| 22 |
+
},
|
| 23 |
+
"peerDependencies": {
|
| 24 |
+
"openclaw": "*"
|
| 25 |
+
},
|
| 26 |
+
"peerDependenciesMeta": {
|
| 27 |
+
"openclaw": {
|
| 28 |
+
"optional": true
|
| 29 |
+
}
|
| 30 |
+
},
|
| 31 |
+
"devDependencies": {
|
| 32 |
+
"@types/node": "^22.10.0",
|
| 33 |
+
"tsup": "^8.0.0",
|
| 34 |
+
"typescript": "^5.5.0",
|
| 35 |
+
"vitest": "^2.0.0"
|
| 36 |
+
},
|
| 37 |
+
"openclaw": {
|
| 38 |
+
"hooks": [
|
| 39 |
+
"./hook-shim"
|
| 40 |
+
],
|
| 41 |
+
"extensions": [
|
| 42 |
+
"./dist/index.js"
|
| 43 |
+
],
|
| 44 |
+
"capabilities": {
|
| 45 |
+
"network": {
|
| 46 |
+
"allow": [
|
| 47 |
+
"http://*:*",
|
| 48 |
+
"https://*:*"
|
| 49 |
+
]
|
| 50 |
+
}
|
| 51 |
+
}
|
| 52 |
+
},
|
| 53 |
+
"license": "Apache-2.0"
|
| 54 |
+
}
|
tests/test_cli/test_wrap_copilot.py
CHANGED
|
@@ -21,6 +21,7 @@ def test_wrap_copilot_auto_anthropic_injects_instructions(
|
|
| 21 |
runner: CliRunner, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
| 22 |
) -> None:
|
| 23 |
monkeypatch.chdir(tmp_path)
|
|
|
|
| 24 |
captured: dict[str, object] = {}
|
| 25 |
|
| 26 |
def fake_launch_tool(**kwargs): # noqa: ANN003
|
|
@@ -51,7 +52,10 @@ def test_wrap_copilot_auto_anthropic_injects_instructions(
|
|
| 51 |
assert captured["args"] == ("--model", "claude-sonnet-4-20250514")
|
| 52 |
|
| 53 |
|
| 54 |
-
def test_wrap_copilot_openai_backend_sets_completions_env(
|
|
|
|
|
|
|
|
|
|
| 55 |
captured: dict[str, object] = {}
|
| 56 |
|
| 57 |
def fake_launch_tool(**kwargs): # noqa: ANN003
|
|
@@ -90,7 +94,10 @@ def test_wrap_copilot_openai_backend_sets_completions_env(runner: CliRunner) ->
|
|
| 90 |
assert captured["args"] == ("--model", "gpt-4o")
|
| 91 |
|
| 92 |
|
| 93 |
-
def test_wrap_copilot_auto_detects_running_proxy_backend(
|
|
|
|
|
|
|
|
|
|
| 94 |
captured: dict[str, object] = {}
|
| 95 |
|
| 96 |
def fake_launch_tool(**kwargs): # noqa: ANN003
|
|
@@ -153,7 +160,10 @@ def test_wrap_copilot_rejects_responses_for_translated_backends(runner: CliRunne
|
|
| 153 |
assert "not supported with translated backends" in result.output
|
| 154 |
|
| 155 |
|
| 156 |
-
def test_wrap_copilot_clears_stale_wire_api_in_anthropic_mode(
|
|
|
|
|
|
|
|
|
|
| 157 |
captured: dict[str, object] = {}
|
| 158 |
|
| 159 |
def fake_launch_tool(**kwargs): # noqa: ANN003
|
|
@@ -164,7 +174,10 @@ def test_wrap_copilot_clears_stale_wire_api_in_anthropic_mode(runner: CliRunner)
|
|
| 164 |
result = runner.invoke(
|
| 165 |
main,
|
| 166 |
["wrap", "copilot", "--no-rtk", "--", "--model", "claude-sonnet-4-20250514"],
|
| 167 |
-
env={
|
|
|
|
|
|
|
|
|
|
| 168 |
)
|
| 169 |
|
| 170 |
assert result.exit_code == 0, result.output
|
|
|
|
| 21 |
runner: CliRunner, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
| 22 |
) -> None:
|
| 23 |
monkeypatch.chdir(tmp_path)
|
| 24 |
+
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-test-dummy")
|
| 25 |
captured: dict[str, object] = {}
|
| 26 |
|
| 27 |
def fake_launch_tool(**kwargs): # noqa: ANN003
|
|
|
|
| 52 |
assert captured["args"] == ("--model", "claude-sonnet-4-20250514")
|
| 53 |
|
| 54 |
|
| 55 |
+
def test_wrap_copilot_openai_backend_sets_completions_env(
|
| 56 |
+
runner: CliRunner, monkeypatch: pytest.MonkeyPatch
|
| 57 |
+
) -> None:
|
| 58 |
+
monkeypatch.setenv("OPENAI_API_KEY", "sk-test-dummy")
|
| 59 |
captured: dict[str, object] = {}
|
| 60 |
|
| 61 |
def fake_launch_tool(**kwargs): # noqa: ANN003
|
|
|
|
| 94 |
assert captured["args"] == ("--model", "gpt-4o")
|
| 95 |
|
| 96 |
|
| 97 |
+
def test_wrap_copilot_auto_detects_running_proxy_backend(
|
| 98 |
+
runner: CliRunner, monkeypatch: pytest.MonkeyPatch
|
| 99 |
+
) -> None:
|
| 100 |
+
monkeypatch.setenv("OPENAI_API_KEY", "sk-test-dummy")
|
| 101 |
captured: dict[str, object] = {}
|
| 102 |
|
| 103 |
def fake_launch_tool(**kwargs): # noqa: ANN003
|
|
|
|
| 160 |
assert "not supported with translated backends" in result.output
|
| 161 |
|
| 162 |
|
| 163 |
+
def test_wrap_copilot_clears_stale_wire_api_in_anthropic_mode(
|
| 164 |
+
runner: CliRunner, monkeypatch: pytest.MonkeyPatch
|
| 165 |
+
) -> None:
|
| 166 |
+
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-test-dummy")
|
| 167 |
captured: dict[str, object] = {}
|
| 168 |
|
| 169 |
def fake_launch_tool(**kwargs): # noqa: ANN003
|
|
|
|
| 174 |
result = runner.invoke(
|
| 175 |
main,
|
| 176 |
["wrap", "copilot", "--no-rtk", "--", "--model", "claude-sonnet-4-20250514"],
|
| 177 |
+
env={
|
| 178 |
+
"COPILOT_PROVIDER_WIRE_API": "responses",
|
| 179 |
+
"ANTHROPIC_API_KEY": "sk-test-dummy",
|
| 180 |
+
},
|
| 181 |
)
|
| 182 |
|
| 183 |
assert result.exit_code == 0, result.output
|
tests/test_learn/test_scanner.py
CHANGED
|
@@ -103,6 +103,38 @@ class TestGreedyPathDecode:
|
|
| 103 |
result = _greedy_path_decode(tmp_path, ["my", "cool", "project", "nosync", "headroom"])
|
| 104 |
assert result == tmp_path / "my-cool-project.nosync" / "headroom"
|
| 105 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
def test_nonexistent_path_returns_none(self, tmp_path: Path) -> None:
|
| 107 |
result = _greedy_path_decode(tmp_path, ["does", "not", "exist"])
|
| 108 |
assert result is None
|
|
@@ -236,6 +268,38 @@ class TestDecodeProjectPath:
|
|
| 236 |
else:
|
| 237 |
assert result is None or result == project
|
| 238 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
def test_windows_drive_letter_pattern(self) -> None:
|
| 240 |
"""Encoded name -C-MQ2-macros should detect Windows drive letter."""
|
| 241 |
import sys
|
|
|
|
| 103 |
result = _greedy_path_decode(tmp_path, ["my", "cool", "project", "nosync", "headroom"])
|
| 104 |
assert result == tmp_path / "my-cool-project.nosync" / "headroom"
|
| 105 |
|
| 106 |
+
# ---- Underscore tests (issue #159) ----
|
| 107 |
+
|
| 108 |
+
def test_single_underscore_in_dirname(self, tmp_path: Path) -> None:
|
| 109 |
+
"""Directory name contains one literal underscore (e.g. my_project)."""
|
| 110 |
+
_make_dirs(tmp_path, "my_project")
|
| 111 |
+
result = _greedy_path_decode(tmp_path, ["my", "project"])
|
| 112 |
+
assert result == tmp_path / "my_project"
|
| 113 |
+
|
| 114 |
+
def test_multiple_underscores_in_dirname(self, tmp_path: Path) -> None:
|
| 115 |
+
"""Directory name contains multiple underscores (e.g. my_cool_project)."""
|
| 116 |
+
_make_dirs(tmp_path, "my_cool_project")
|
| 117 |
+
result = _greedy_path_decode(tmp_path, ["my", "cool", "project"])
|
| 118 |
+
assert result == tmp_path / "my_cool_project"
|
| 119 |
+
|
| 120 |
+
def test_underscore_nested_path(self, tmp_path: Path) -> None:
|
| 121 |
+
"""Nested path like org/my_project should decode correctly."""
|
| 122 |
+
_make_dirs(tmp_path, "org/my_project")
|
| 123 |
+
result = _greedy_path_decode(tmp_path, ["org", "my", "project"])
|
| 124 |
+
assert result == tmp_path / "org" / "my_project"
|
| 125 |
+
|
| 126 |
+
def test_mixed_underscore_and_hyphen_in_dirname(self, tmp_path: Path) -> None:
|
| 127 |
+
"""Directory with both hyphens and underscores (e.g. my-cool_project)."""
|
| 128 |
+
_make_dirs(tmp_path, "my-cool_project")
|
| 129 |
+
result = _greedy_path_decode(tmp_path, ["my", "cool", "project"])
|
| 130 |
+
assert result == tmp_path / "my-cool_project"
|
| 131 |
+
|
| 132 |
+
def test_underscore_dir_containing_hyphen_subdir(self, tmp_path: Path) -> None:
|
| 133 |
+
"""Path like my_app/sub-module — underscore parent + hyphen child."""
|
| 134 |
+
_make_dirs(tmp_path, "my_app/sub-module")
|
| 135 |
+
result = _greedy_path_decode(tmp_path, ["my", "app", "sub", "module"])
|
| 136 |
+
assert result == tmp_path / "my_app" / "sub-module"
|
| 137 |
+
|
| 138 |
def test_nonexistent_path_returns_none(self, tmp_path: Path) -> None:
|
| 139 |
result = _greedy_path_decode(tmp_path, ["does", "not", "exist"])
|
| 140 |
assert result is None
|
|
|
|
| 268 |
else:
|
| 269 |
assert result is None or result == project
|
| 270 |
|
| 271 |
+
def test_underscore_dirname_via_greedy(self, users_tmp: Path) -> None:
|
| 272 |
+
"""my_project — underscore in directory name (issue #159).
|
| 273 |
+
|
| 274 |
+
Claude Code encodes /Users/foo/org/my_project as
|
| 275 |
+
-Users-foo-org-my-project. Simple replace gives
|
| 276 |
+
…/org/my/project which does not exist, so the greedy decoder
|
| 277 |
+
must reconstruct my_project from tokens ['my', 'project'].
|
| 278 |
+
"""
|
| 279 |
+
project = users_tmp / "org" / "my_project"
|
| 280 |
+
project.mkdir(parents=True)
|
| 281 |
+
|
| 282 |
+
encoded = "-" + str(project)[1:].replace("/", "-")
|
| 283 |
+
result = _decode_project_path(encoded)
|
| 284 |
+
|
| 285 |
+
if str(users_tmp).startswith("/Users/"):
|
| 286 |
+
assert result == project
|
| 287 |
+
else:
|
| 288 |
+
assert result is None or result == project
|
| 289 |
+
|
| 290 |
+
def test_multi_underscore_dirname_via_greedy(self, users_tmp: Path) -> None:
|
| 291 |
+
"""my_cool_project — multiple underscores (issue #159)."""
|
| 292 |
+
project = users_tmp / "my_cool_project"
|
| 293 |
+
project.mkdir(parents=True)
|
| 294 |
+
|
| 295 |
+
encoded = "-" + str(project)[1:].replace("/", "-")
|
| 296 |
+
result = _decode_project_path(encoded)
|
| 297 |
+
|
| 298 |
+
if str(users_tmp).startswith("/Users/"):
|
| 299 |
+
assert result == project
|
| 300 |
+
else:
|
| 301 |
+
assert result is None or result == project
|
| 302 |
+
|
| 303 |
def test_windows_drive_letter_pattern(self) -> None:
|
| 304 |
"""Encoded name -C-MQ2-macros should detect Windows drive letter."""
|
| 305 |
import sys
|
tests/test_memory_sync.py
ADDED
|
@@ -0,0 +1,647 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Comprehensive tests for the universal memory sync engine.
|
| 2 |
+
|
| 3 |
+
Tests cover:
|
| 4 |
+
- Core sync: import, export, bidirectional
|
| 5 |
+
- Idempotency and deduplication
|
| 6 |
+
- Fast no-op detection
|
| 7 |
+
- Lineage and governance metadata
|
| 8 |
+
- Claude Code adapter: read/write frontmatter files
|
| 9 |
+
- Codex adapter: read/write AGENTS.md sections
|
| 10 |
+
- Cross-agent interop: save in one agent, find in another
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import hashlib
|
| 16 |
+
import json
|
| 17 |
+
import time
|
| 18 |
+
from dataclasses import dataclass, field
|
| 19 |
+
from datetime import datetime, timezone
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
from typing import Any
|
| 22 |
+
|
| 23 |
+
import pytest
|
| 24 |
+
|
| 25 |
+
from headroom.memory.sync import (
|
| 26 |
+
sync,
|
| 27 |
+
sync_export,
|
| 28 |
+
sync_import,
|
| 29 |
+
)
|
| 30 |
+
from headroom.memory.sync_adapters.claude_code import (
|
| 31 |
+
ClaudeCodeAdapter,
|
| 32 |
+
_parse_frontmatter,
|
| 33 |
+
)
|
| 34 |
+
from headroom.memory.sync_adapters.codex_agent import CodexAdapter
|
| 35 |
+
|
| 36 |
+
# ---------------------------------------------------------------------------
|
| 37 |
+
# Fake backend for testing (no real DB/embeddings needed)
|
| 38 |
+
# ---------------------------------------------------------------------------
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@dataclass
|
| 42 |
+
class FakeMemory:
|
| 43 |
+
id: str = ""
|
| 44 |
+
content: str = ""
|
| 45 |
+
user_id: str = ""
|
| 46 |
+
category: str = ""
|
| 47 |
+
importance: float = 0.5
|
| 48 |
+
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
| 49 |
+
metadata: dict[str, Any] = field(default_factory=dict)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class FakeBackend:
|
| 53 |
+
"""In-memory backend for testing sync without real DB."""
|
| 54 |
+
|
| 55 |
+
def __init__(self) -> None:
|
| 56 |
+
self._memories: list[FakeMemory] = []
|
| 57 |
+
self._next_id = 1
|
| 58 |
+
|
| 59 |
+
async def get_user_memories(self, user_id: str, limit: int = 500) -> list[FakeMemory]:
|
| 60 |
+
return [m for m in self._memories if m.user_id == user_id][:limit]
|
| 61 |
+
|
| 62 |
+
async def save_memory(
|
| 63 |
+
self,
|
| 64 |
+
content: str,
|
| 65 |
+
user_id: str,
|
| 66 |
+
importance: float = 0.5,
|
| 67 |
+
metadata: dict[str, Any] | None = None,
|
| 68 |
+
**kwargs: Any,
|
| 69 |
+
) -> FakeMemory:
|
| 70 |
+
mem = FakeMemory(
|
| 71 |
+
id=f"mem_{self._next_id:04d}",
|
| 72 |
+
content=content,
|
| 73 |
+
user_id=user_id,
|
| 74 |
+
importance=importance,
|
| 75 |
+
metadata=metadata or {},
|
| 76 |
+
)
|
| 77 |
+
self._next_id += 1
|
| 78 |
+
self._memories.append(mem)
|
| 79 |
+
return mem
|
| 80 |
+
|
| 81 |
+
def add_memory(self, content: str, user_id: str = "tcms", **kwargs: Any) -> FakeMemory:
|
| 82 |
+
"""Sync helper to pre-populate memories."""
|
| 83 |
+
mem = FakeMemory(
|
| 84 |
+
id=f"mem_{self._next_id:04d}",
|
| 85 |
+
content=content,
|
| 86 |
+
user_id=user_id,
|
| 87 |
+
metadata=kwargs.get("metadata", {}),
|
| 88 |
+
importance=kwargs.get("importance", 0.5),
|
| 89 |
+
)
|
| 90 |
+
self._next_id += 1
|
| 91 |
+
self._memories.append(mem)
|
| 92 |
+
return mem
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
# ---------------------------------------------------------------------------
|
| 96 |
+
# Core sync tests
|
| 97 |
+
# ---------------------------------------------------------------------------
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class TestSyncImport:
|
| 101 |
+
"""Test importing from agent files into DB."""
|
| 102 |
+
|
| 103 |
+
@pytest.fixture
|
| 104 |
+
def backend(self):
|
| 105 |
+
return FakeBackend()
|
| 106 |
+
|
| 107 |
+
@pytest.fixture
|
| 108 |
+
def claude_dir(self, tmp_path):
|
| 109 |
+
d = tmp_path / "memory"
|
| 110 |
+
d.mkdir()
|
| 111 |
+
return d
|
| 112 |
+
|
| 113 |
+
def _write_claude_memory(
|
| 114 |
+
self, memory_dir: Path, name: str, content: str, **fm_fields: str
|
| 115 |
+
) -> None:
|
| 116 |
+
slug = name.lower().replace(" ", "_")
|
| 117 |
+
fields = {"name": name, "description": content[:80], "type": "project", **fm_fields}
|
| 118 |
+
fm_lines = ["---"]
|
| 119 |
+
for k, v in fields.items():
|
| 120 |
+
fm_lines.append(f"{k}: {v}")
|
| 121 |
+
fm_lines.append("---")
|
| 122 |
+
(memory_dir / f"{slug}.md").write_text("\n".join(fm_lines) + f"\n\n{content}\n")
|
| 123 |
+
|
| 124 |
+
@pytest.mark.asyncio
|
| 125 |
+
async def test_import_claude_files_to_db(self, backend, claude_dir):
|
| 126 |
+
self._write_claude_memory(claude_dir, "Project codename", "The secret name is TC")
|
| 127 |
+
self._write_claude_memory(claude_dir, "Dark mode", "User prefers dark mode")
|
| 128 |
+
|
| 129 |
+
adapter = ClaudeCodeAdapter(claude_dir)
|
| 130 |
+
imported = await sync_import(backend, adapter, "tcms")
|
| 131 |
+
|
| 132 |
+
assert imported == 2
|
| 133 |
+
mems = await backend.get_user_memories("tcms")
|
| 134 |
+
contents = {m.content for m in mems}
|
| 135 |
+
assert "The secret name is TC" in contents
|
| 136 |
+
assert "User prefers dark mode" in contents
|
| 137 |
+
|
| 138 |
+
@pytest.mark.asyncio
|
| 139 |
+
async def test_import_skips_existing(self, backend, claude_dir):
|
| 140 |
+
"""Memories already in DB are not re-imported."""
|
| 141 |
+
backend.add_memory(
|
| 142 |
+
"The secret name is TC",
|
| 143 |
+
metadata={"content_hash": hashlib.sha256(b"The secret name is TC").hexdigest()[:16]},
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
self._write_claude_memory(claude_dir, "Project codename", "The secret name is TC")
|
| 147 |
+
self._write_claude_memory(claude_dir, "New fact", "Something new")
|
| 148 |
+
|
| 149 |
+
adapter = ClaudeCodeAdapter(claude_dir)
|
| 150 |
+
imported = await sync_import(backend, adapter, "tcms")
|
| 151 |
+
|
| 152 |
+
assert imported == 1 # Only "Something new"
|
| 153 |
+
|
| 154 |
+
@pytest.mark.asyncio
|
| 155 |
+
async def test_import_preserves_lineage(self, backend, claude_dir):
|
| 156 |
+
self._write_claude_memory(claude_dir, "Fact", "Important fact")
|
| 157 |
+
|
| 158 |
+
adapter = ClaudeCodeAdapter(claude_dir)
|
| 159 |
+
await sync_import(backend, adapter, "tcms")
|
| 160 |
+
|
| 161 |
+
mems = await backend.get_user_memories("tcms")
|
| 162 |
+
assert len(mems) == 1
|
| 163 |
+
assert mems[0].metadata["source_agent"] == "claude"
|
| 164 |
+
assert mems[0].metadata["source_file"] == "fact.md"
|
| 165 |
+
assert "content_hash" in mems[0].metadata
|
| 166 |
+
assert mems[0].metadata["sync_direction"] == "import"
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class TestSyncExport:
|
| 170 |
+
"""Test exporting from DB to agent files."""
|
| 171 |
+
|
| 172 |
+
@pytest.fixture
|
| 173 |
+
def backend(self):
|
| 174 |
+
return FakeBackend()
|
| 175 |
+
|
| 176 |
+
@pytest.fixture
|
| 177 |
+
def claude_dir(self, tmp_path):
|
| 178 |
+
d = tmp_path / "memory"
|
| 179 |
+
d.mkdir()
|
| 180 |
+
return d
|
| 181 |
+
|
| 182 |
+
@pytest.mark.asyncio
|
| 183 |
+
async def test_export_new_memory_to_claude_files(self, backend, claude_dir):
|
| 184 |
+
backend.add_memory(
|
| 185 |
+
"Project uses Python 3.12",
|
| 186 |
+
metadata={
|
| 187 |
+
"source_agent": "codex",
|
| 188 |
+
"sync_direction": "export", # Not from claude import
|
| 189 |
+
},
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
adapter = ClaudeCodeAdapter(claude_dir)
|
| 193 |
+
exported = await sync_export(backend, adapter, "tcms")
|
| 194 |
+
|
| 195 |
+
assert exported == 1
|
| 196 |
+
# Check file was created
|
| 197 |
+
md_files = list(claude_dir.glob("headroom_*.md"))
|
| 198 |
+
assert len(md_files) == 1
|
| 199 |
+
|
| 200 |
+
content = md_files[0].read_text()
|
| 201 |
+
assert "Python 3.12" in content
|
| 202 |
+
assert "headroom_id: mem_0001" in content
|
| 203 |
+
assert "source_agent: codex" in content
|
| 204 |
+
|
| 205 |
+
@pytest.mark.asyncio
|
| 206 |
+
async def test_export_skips_claude_originated(self, backend, claude_dir):
|
| 207 |
+
"""Don't re-export memories that were imported FROM claude (anti-echo)."""
|
| 208 |
+
backend.add_memory(
|
| 209 |
+
"From claude",
|
| 210 |
+
metadata={
|
| 211 |
+
"source_agent": "claude",
|
| 212 |
+
"sync_direction": "import",
|
| 213 |
+
},
|
| 214 |
+
)
|
| 215 |
+
backend.add_memory(
|
| 216 |
+
"From codex",
|
| 217 |
+
metadata={
|
| 218 |
+
"source_agent": "codex",
|
| 219 |
+
},
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
adapter = ClaudeCodeAdapter(claude_dir)
|
| 223 |
+
exported = await sync_export(backend, adapter, "tcms")
|
| 224 |
+
|
| 225 |
+
assert exported == 1 # Only "From codex"
|
| 226 |
+
|
| 227 |
+
@pytest.mark.asyncio
|
| 228 |
+
async def test_export_updates_memory_md_index(self, backend, claude_dir):
|
| 229 |
+
# Create an existing MEMORY.md
|
| 230 |
+
(claude_dir / "MEMORY.md").write_text("# Memory\n\n## User\n- Some existing entry\n")
|
| 231 |
+
|
| 232 |
+
backend.add_memory("New fact from codex", metadata={"source_agent": "codex"})
|
| 233 |
+
|
| 234 |
+
adapter = ClaudeCodeAdapter(claude_dir)
|
| 235 |
+
await sync_export(backend, adapter, "tcms")
|
| 236 |
+
|
| 237 |
+
memory_md = (claude_dir / "MEMORY.md").read_text()
|
| 238 |
+
assert "Headroom Shared Memory" in memory_md
|
| 239 |
+
assert "New fact from codex" in memory_md
|
| 240 |
+
assert "Some existing entry" in memory_md # Preserved
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
class TestBidirectionalSync:
|
| 244 |
+
"""Test full bidirectional sync."""
|
| 245 |
+
|
| 246 |
+
@pytest.fixture
|
| 247 |
+
def backend(self):
|
| 248 |
+
return FakeBackend()
|
| 249 |
+
|
| 250 |
+
@pytest.fixture
|
| 251 |
+
def claude_dir(self, tmp_path):
|
| 252 |
+
d = tmp_path / "memory"
|
| 253 |
+
d.mkdir()
|
| 254 |
+
return d
|
| 255 |
+
|
| 256 |
+
@pytest.fixture
|
| 257 |
+
def state_path(self, tmp_path):
|
| 258 |
+
return tmp_path / "sync_state.json"
|
| 259 |
+
|
| 260 |
+
def _write_claude_memory(self, memory_dir: Path, name: str, content: str) -> None:
|
| 261 |
+
slug = name.lower().replace(" ", "_")
|
| 262 |
+
fm = f"---\nname: {name}\ndescription: {content[:80]}\ntype: project\n---"
|
| 263 |
+
(memory_dir / f"{slug}.md").write_text(f"{fm}\n\n{content}\n")
|
| 264 |
+
|
| 265 |
+
@pytest.mark.asyncio
|
| 266 |
+
async def test_bidirectional_sync(self, backend, claude_dir, state_path):
|
| 267 |
+
# Claude has a memory file
|
| 268 |
+
self._write_claude_memory(claude_dir, "Convention", "Always use ruff for linting")
|
| 269 |
+
|
| 270 |
+
# DB has a memory from Codex
|
| 271 |
+
backend.add_memory("Secret name is TC", metadata={"source_agent": "codex"})
|
| 272 |
+
|
| 273 |
+
adapter = ClaudeCodeAdapter(claude_dir)
|
| 274 |
+
result = await sync(backend, adapter, "tcms", state_path=state_path, force=True)
|
| 275 |
+
|
| 276 |
+
assert result.imported == 1 # Claude file → DB
|
| 277 |
+
assert result.exported == 1 # Codex memory → Claude file
|
| 278 |
+
|
| 279 |
+
# Verify DB has both
|
| 280 |
+
mems = await backend.get_user_memories("tcms")
|
| 281 |
+
contents = {m.content for m in mems}
|
| 282 |
+
assert "Always use ruff for linting" in contents
|
| 283 |
+
assert "Secret name is TC" in contents
|
| 284 |
+
|
| 285 |
+
# Verify Claude dir has the exported file
|
| 286 |
+
all_files = list(claude_dir.glob("headroom_*.md"))
|
| 287 |
+
assert len(all_files) >= 1
|
| 288 |
+
exported_content = " ".join(f.read_text() for f in all_files)
|
| 289 |
+
assert "TC" in exported_content
|
| 290 |
+
|
| 291 |
+
@pytest.mark.asyncio
|
| 292 |
+
async def test_sync_idempotent(self, backend, claude_dir, state_path):
|
| 293 |
+
"""Running sync twice produces no duplicates."""
|
| 294 |
+
self._write_claude_memory(claude_dir, "Fact", "Python 3.12 is required")
|
| 295 |
+
backend.add_memory("Port 8787 is default", metadata={"source_agent": "codex"})
|
| 296 |
+
|
| 297 |
+
adapter = ClaudeCodeAdapter(claude_dir)
|
| 298 |
+
|
| 299 |
+
r1 = await sync(backend, adapter, "tcms", state_path=state_path, force=True)
|
| 300 |
+
assert r1.imported == 1
|
| 301 |
+
assert r1.exported == 1
|
| 302 |
+
|
| 303 |
+
r2 = await sync(backend, adapter, "tcms", state_path=state_path, force=True)
|
| 304 |
+
assert r2.imported == 0 # Already imported
|
| 305 |
+
assert r2.exported == 0 # Already exported
|
| 306 |
+
|
| 307 |
+
# No duplicates in DB
|
| 308 |
+
mems = await backend.get_user_memories("tcms")
|
| 309 |
+
assert len(mems) == 2
|
| 310 |
+
|
| 311 |
+
@pytest.mark.asyncio
|
| 312 |
+
async def test_fast_noop_when_unchanged(self, backend, claude_dir, state_path):
|
| 313 |
+
"""Second sync with no changes completes in < 10ms."""
|
| 314 |
+
self._write_claude_memory(claude_dir, "Fact", "Some fact")
|
| 315 |
+
|
| 316 |
+
adapter = ClaudeCodeAdapter(claude_dir)
|
| 317 |
+
|
| 318 |
+
# First sync (populates state)
|
| 319 |
+
await sync(backend, adapter, "tcms", state_path=state_path, force=True)
|
| 320 |
+
|
| 321 |
+
# Second sync (should be fast no-op)
|
| 322 |
+
start = time.monotonic()
|
| 323 |
+
r = await sync(backend, adapter, "tcms", state_path=state_path)
|
| 324 |
+
elapsed = (time.monotonic() - start) * 1000
|
| 325 |
+
|
| 326 |
+
assert r.imported == 0
|
| 327 |
+
assert r.exported == 0
|
| 328 |
+
assert elapsed < 50 # Generous threshold for CI
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
class TestLineageAndGovernance:
|
| 332 |
+
"""Test metadata tracking for audit and lineage."""
|
| 333 |
+
|
| 334 |
+
@pytest.fixture
|
| 335 |
+
def backend(self):
|
| 336 |
+
return FakeBackend()
|
| 337 |
+
|
| 338 |
+
@pytest.fixture
|
| 339 |
+
def claude_dir(self, tmp_path):
|
| 340 |
+
d = tmp_path / "memory"
|
| 341 |
+
d.mkdir()
|
| 342 |
+
return d
|
| 343 |
+
|
| 344 |
+
@pytest.mark.asyncio
|
| 345 |
+
async def test_lineage_tracks_source_agent(self, backend, claude_dir):
|
| 346 |
+
fm = "---\nname: test\ndescription: test\ntype: project\n---"
|
| 347 |
+
(claude_dir / "test.md").write_text(f"{fm}\n\nClaude discovered this\n")
|
| 348 |
+
|
| 349 |
+
adapter = ClaudeCodeAdapter(claude_dir)
|
| 350 |
+
await sync_import(backend, adapter, "tcms")
|
| 351 |
+
|
| 352 |
+
mems = await backend.get_user_memories("tcms")
|
| 353 |
+
assert mems[0].metadata["source_agent"] == "claude"
|
| 354 |
+
|
| 355 |
+
@pytest.mark.asyncio
|
| 356 |
+
async def test_exported_files_have_headroom_id(self, backend, claude_dir):
|
| 357 |
+
backend.add_memory("From codex", metadata={"source_agent": "codex"})
|
| 358 |
+
|
| 359 |
+
adapter = ClaudeCodeAdapter(claude_dir)
|
| 360 |
+
await sync_export(backend, adapter, "tcms")
|
| 361 |
+
|
| 362 |
+
md_files = list(claude_dir.glob("headroom_*.md"))
|
| 363 |
+
assert len(md_files) == 1
|
| 364 |
+
content = md_files[0].read_text()
|
| 365 |
+
assert "headroom_id:" in content
|
| 366 |
+
|
| 367 |
+
@pytest.mark.asyncio
|
| 368 |
+
async def test_sync_state_records_timestamps(self, backend, claude_dir, tmp_path):
|
| 369 |
+
state_path = tmp_path / "state.json"
|
| 370 |
+
fm = "---\nname: t\ndescription: t\ntype: project\n---"
|
| 371 |
+
(claude_dir / "t.md").write_text(f"{fm}\n\nFact\n")
|
| 372 |
+
|
| 373 |
+
adapter = ClaudeCodeAdapter(claude_dir)
|
| 374 |
+
await sync(backend, adapter, "tcms", state_path=state_path, force=True)
|
| 375 |
+
|
| 376 |
+
state = json.loads(state_path.read_text())
|
| 377 |
+
key = "claude:tcms"
|
| 378 |
+
assert key in state
|
| 379 |
+
assert "last_sync" in state[key]
|
| 380 |
+
assert "agent_fingerprint" in state[key]
|
| 381 |
+
assert "db_fingerprint" in state[key]
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
# ---------------------------------------------------------------------------
|
| 385 |
+
# Claude Code adapter tests
|
| 386 |
+
# ---------------------------------------------------------------------------
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
class TestClaudeCodeAdapter:
|
| 390 |
+
"""Test Claude Code adapter read/write."""
|
| 391 |
+
|
| 392 |
+
@pytest.fixture
|
| 393 |
+
def memory_dir(self, tmp_path):
|
| 394 |
+
d = tmp_path / "memory"
|
| 395 |
+
d.mkdir()
|
| 396 |
+
return d
|
| 397 |
+
|
| 398 |
+
def test_parse_frontmatter(self):
|
| 399 |
+
content = "---\nname: Test\ntype: project\n---\n\nBody content here."
|
| 400 |
+
fm, body = _parse_frontmatter(content)
|
| 401 |
+
assert fm["name"] == "Test"
|
| 402 |
+
assert fm["type"] == "project"
|
| 403 |
+
assert body == "Body content here."
|
| 404 |
+
|
| 405 |
+
def test_parse_frontmatter_no_frontmatter(self):
|
| 406 |
+
content = "Just plain content."
|
| 407 |
+
fm, body = _parse_frontmatter(content)
|
| 408 |
+
assert fm == {}
|
| 409 |
+
assert body == "Just plain content."
|
| 410 |
+
|
| 411 |
+
@pytest.mark.asyncio
|
| 412 |
+
async def test_read_memories_skips_memory_md(self, memory_dir):
|
| 413 |
+
(memory_dir / "MEMORY.md").write_text("# Index\n- entry")
|
| 414 |
+
(memory_dir / "fact.md").write_text(
|
| 415 |
+
"---\nname: Fact\ntype: project\n---\n\nImportant fact."
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
adapter = ClaudeCodeAdapter(memory_dir)
|
| 419 |
+
mems = await adapter.read_memories()
|
| 420 |
+
|
| 421 |
+
assert len(mems) == 1
|
| 422 |
+
assert mems[0].content == "Important fact."
|
| 423 |
+
assert mems[0].source_file == "fact.md"
|
| 424 |
+
|
| 425 |
+
@pytest.mark.asyncio
|
| 426 |
+
async def test_write_creates_valid_md(self, memory_dir):
|
| 427 |
+
adapter = ClaudeCodeAdapter(memory_dir)
|
| 428 |
+
written = await adapter.write_memories(
|
| 429 |
+
[
|
| 430 |
+
{
|
| 431 |
+
"content": "Project uses FastAPI",
|
| 432 |
+
"category": "architecture",
|
| 433 |
+
"headroom_id": "mem_001",
|
| 434 |
+
"source_agent": "codex",
|
| 435 |
+
"content_hash": "abc123",
|
| 436 |
+
}
|
| 437 |
+
]
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
assert written == 1
|
| 441 |
+
files = list(memory_dir.glob("headroom_*.md"))
|
| 442 |
+
assert len(files) == 1
|
| 443 |
+
|
| 444 |
+
content = files[0].read_text()
|
| 445 |
+
fm, body = _parse_frontmatter(content)
|
| 446 |
+
assert fm["type"] == "architecture"
|
| 447 |
+
assert fm["headroom_id"] == "mem_001"
|
| 448 |
+
assert fm["source_agent"] == "codex"
|
| 449 |
+
assert "FastAPI" in body
|
| 450 |
+
|
| 451 |
+
def test_fingerprint_changes_on_modification(self, memory_dir):
|
| 452 |
+
(memory_dir / "test.md").write_text("content 1")
|
| 453 |
+
|
| 454 |
+
adapter = ClaudeCodeAdapter(memory_dir)
|
| 455 |
+
fp1 = adapter.fingerprint()
|
| 456 |
+
|
| 457 |
+
(memory_dir / "test.md").write_text("content 2")
|
| 458 |
+
fp2 = adapter.fingerprint()
|
| 459 |
+
|
| 460 |
+
assert fp1 != fp2
|
| 461 |
+
|
| 462 |
+
def test_fingerprint_stable_when_unchanged(self, memory_dir):
|
| 463 |
+
(memory_dir / "test.md").write_text("stable content")
|
| 464 |
+
|
| 465 |
+
adapter = ClaudeCodeAdapter(memory_dir)
|
| 466 |
+
assert adapter.fingerprint() == adapter.fingerprint()
|
| 467 |
+
|
| 468 |
+
def test_fingerprint_empty_dir(self, tmp_path):
|
| 469 |
+
empty = tmp_path / "empty"
|
| 470 |
+
empty.mkdir()
|
| 471 |
+
adapter = ClaudeCodeAdapter(empty)
|
| 472 |
+
assert adapter.fingerprint() == "empty"
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
# ---------------------------------------------------------------------------
|
| 476 |
+
# Codex adapter tests
|
| 477 |
+
# ---------------------------------------------------------------------------
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
class TestCodexAdapter:
|
| 481 |
+
"""Test Codex AGENTS.md adapter."""
|
| 482 |
+
|
| 483 |
+
@pytest.fixture
|
| 484 |
+
def agents_md(self, tmp_path):
|
| 485 |
+
return tmp_path / "AGENTS.md"
|
| 486 |
+
|
| 487 |
+
@pytest.mark.asyncio
|
| 488 |
+
async def test_read_from_agents_md(self, agents_md):
|
| 489 |
+
agents_md.write_text(
|
| 490 |
+
"# Instructions\n\n"
|
| 491 |
+
"<!-- headroom:memory:start -->\n"
|
| 492 |
+
"## Headroom Shared Memory\n\n"
|
| 493 |
+
"- Secret name is TC\n"
|
| 494 |
+
"- Uses Python 3.12\n"
|
| 495 |
+
"<!-- headroom:memory:end -->\n"
|
| 496 |
+
)
|
| 497 |
+
|
| 498 |
+
adapter = CodexAdapter(agents_md)
|
| 499 |
+
mems = await adapter.read_memories()
|
| 500 |
+
|
| 501 |
+
assert len(mems) == 2
|
| 502 |
+
assert mems[0].content == "Secret name is TC"
|
| 503 |
+
assert mems[1].content == "Uses Python 3.12"
|
| 504 |
+
|
| 505 |
+
@pytest.mark.asyncio
|
| 506 |
+
async def test_write_to_agents_md(self, agents_md):
|
| 507 |
+
agents_md.write_text("# Existing instructions\n")
|
| 508 |
+
|
| 509 |
+
adapter = CodexAdapter(agents_md)
|
| 510 |
+
written = await adapter.write_memories(
|
| 511 |
+
[
|
| 512 |
+
{"content": "Port 8787 is default"},
|
| 513 |
+
{"content": "Uses ruff for linting"},
|
| 514 |
+
]
|
| 515 |
+
)
|
| 516 |
+
|
| 517 |
+
assert written == 2
|
| 518 |
+
content = agents_md.read_text()
|
| 519 |
+
assert "headroom:memory:start" in content
|
| 520 |
+
assert "Port 8787 is default" in content
|
| 521 |
+
assert "Uses ruff for linting" in content
|
| 522 |
+
assert "Existing instructions" in content # Preserved
|
| 523 |
+
|
| 524 |
+
@pytest.mark.asyncio
|
| 525 |
+
async def test_write_replaces_existing_section(self, agents_md):
|
| 526 |
+
agents_md.write_text(
|
| 527 |
+
"# Instructions\n\n"
|
| 528 |
+
"<!-- headroom:memory:start -->\n"
|
| 529 |
+
"## Old\n- old fact\n"
|
| 530 |
+
"<!-- headroom:memory:end -->\n"
|
| 531 |
+
)
|
| 532 |
+
|
| 533 |
+
adapter = CodexAdapter(agents_md)
|
| 534 |
+
await adapter.write_memories([{"content": "new fact"}])
|
| 535 |
+
|
| 536 |
+
content = agents_md.read_text()
|
| 537 |
+
assert "new fact" in content
|
| 538 |
+
assert "old fact" not in content
|
| 539 |
+
|
| 540 |
+
@pytest.mark.asyncio
|
| 541 |
+
async def test_read_empty_agents_md(self, agents_md):
|
| 542 |
+
agents_md.write_text("# No memory section\n")
|
| 543 |
+
adapter = CodexAdapter(agents_md)
|
| 544 |
+
mems = await adapter.read_memories()
|
| 545 |
+
assert mems == []
|
| 546 |
+
|
| 547 |
+
@pytest.mark.asyncio
|
| 548 |
+
async def test_read_nonexistent_file(self, tmp_path):
|
| 549 |
+
adapter = CodexAdapter(tmp_path / "nonexistent.md")
|
| 550 |
+
mems = await adapter.read_memories()
|
| 551 |
+
assert mems == []
|
| 552 |
+
|
| 553 |
+
|
| 554 |
+
# ---------------------------------------------------------------------------
|
| 555 |
+
# Cross-agent integration tests
|
| 556 |
+
# ---------------------------------------------------------------------------
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
class TestCrossAgentInterop:
|
| 560 |
+
"""Test that memories flow between agents via sync."""
|
| 561 |
+
|
| 562 |
+
@pytest.fixture
|
| 563 |
+
def backend(self):
|
| 564 |
+
return FakeBackend()
|
| 565 |
+
|
| 566 |
+
@pytest.fixture
|
| 567 |
+
def claude_dir(self, tmp_path):
|
| 568 |
+
d = tmp_path / "claude_memory"
|
| 569 |
+
d.mkdir()
|
| 570 |
+
return d
|
| 571 |
+
|
| 572 |
+
@pytest.fixture
|
| 573 |
+
def agents_md(self, tmp_path):
|
| 574 |
+
return tmp_path / "AGENTS.md"
|
| 575 |
+
|
| 576 |
+
@pytest.fixture
|
| 577 |
+
def state_path(self, tmp_path):
|
| 578 |
+
return tmp_path / "state.json"
|
| 579 |
+
|
| 580 |
+
@pytest.mark.asyncio
|
| 581 |
+
async def test_codex_saves_claude_finds(self, backend, claude_dir, state_path):
|
| 582 |
+
"""Memory saved via Codex MCP appears in Claude's files after sync."""
|
| 583 |
+
# Simulate Codex saving via MCP (directly to backend)
|
| 584 |
+
backend.add_memory(
|
| 585 |
+
"Secret name is TC",
|
| 586 |
+
metadata={"source_agent": "codex", "content_hash": "x"},
|
| 587 |
+
)
|
| 588 |
+
|
| 589 |
+
# Sync to Claude
|
| 590 |
+
adapter = ClaudeCodeAdapter(claude_dir)
|
| 591 |
+
result = await sync(backend, adapter, "tcms", state_path=state_path, force=True)
|
| 592 |
+
|
| 593 |
+
assert result.exported == 1
|
| 594 |
+
|
| 595 |
+
# Claude's memory dir should have the file
|
| 596 |
+
files = list(claude_dir.glob("headroom_*.md"))
|
| 597 |
+
assert len(files) == 1
|
| 598 |
+
assert "TC" in files[0].read_text()
|
| 599 |
+
|
| 600 |
+
@pytest.mark.asyncio
|
| 601 |
+
async def test_claude_saves_codex_finds(self, backend, claude_dir, agents_md, state_path):
|
| 602 |
+
"""Memory saved in Claude's files appears in Codex AGENTS.md after sync."""
|
| 603 |
+
# Claude has a memory
|
| 604 |
+
fm = "---\nname: Linting\ndescription: use ruff\ntype: project\n---"
|
| 605 |
+
(claude_dir / "linting.md").write_text(f"{fm}\n\nAlways use ruff for linting\n")
|
| 606 |
+
|
| 607 |
+
# Sync Claude → DB
|
| 608 |
+
claude_adapter = ClaudeCodeAdapter(claude_dir)
|
| 609 |
+
await sync(backend, claude_adapter, "tcms", state_path=state_path, force=True)
|
| 610 |
+
|
| 611 |
+
# Sync DB → Codex AGENTS.md
|
| 612 |
+
codex_adapter = CodexAdapter(agents_md)
|
| 613 |
+
result = await sync(backend, codex_adapter, "tcms", state_path=state_path, force=True)
|
| 614 |
+
|
| 615 |
+
assert result.exported >= 1
|
| 616 |
+
assert "ruff" in agents_md.read_text()
|
| 617 |
+
|
| 618 |
+
@pytest.mark.asyncio
|
| 619 |
+
async def test_full_round_trip(self, backend, claude_dir, agents_md, state_path):
|
| 620 |
+
"""Full round trip: Claude → DB → Codex, Codex → DB → Claude."""
|
| 621 |
+
# Claude has a memory
|
| 622 |
+
fm = "---\nname: Framework\ntype: project\n---"
|
| 623 |
+
(claude_dir / "framework.md").write_text(f"{fm}\n\nUses FastAPI\n")
|
| 624 |
+
|
| 625 |
+
# Codex has a memory (in DB via MCP)
|
| 626 |
+
backend.add_memory("Port is 8787", metadata={"source_agent": "codex"})
|
| 627 |
+
|
| 628 |
+
# Sync both adapters
|
| 629 |
+
claude_adapter = ClaudeCodeAdapter(claude_dir)
|
| 630 |
+
codex_adapter = CodexAdapter(agents_md)
|
| 631 |
+
|
| 632 |
+
await sync(backend, claude_adapter, "tcms", state_path=state_path, force=True)
|
| 633 |
+
await sync(backend, codex_adapter, "tcms", state_path=state_path, force=True)
|
| 634 |
+
|
| 635 |
+
# DB has both memories
|
| 636 |
+
mems = await backend.get_user_memories("tcms")
|
| 637 |
+
contents = {m.content for m in mems}
|
| 638 |
+
assert "Uses FastAPI" in contents
|
| 639 |
+
assert "Port is 8787" in contents
|
| 640 |
+
|
| 641 |
+
# Claude files have Codex's memory
|
| 642 |
+
all_claude = " ".join(f.read_text() for f in claude_dir.glob("headroom_*.md"))
|
| 643 |
+
assert "8787" in all_claude
|
| 644 |
+
|
| 645 |
+
# AGENTS.md has both (from DB)
|
| 646 |
+
agents_content = agents_md.read_text()
|
| 647 |
+
assert "FastAPI" in agents_content or "8787" in agents_content
|
tests/test_package_init_lazy.py
CHANGED
|
@@ -33,7 +33,8 @@ def test_headroom_import_stays_lazy() -> None:
|
|
| 33 |
)
|
| 34 |
|
| 35 |
data = json.loads(result.stdout.strip())
|
| 36 |
-
|
|
|
|
| 37 |
assert data["cache_loaded"] is False
|
| 38 |
assert data["models_registry_loaded"] is False
|
| 39 |
assert data["memory_loaded"] is False
|
|
|
|
| 33 |
)
|
| 34 |
|
| 35 |
data = json.loads(result.stdout.strip())
|
| 36 |
+
# Version is a non-empty string; don't hardcode a specific value.
|
| 37 |
+
assert isinstance(data["version"], str) and data["version"]
|
| 38 |
assert data["cache_loaded"] is False
|
| 39 |
assert data["models_registry_loaded"] is False
|
| 40 |
assert data["memory_loaded"] is False
|
tests/test_transforms/test_kompress_compressor.py
CHANGED
|
@@ -352,9 +352,8 @@ class TestUnloadKompressModel:
|
|
| 352 |
import headroom.transforms.kompress_compressor as kmod
|
| 353 |
from headroom.transforms.kompress_compressor import unload_kompress_model
|
| 354 |
|
| 355 |
-
# Ensure no model is loaded (previous tests may have set the
|
| 356 |
-
kmod.
|
| 357 |
-
kmod._kompress_tokenizer = None
|
| 358 |
|
| 359 |
# Should return False when no model is loaded
|
| 360 |
assert unload_kompress_model() is False
|
|
|
|
| 352 |
import headroom.transforms.kompress_compressor as kmod
|
| 353 |
from headroom.transforms.kompress_compressor import unload_kompress_model
|
| 354 |
|
| 355 |
+
# Ensure no model is loaded (previous tests may have set the cache)
|
| 356 |
+
kmod._kompress_cache.clear()
|
|
|
|
| 357 |
|
| 358 |
# Should return False when no model is loaded
|
| 359 |
assert unload_kompress_model() is False
|
tests/test_transforms/test_smart_crusher_bugs.py
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Regression tests for SmartCrusher bugs.
|
| 2 |
+
|
| 3 |
+
Bug 1: _crush_number_array mixes types (string summary + numbers),
|
| 4 |
+
violating the schema-preserving guarantee.
|
| 5 |
+
Bug 2: _current_field_semantics is shared instance state, creating
|
| 6 |
+
a race condition when crushing concurrently.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import json
|
| 12 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 13 |
+
|
| 14 |
+
from headroom import SmartCrusherConfig
|
| 15 |
+
from headroom.transforms.smart_crusher import SmartCrusher
|
| 16 |
+
|
| 17 |
+
# ---------------------------------------------------------------------------
|
| 18 |
+
# Fixtures
|
| 19 |
+
# ---------------------------------------------------------------------------
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _make_crusher(max_items: int = 10, min_items: int = 3) -> SmartCrusher:
|
| 23 |
+
config = SmartCrusherConfig(
|
| 24 |
+
enabled=True,
|
| 25 |
+
min_items_to_analyze=min_items,
|
| 26 |
+
min_tokens_to_crush=0,
|
| 27 |
+
max_items_after_crush=max_items,
|
| 28 |
+
variance_threshold=2.0,
|
| 29 |
+
)
|
| 30 |
+
return SmartCrusher(config=config)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# ---------------------------------------------------------------------------
|
| 34 |
+
# Bug 1: Number array type mixing
|
| 35 |
+
# ---------------------------------------------------------------------------
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class TestNumberArraySchemaPreservation:
|
| 39 |
+
"""_crush_number_array must return only original numeric values.
|
| 40 |
+
|
| 41 |
+
Previously it prepended a stats summary string, producing
|
| 42 |
+
[string, int, int, ...] which violates the schema-preserving
|
| 43 |
+
guarantee and breaks type-aware JSON consumers.
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
def test_crushed_number_array_contains_only_numbers(self) -> None:
|
| 47 |
+
"""Every element of the crushed array must be int or float."""
|
| 48 |
+
crusher = _make_crusher(max_items=10)
|
| 49 |
+
numbers = list(range(50)) # 0..49, well above the n<=8 passthrough
|
| 50 |
+
crushed, strategy = crusher._crush_number_array(numbers)
|
| 51 |
+
|
| 52 |
+
for i, item in enumerate(crushed):
|
| 53 |
+
assert isinstance(item, int | float), (
|
| 54 |
+
f"Item {i} is {type(item).__name__} = {item!r}, expected int/float. "
|
| 55 |
+
f"Schema-preserving guarantee violated."
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
def test_crushed_number_array_subset_of_original(self) -> None:
|
| 59 |
+
"""Every value in the crushed array must exist in the original."""
|
| 60 |
+
crusher = _make_crusher(max_items=10)
|
| 61 |
+
numbers = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120]
|
| 62 |
+
crushed, _ = crusher._crush_number_array(numbers)
|
| 63 |
+
|
| 64 |
+
original_set = set(numbers)
|
| 65 |
+
for item in crushed:
|
| 66 |
+
assert item in original_set, (
|
| 67 |
+
f"Value {item!r} not in original array — generated content detected"
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
def test_stats_summary_in_strategy_not_in_array(self) -> None:
|
| 71 |
+
"""Statistics should be communicated via strategy string, not array content."""
|
| 72 |
+
crusher = _make_crusher(max_items=5)
|
| 73 |
+
numbers = list(range(100))
|
| 74 |
+
crushed, strategy = crusher._crush_number_array(numbers)
|
| 75 |
+
|
| 76 |
+
# Strategy should contain stats info
|
| 77 |
+
assert "number:" in strategy
|
| 78 |
+
|
| 79 |
+
# Array should not contain any strings
|
| 80 |
+
strings_in_result = [x for x in crushed if isinstance(x, str)]
|
| 81 |
+
assert strings_in_result == [], f"Found string(s) in numeric array: {strings_in_result}"
|
| 82 |
+
|
| 83 |
+
def test_number_array_passthrough_for_small(self) -> None:
|
| 84 |
+
"""Arrays with n <= 8 should pass through unchanged."""
|
| 85 |
+
crusher = _make_crusher()
|
| 86 |
+
small = [1, 2, 3, 4, 5]
|
| 87 |
+
crushed, strategy = crusher._crush_number_array(small)
|
| 88 |
+
assert crushed == small
|
| 89 |
+
assert strategy == "number:passthrough"
|
| 90 |
+
|
| 91 |
+
def test_number_array_preserves_outliers(self) -> None:
|
| 92 |
+
"""Outlier values should be preserved in the crushed output."""
|
| 93 |
+
crusher = _make_crusher(max_items=10)
|
| 94 |
+
# Normal range + extreme outlier
|
| 95 |
+
numbers = [10] * 20 + [10000]
|
| 96 |
+
crushed, strategy = crusher._crush_number_array(numbers)
|
| 97 |
+
assert 10000 in crushed, "Outlier value 10000 was dropped"
|
| 98 |
+
|
| 99 |
+
def test_number_array_preserves_boundaries(self) -> None:
|
| 100 |
+
"""First and last values should always be kept."""
|
| 101 |
+
crusher = _make_crusher(max_items=5)
|
| 102 |
+
numbers = list(range(100))
|
| 103 |
+
crushed, strategy = crusher._crush_number_array(numbers)
|
| 104 |
+
assert crushed[0] == 0, "First value not preserved"
|
| 105 |
+
assert numbers[-1] in crushed, "Last value not preserved"
|
| 106 |
+
|
| 107 |
+
def test_non_finite_passthrough(self) -> None:
|
| 108 |
+
"""All-NaN/Inf arrays should return unchanged."""
|
| 109 |
+
crusher = _make_crusher()
|
| 110 |
+
nans = [float("nan")] * 10
|
| 111 |
+
crushed, strategy = crusher._crush_number_array(nans)
|
| 112 |
+
assert strategy == "number:no_finite"
|
| 113 |
+
assert len(crushed) == 10
|
| 114 |
+
|
| 115 |
+
def test_full_crush_pipeline_number_array_types(self) -> None:
|
| 116 |
+
"""End-to-end: crushing a JSON number array via the public API."""
|
| 117 |
+
crusher = _make_crusher(max_items=10)
|
| 118 |
+
content = json.dumps(list(range(50)))
|
| 119 |
+
result, was_modified, info = crusher._smart_crush_content(content)
|
| 120 |
+
|
| 121 |
+
if was_modified:
|
| 122 |
+
parsed = json.loads(result)
|
| 123 |
+
assert isinstance(parsed, list)
|
| 124 |
+
for item in parsed:
|
| 125 |
+
assert isinstance(item, int | float), (
|
| 126 |
+
f"Public API returned non-numeric item {item!r} in number array"
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
# ---------------------------------------------------------------------------
|
| 131 |
+
# Bug 2: Race condition on _current_field_semantics
|
| 132 |
+
# ---------------------------------------------------------------------------
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class TestFieldSemanticsThreadSafety:
|
| 136 |
+
"""_current_field_semantics must not leak between concurrent crushes.
|
| 137 |
+
|
| 138 |
+
Previously it was stored as instance state (self._current_field_semantics)
|
| 139 |
+
which created a race condition when the same SmartCrusher instance
|
| 140 |
+
was used from multiple threads.
|
| 141 |
+
"""
|
| 142 |
+
|
| 143 |
+
def test_concurrent_crushes_no_cross_contamination(self) -> None:
|
| 144 |
+
"""Two concurrent crushes must not share field_semantics state."""
|
| 145 |
+
crusher = _make_crusher(max_items=5)
|
| 146 |
+
|
| 147 |
+
# Two different array payloads
|
| 148 |
+
payload_a = json.dumps([{"name": f"item_{i}", "value": i} for i in range(20)])
|
| 149 |
+
payload_b = json.dumps([{"key": f"k_{i}", "score": i * 0.1} for i in range(20)])
|
| 150 |
+
|
| 151 |
+
results: dict[str, str] = {}
|
| 152 |
+
errors: list[Exception] = []
|
| 153 |
+
|
| 154 |
+
def crush_task(label: str, content: str) -> None:
|
| 155 |
+
try:
|
| 156 |
+
result, modified, info = crusher._smart_crush_content(content)
|
| 157 |
+
results[label] = result
|
| 158 |
+
except Exception as e:
|
| 159 |
+
errors.append(e)
|
| 160 |
+
|
| 161 |
+
with ThreadPoolExecutor(max_workers=4) as executor:
|
| 162 |
+
futures = []
|
| 163 |
+
# Run many concurrent crushes to increase race probability
|
| 164 |
+
for i in range(20):
|
| 165 |
+
futures.append(executor.submit(crush_task, f"a_{i}", payload_a))
|
| 166 |
+
futures.append(executor.submit(crush_task, f"b_{i}", payload_b))
|
| 167 |
+
for f in as_completed(futures):
|
| 168 |
+
f.result() # Re-raise exceptions
|
| 169 |
+
|
| 170 |
+
assert not errors, f"Concurrent crushes raised errors: {errors}"
|
| 171 |
+
|
| 172 |
+
# After all crushes, thread-local state must be clean
|
| 173 |
+
tl = getattr(crusher, "_thread_local", None)
|
| 174 |
+
if tl is not None:
|
| 175 |
+
semantics = getattr(tl, "field_semantics", None)
|
| 176 |
+
assert semantics is None, f"field_semantics leaked in thread-local: {semantics}"
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
# ---------------------------------------------------------------------------
|
| 180 |
+
# Issue 7: Recursion depth limit
|
| 181 |
+
# ---------------------------------------------------------------------------
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class TestRecursionDepthLimit:
|
| 185 |
+
"""_process_value must not crash on deeply nested JSON."""
|
| 186 |
+
|
| 187 |
+
def test_deeply_nested_json_does_not_crash(self) -> None:
|
| 188 |
+
"""Nesting deeper than _MAX_PROCESS_DEPTH should return value unchanged."""
|
| 189 |
+
crusher = _make_crusher()
|
| 190 |
+
# Build a 100-level nested structure
|
| 191 |
+
nested: dict = {"leaf": "value"}
|
| 192 |
+
for _i in range(100):
|
| 193 |
+
nested = {"level": nested}
|
| 194 |
+
|
| 195 |
+
content = json.dumps(nested)
|
| 196 |
+
result, was_modified, info = crusher._smart_crush_content(content)
|
| 197 |
+
# Should not raise RecursionError
|
| 198 |
+
parsed = json.loads(result)
|
| 199 |
+
# The deep structure should be preserved (returned as-is past depth limit)
|
| 200 |
+
assert isinstance(parsed, dict)
|
| 201 |
+
|
| 202 |
+
def test_deeply_nested_list_does_not_crash(self) -> None:
|
| 203 |
+
"""Deeply nested lists should also be handled safely."""
|
| 204 |
+
crusher = _make_crusher()
|
| 205 |
+
nested: list = ["leaf"]
|
| 206 |
+
for _i in range(100):
|
| 207 |
+
nested = [nested]
|
| 208 |
+
|
| 209 |
+
content = json.dumps(nested)
|
| 210 |
+
result, was_modified, info = crusher._smart_crush_content(content)
|
| 211 |
+
parsed = json.loads(result)
|
| 212 |
+
assert isinstance(parsed, list)
|
tests/test_transforms/test_universal_json_crush.py
CHANGED
|
@@ -172,14 +172,15 @@ class TestCrushNumberArray:
|
|
| 172 |
assert len(crushed) < len(numbers)
|
| 173 |
assert "number:adaptive" in strategy
|
| 174 |
|
| 175 |
-
def
|
| 176 |
numbers = list(range(100))
|
| 177 |
crushed, strategy = crusher._crush_number_array(numbers)
|
| 178 |
-
#
|
| 179 |
-
assert
|
| 180 |
-
assert "
|
| 181 |
-
|
| 182 |
-
|
|
|
|
| 183 |
|
| 184 |
def test_outliers_preserved(self, crusher):
|
| 185 |
# Normal values around 50 with one extreme outlier
|
|
@@ -193,13 +194,13 @@ class TestCrushNumberArray:
|
|
| 193 |
crushed, strategy = crusher._crush_number_array(numbers)
|
| 194 |
# With all identical, should compress heavily
|
| 195 |
# Summary + a few representatives
|
| 196 |
-
numeric_values = [v for v in crushed if isinstance(v,
|
| 197 |
assert all(v == 42.0 for v in numeric_values)
|
| 198 |
|
| 199 |
def test_first_last_kept(self, crusher):
|
| 200 |
numbers = list(range(50))
|
| 201 |
crushed, strategy = crusher._crush_number_array(numbers)
|
| 202 |
-
numeric_values = [v for v in crushed if isinstance(v,
|
| 203 |
assert 0 in numeric_values # First
|
| 204 |
assert 49 in numeric_values # Last
|
| 205 |
|
|
@@ -207,7 +208,7 @@ class TestCrushNumberArray:
|
|
| 207 |
# Stable at 10, then jumps to 100
|
| 208 |
numbers = [10.0] * 50 + [100.0] * 50
|
| 209 |
crushed, strategy = crusher_large_k._crush_number_array(numbers)
|
| 210 |
-
numeric_values = [v for v in crushed if isinstance(v,
|
| 211 |
# Both 10.0 and 100.0 should be present
|
| 212 |
assert 10.0 in numeric_values
|
| 213 |
assert 100.0 in numeric_values
|
|
@@ -215,23 +216,24 @@ class TestCrushNumberArray:
|
|
| 215 |
def test_nan_inf_filtered(self, crusher):
|
| 216 |
numbers = [1.0, 2.0, float("nan"), float("inf"), 3.0] * 10
|
| 217 |
crushed, strategy = crusher._crush_number_array(numbers)
|
| 218 |
-
# Should not crash; stats
|
| 219 |
-
assert
|
|
|
|
| 220 |
|
| 221 |
def test_integers_preserved_as_int(self, crusher):
|
| 222 |
numbers = list(range(50))
|
| 223 |
crushed, strategy = crusher._crush_number_array(numbers)
|
| 224 |
-
numeric_values = [v for v in crushed if isinstance(v,
|
| 225 |
# Integers should remain integers (not converted to float)
|
| 226 |
assert any(isinstance(v, int) for v in numeric_values)
|
| 227 |
|
| 228 |
def test_statistics_accuracy(self, crusher):
|
| 229 |
numbers = list(range(1, 101)) # 1 to 100
|
| 230 |
crushed, strategy = crusher._crush_number_array(numbers)
|
| 231 |
-
|
| 232 |
-
assert "min=1" in
|
| 233 |
-
assert "max=100" in
|
| 234 |
-
assert "mean=50.5" in
|
| 235 |
|
| 236 |
|
| 237 |
# =====================================================================
|
|
@@ -382,7 +384,7 @@ class TestSafetyGuarantees:
|
|
| 382 |
assert items[-1] in crushed
|
| 383 |
else:
|
| 384 |
crushed, _ = crusher._crush_number_array(items)
|
| 385 |
-
numeric = [v for v in crushed if isinstance(v,
|
| 386 |
assert items[0] in numeric
|
| 387 |
assert items[-1] in numeric
|
| 388 |
|
|
@@ -399,7 +401,7 @@ class TestSafetyGuarantees:
|
|
| 399 |
"""Arrays below min_items_to_analyze pass through unchanged."""
|
| 400 |
if all(isinstance(i, str) for i in items):
|
| 401 |
crushed, strategy = crusher._crush_string_array(items)
|
| 402 |
-
elif all(isinstance(i,
|
| 403 |
crushed, strategy = crusher._crush_number_array(items)
|
| 404 |
else:
|
| 405 |
crushed, strategy = crusher._crush_mixed_array(items)
|
|
|
|
| 172 |
assert len(crushed) < len(numbers)
|
| 173 |
assert "number:adaptive" in strategy
|
| 174 |
|
| 175 |
+
def test_stats_in_strategy_not_array(self, crusher):
|
| 176 |
numbers = list(range(100))
|
| 177 |
crushed, strategy = crusher._crush_number_array(numbers)
|
| 178 |
+
# Stats should be in the strategy string, not in the array
|
| 179 |
+
assert "min=" in strategy
|
| 180 |
+
assert "max=" in strategy
|
| 181 |
+
# Array should contain only numbers (schema-preserving)
|
| 182 |
+
for item in crushed:
|
| 183 |
+
assert isinstance(item, int | float)
|
| 184 |
|
| 185 |
def test_outliers_preserved(self, crusher):
|
| 186 |
# Normal values around 50 with one extreme outlier
|
|
|
|
| 194 |
crushed, strategy = crusher._crush_number_array(numbers)
|
| 195 |
# With all identical, should compress heavily
|
| 196 |
# Summary + a few representatives
|
| 197 |
+
numeric_values = [v for v in crushed if isinstance(v, int | float)]
|
| 198 |
assert all(v == 42.0 for v in numeric_values)
|
| 199 |
|
| 200 |
def test_first_last_kept(self, crusher):
|
| 201 |
numbers = list(range(50))
|
| 202 |
crushed, strategy = crusher._crush_number_array(numbers)
|
| 203 |
+
numeric_values = [v for v in crushed if isinstance(v, int | float)]
|
| 204 |
assert 0 in numeric_values # First
|
| 205 |
assert 49 in numeric_values # Last
|
| 206 |
|
|
|
|
| 208 |
# Stable at 10, then jumps to 100
|
| 209 |
numbers = [10.0] * 50 + [100.0] * 50
|
| 210 |
crushed, strategy = crusher_large_k._crush_number_array(numbers)
|
| 211 |
+
numeric_values = [v for v in crushed if isinstance(v, int | float)]
|
| 212 |
# Both 10.0 and 100.0 should be present
|
| 213 |
assert 10.0 in numeric_values
|
| 214 |
assert 100.0 in numeric_values
|
|
|
|
| 216 |
def test_nan_inf_filtered(self, crusher):
|
| 217 |
numbers = [1.0, 2.0, float("nan"), float("inf"), 3.0] * 10
|
| 218 |
crushed, strategy = crusher._crush_number_array(numbers)
|
| 219 |
+
# Should not crash; stats in strategy based on finite values
|
| 220 |
+
assert "min=" in strategy
|
| 221 |
+
assert "max=" in strategy
|
| 222 |
|
| 223 |
def test_integers_preserved_as_int(self, crusher):
|
| 224 |
numbers = list(range(50))
|
| 225 |
crushed, strategy = crusher._crush_number_array(numbers)
|
| 226 |
+
numeric_values = [v for v in crushed if isinstance(v, int | float)]
|
| 227 |
# Integers should remain integers (not converted to float)
|
| 228 |
assert any(isinstance(v, int) for v in numeric_values)
|
| 229 |
|
| 230 |
def test_statistics_accuracy(self, crusher):
|
| 231 |
numbers = list(range(1, 101)) # 1 to 100
|
| 232 |
crushed, strategy = crusher._crush_number_array(numbers)
|
| 233 |
+
# Stats are in the strategy string
|
| 234 |
+
assert "min=1" in strategy
|
| 235 |
+
assert "max=100" in strategy
|
| 236 |
+
assert "mean=50.5" in strategy
|
| 237 |
|
| 238 |
|
| 239 |
# =====================================================================
|
|
|
|
| 384 |
assert items[-1] in crushed
|
| 385 |
else:
|
| 386 |
crushed, _ = crusher._crush_number_array(items)
|
| 387 |
+
numeric = [v for v in crushed if isinstance(v, int | float)]
|
| 388 |
assert items[0] in numeric
|
| 389 |
assert items[-1] in numeric
|
| 390 |
|
|
|
|
| 401 |
"""Arrays below min_items_to_analyze pass through unchanged."""
|
| 402 |
if all(isinstance(i, str) for i in items):
|
| 403 |
crushed, strategy = crusher._crush_string_array(items)
|
| 404 |
+
elif all(isinstance(i, int | float) for i in items):
|
| 405 |
crushed, strategy = crusher._crush_number_array(items)
|
| 406 |
else:
|
| 407 |
crushed, strategy = crusher._crush_mixed_array(items)
|
tests/test_ws_memory_relay.py
ADDED
|
@@ -0,0 +1,523 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for WebSocket memory tool interception in the Codex Responses API relay.
|
| 2 |
+
|
| 3 |
+
Verifies that:
|
| 4 |
+
1. Memory tool events are suppressed from reaching Codex
|
| 5 |
+
2. response.created is buffered and only flushed for non-memory responses
|
| 6 |
+
3. Tool execution happens and continuation is sent upstream
|
| 7 |
+
4. Non-memory responses pass through with normal streaming latency
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import json
|
| 13 |
+
from dataclasses import dataclass, field
|
| 14 |
+
from typing import Any
|
| 15 |
+
|
| 16 |
+
from headroom.proxy.memory_handler import MEMORY_TOOL_NAMES
|
| 17 |
+
|
| 18 |
+
# ---------------------------------------------------------------------------
|
| 19 |
+
# Minimal WS relay state machine (mirrors the logic in openai.py)
|
| 20 |
+
# ---------------------------------------------------------------------------
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dataclass
|
| 24 |
+
class WSMemoryRelayState:
|
| 25 |
+
"""State machine for WS event processing with memory tool interception.
|
| 26 |
+
|
| 27 |
+
This mirrors the logic in ``_upstream_to_client`` but is decoupled from
|
| 28 |
+
actual WebSocket I/O so it can be unit-tested.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
memory_tool_names: set[str] = field(default_factory=lambda: set(MEMORY_TOOL_NAMES))
|
| 32 |
+
|
| 33 |
+
# Per-response state (reset after each response.completed)
|
| 34 |
+
event_buffer: list[str] = field(default_factory=list)
|
| 35 |
+
decided: bool = False
|
| 36 |
+
suppress_response: bool = False
|
| 37 |
+
pending_function_calls: list[dict[str, Any]] = field(default_factory=list)
|
| 38 |
+
last_response_id: str | None = None
|
| 39 |
+
|
| 40 |
+
def process_event(self, msg_str: str) -> dict[str, Any]:
|
| 41 |
+
"""Process a single upstream WS event.
|
| 42 |
+
|
| 43 |
+
Returns a dict with possible keys:
|
| 44 |
+
relay: list[str] — events to send to Codex
|
| 45 |
+
execute_tools: list — function_call items to execute
|
| 46 |
+
send_continuation: dict — continuation payload to send upstream
|
| 47 |
+
"""
|
| 48 |
+
result: dict[str, Any] = {"relay": [], "execute_tools": [], "send_continuation": None}
|
| 49 |
+
|
| 50 |
+
try:
|
| 51 |
+
event = json.loads(msg_str)
|
| 52 |
+
except (json.JSONDecodeError, TypeError):
|
| 53 |
+
# Not JSON — always relay
|
| 54 |
+
result["relay"].append(msg_str)
|
| 55 |
+
return result
|
| 56 |
+
|
| 57 |
+
event_type = event.get("type", "")
|
| 58 |
+
|
| 59 |
+
# ---- Phase 1: Buffering (before first output item) ----
|
| 60 |
+
if not self.decided:
|
| 61 |
+
self.event_buffer.append(msg_str)
|
| 62 |
+
|
| 63 |
+
if event_type == "response.output_item.added":
|
| 64 |
+
item = event.get("item", {})
|
| 65 |
+
if (
|
| 66 |
+
item.get("type") == "function_call"
|
| 67 |
+
and item.get("name") in self.memory_tool_names
|
| 68 |
+
):
|
| 69 |
+
# Memory tool is first output → suppress entire response
|
| 70 |
+
self.suppress_response = True
|
| 71 |
+
self.decided = True
|
| 72 |
+
self.event_buffer.clear()
|
| 73 |
+
else:
|
| 74 |
+
# Non-memory item → flush buffer and pass through
|
| 75 |
+
self.decided = True
|
| 76 |
+
result["relay"].extend(self.event_buffer)
|
| 77 |
+
self.event_buffer.clear()
|
| 78 |
+
|
| 79 |
+
elif event_type == "response.completed":
|
| 80 |
+
# Response completed with no output items — flush all
|
| 81 |
+
self.decided = True
|
| 82 |
+
result["relay"].extend(self.event_buffer)
|
| 83 |
+
self.event_buffer.clear()
|
| 84 |
+
|
| 85 |
+
return result
|
| 86 |
+
|
| 87 |
+
# ---- Phase 2a: Suppress mode (memory tool response) ----
|
| 88 |
+
if self.suppress_response:
|
| 89 |
+
# Capture completed function_call items
|
| 90 |
+
if event_type == "response.output_item.done":
|
| 91 |
+
item = event.get("item", {})
|
| 92 |
+
if (
|
| 93 |
+
item.get("type") == "function_call"
|
| 94 |
+
and item.get("name") in self.memory_tool_names
|
| 95 |
+
):
|
| 96 |
+
self.pending_function_calls.append(item)
|
| 97 |
+
|
| 98 |
+
if event_type == "response.completed":
|
| 99 |
+
resp = event.get("response", {})
|
| 100 |
+
self.last_response_id = resp.get("id")
|
| 101 |
+
|
| 102 |
+
if self.pending_function_calls:
|
| 103 |
+
result["execute_tools"] = list(self.pending_function_calls)
|
| 104 |
+
# Build continuation payload
|
| 105 |
+
# (actual tool execution + output building done by caller)
|
| 106 |
+
result["send_continuation"] = {
|
| 107 |
+
"response_id": self.last_response_id,
|
| 108 |
+
"function_calls": list(self.pending_function_calls),
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
# Reset for next response (continuation)
|
| 112 |
+
self._reset_response_state()
|
| 113 |
+
|
| 114 |
+
return result # Nothing relayed in suppress mode
|
| 115 |
+
|
| 116 |
+
# ---- Phase 2b: Pass-through mode (normal response) ----
|
| 117 |
+
result["relay"].append(msg_str)
|
| 118 |
+
return result
|
| 119 |
+
|
| 120 |
+
def _reset_response_state(self) -> None:
|
| 121 |
+
"""Reset per-response state for the next response."""
|
| 122 |
+
self.event_buffer.clear()
|
| 123 |
+
self.decided = False
|
| 124 |
+
self.suppress_response = False
|
| 125 |
+
self.pending_function_calls.clear()
|
| 126 |
+
self.last_response_id = None
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
# ---------------------------------------------------------------------------
|
| 130 |
+
# Test helpers
|
| 131 |
+
# ---------------------------------------------------------------------------
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def _make_event(event_type: str, **kwargs: Any) -> str:
|
| 135 |
+
data: dict[str, Any] = {"type": event_type}
|
| 136 |
+
data.update(kwargs)
|
| 137 |
+
return json.dumps(data)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def _response_created(response_id: str = "resp_A") -> str:
|
| 141 |
+
return _make_event("response.created", response={"id": response_id})
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def _output_item_added_text(index: int = 0) -> str:
|
| 145 |
+
return _make_event(
|
| 146 |
+
"response.output_item.added",
|
| 147 |
+
output_index=index,
|
| 148 |
+
item={"type": "message", "role": "assistant"},
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def _output_item_added_function_call(name: str, index: int = 0, call_id: str = "call_1") -> str:
|
| 153 |
+
return _make_event(
|
| 154 |
+
"response.output_item.added",
|
| 155 |
+
output_index=index,
|
| 156 |
+
item={"type": "function_call", "name": name, "call_id": call_id},
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def _function_call_args_delta(index: int = 0, delta: str = '{"qu') -> str:
|
| 161 |
+
return _make_event(
|
| 162 |
+
"response.function_call_arguments.delta",
|
| 163 |
+
output_index=index,
|
| 164 |
+
delta=delta,
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def _function_call_args_done(index: int = 0, arguments: str = '{"query": "codename"}') -> str:
|
| 169 |
+
return _make_event(
|
| 170 |
+
"response.function_call_arguments.done",
|
| 171 |
+
output_index=index,
|
| 172 |
+
arguments=arguments,
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def _output_item_done_function_call(
|
| 177 |
+
name: str,
|
| 178 |
+
index: int = 0,
|
| 179 |
+
call_id: str = "call_1",
|
| 180 |
+
arguments: str = '{"query": "codename"}',
|
| 181 |
+
) -> str:
|
| 182 |
+
return _make_event(
|
| 183 |
+
"response.output_item.done",
|
| 184 |
+
output_index=index,
|
| 185 |
+
item={
|
| 186 |
+
"type": "function_call",
|
| 187 |
+
"name": name,
|
| 188 |
+
"call_id": call_id,
|
| 189 |
+
"arguments": arguments,
|
| 190 |
+
},
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def _output_text_delta(index: int = 0, text: str = "Hello") -> str:
|
| 195 |
+
return _make_event(
|
| 196 |
+
"response.output_text.delta",
|
| 197 |
+
output_index=index,
|
| 198 |
+
delta=text,
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def _output_item_done_text(index: int = 0) -> str:
|
| 203 |
+
return _make_event(
|
| 204 |
+
"response.output_item.done",
|
| 205 |
+
output_index=index,
|
| 206 |
+
item={"type": "message", "role": "assistant"},
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def _response_completed(response_id: str = "resp_A") -> str:
|
| 211 |
+
return _make_event(
|
| 212 |
+
"response.completed",
|
| 213 |
+
response={"id": response_id, "status": "completed"},
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def _output_item_added_shell(index: int = 0) -> str:
|
| 218 |
+
"""Simulate a Codex built-in tool (shell) that should pass through."""
|
| 219 |
+
return _make_event(
|
| 220 |
+
"response.output_item.added",
|
| 221 |
+
output_index=index,
|
| 222 |
+
item={"type": "function_call", "name": "shell", "call_id": "call_shell"},
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
# ---------------------------------------------------------------------------
|
| 227 |
+
# Tests
|
| 228 |
+
# ---------------------------------------------------------------------------
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
class TestWSMemoryRelayNonMemory:
|
| 232 |
+
"""Responses with no memory tools pass through normally."""
|
| 233 |
+
|
| 234 |
+
def test_text_response_relayed_immediately(self):
|
| 235 |
+
"""Text-only response: all events relayed, no buffering after first item."""
|
| 236 |
+
relay = WSMemoryRelayState()
|
| 237 |
+
|
| 238 |
+
events = [
|
| 239 |
+
_response_created(),
|
| 240 |
+
_output_item_added_text(),
|
| 241 |
+
_output_text_delta(text="The answer is 42"),
|
| 242 |
+
_output_item_done_text(),
|
| 243 |
+
_response_completed(),
|
| 244 |
+
]
|
| 245 |
+
|
| 246 |
+
all_relayed: list[str] = []
|
| 247 |
+
for ev in events:
|
| 248 |
+
result = relay.process_event(ev)
|
| 249 |
+
all_relayed.extend(result["relay"])
|
| 250 |
+
assert result["execute_tools"] == []
|
| 251 |
+
assert result["send_continuation"] is None
|
| 252 |
+
|
| 253 |
+
# All 5 events should be relayed
|
| 254 |
+
assert len(all_relayed) == 5
|
| 255 |
+
|
| 256 |
+
# First event (response.created) should be buffered then flushed
|
| 257 |
+
# with the second event (output_item_added_text)
|
| 258 |
+
types = [json.loads(e)["type"] for e in all_relayed]
|
| 259 |
+
assert types == [
|
| 260 |
+
"response.created",
|
| 261 |
+
"response.output_item.added",
|
| 262 |
+
"response.output_text.delta",
|
| 263 |
+
"response.output_item.done",
|
| 264 |
+
"response.completed",
|
| 265 |
+
]
|
| 266 |
+
|
| 267 |
+
def test_shell_tool_relayed(self):
|
| 268 |
+
"""Codex built-in tools (shell) pass through without interception."""
|
| 269 |
+
relay = WSMemoryRelayState()
|
| 270 |
+
|
| 271 |
+
events = [
|
| 272 |
+
_response_created(),
|
| 273 |
+
_output_item_added_shell(),
|
| 274 |
+
_response_completed(),
|
| 275 |
+
]
|
| 276 |
+
|
| 277 |
+
all_relayed: list[str] = []
|
| 278 |
+
for ev in events:
|
| 279 |
+
result = relay.process_event(ev)
|
| 280 |
+
all_relayed.extend(result["relay"])
|
| 281 |
+
assert result["execute_tools"] == []
|
| 282 |
+
assert result["send_continuation"] is None
|
| 283 |
+
|
| 284 |
+
assert len(all_relayed) == 3
|
| 285 |
+
|
| 286 |
+
def test_empty_response_relayed(self):
|
| 287 |
+
"""Response with no output items still relays created + completed."""
|
| 288 |
+
relay = WSMemoryRelayState()
|
| 289 |
+
|
| 290 |
+
events = [
|
| 291 |
+
_response_created(),
|
| 292 |
+
_response_completed(),
|
| 293 |
+
]
|
| 294 |
+
|
| 295 |
+
all_relayed: list[str] = []
|
| 296 |
+
for ev in events:
|
| 297 |
+
result = relay.process_event(ev)
|
| 298 |
+
all_relayed.extend(result["relay"])
|
| 299 |
+
|
| 300 |
+
assert len(all_relayed) == 2
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
class TestWSMemoryRelayMemoryTool:
|
| 304 |
+
"""Responses with memory tools are intercepted transparently."""
|
| 305 |
+
|
| 306 |
+
def test_memory_search_fully_suppressed(self):
|
| 307 |
+
"""memory_search call: ALL events suppressed from Codex."""
|
| 308 |
+
relay = WSMemoryRelayState()
|
| 309 |
+
|
| 310 |
+
events = [
|
| 311 |
+
_response_created("resp_A"),
|
| 312 |
+
_output_item_added_function_call("memory_search", index=0),
|
| 313 |
+
_function_call_args_delta(index=0),
|
| 314 |
+
_function_call_args_done(index=0),
|
| 315 |
+
_output_item_done_function_call("memory_search", index=0),
|
| 316 |
+
_response_completed("resp_A"),
|
| 317 |
+
]
|
| 318 |
+
|
| 319 |
+
all_relayed: list[str] = []
|
| 320 |
+
tool_executions: list[Any] = []
|
| 321 |
+
continuations: list[Any] = []
|
| 322 |
+
|
| 323 |
+
for ev in events:
|
| 324 |
+
result = relay.process_event(ev)
|
| 325 |
+
all_relayed.extend(result["relay"])
|
| 326 |
+
tool_executions.extend(result["execute_tools"])
|
| 327 |
+
if result["send_continuation"]:
|
| 328 |
+
continuations.append(result["send_continuation"])
|
| 329 |
+
|
| 330 |
+
# ZERO events relayed to Codex
|
| 331 |
+
assert len(all_relayed) == 0, (
|
| 332 |
+
f"Expected 0 relayed events, got {len(all_relayed)}: "
|
| 333 |
+
f"{[json.loads(e)['type'] for e in all_relayed]}"
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
# Tool execution triggered
|
| 337 |
+
assert len(tool_executions) == 1
|
| 338 |
+
assert tool_executions[0]["name"] == "memory_search"
|
| 339 |
+
|
| 340 |
+
# Continuation requested
|
| 341 |
+
assert len(continuations) == 1
|
| 342 |
+
assert continuations[0]["response_id"] == "resp_A"
|
| 343 |
+
|
| 344 |
+
def test_memory_save_also_suppressed(self):
|
| 345 |
+
"""memory_save call is also intercepted."""
|
| 346 |
+
relay = WSMemoryRelayState()
|
| 347 |
+
|
| 348 |
+
events = [
|
| 349 |
+
_response_created("resp_B"),
|
| 350 |
+
_output_item_added_function_call("memory_save", index=0, call_id="call_save"),
|
| 351 |
+
_function_call_args_done(index=0, arguments='{"content": "user likes dark mode"}'),
|
| 352 |
+
_output_item_done_function_call(
|
| 353 |
+
"memory_save",
|
| 354 |
+
index=0,
|
| 355 |
+
call_id="call_save",
|
| 356 |
+
arguments='{"content": "user likes dark mode"}',
|
| 357 |
+
),
|
| 358 |
+
_response_completed("resp_B"),
|
| 359 |
+
]
|
| 360 |
+
|
| 361 |
+
all_relayed: list[str] = []
|
| 362 |
+
tool_executions: list[Any] = []
|
| 363 |
+
|
| 364 |
+
for ev in events:
|
| 365 |
+
result = relay.process_event(ev)
|
| 366 |
+
all_relayed.extend(result["relay"])
|
| 367 |
+
tool_executions.extend(result["execute_tools"])
|
| 368 |
+
|
| 369 |
+
assert len(all_relayed) == 0
|
| 370 |
+
assert len(tool_executions) == 1
|
| 371 |
+
assert tool_executions[0]["name"] == "memory_save"
|
| 372 |
+
|
| 373 |
+
def test_continuation_response_relayed_normally(self):
|
| 374 |
+
"""After memory tool handling, the continuation response passes through."""
|
| 375 |
+
relay = WSMemoryRelayState()
|
| 376 |
+
|
| 377 |
+
# --- First response: memory_search (suppressed) ---
|
| 378 |
+
first_response_events = [
|
| 379 |
+
_response_created("resp_A"),
|
| 380 |
+
_output_item_added_function_call("memory_search", index=0),
|
| 381 |
+
_function_call_args_done(index=0),
|
| 382 |
+
_output_item_done_function_call("memory_search", index=0),
|
| 383 |
+
_response_completed("resp_A"),
|
| 384 |
+
]
|
| 385 |
+
|
| 386 |
+
for ev in first_response_events:
|
| 387 |
+
relay.process_event(ev)
|
| 388 |
+
|
| 389 |
+
# --- Second response: continuation text (relayed) ---
|
| 390 |
+
continuation_events = [
|
| 391 |
+
_response_created("resp_B"),
|
| 392 |
+
_output_item_added_text(index=0),
|
| 393 |
+
_output_text_delta(index=0, text="The codename is Pegasus-2"),
|
| 394 |
+
_output_item_done_text(index=0),
|
| 395 |
+
_response_completed("resp_B"),
|
| 396 |
+
]
|
| 397 |
+
|
| 398 |
+
all_relayed: list[str] = []
|
| 399 |
+
for ev in continuation_events:
|
| 400 |
+
result = relay.process_event(ev)
|
| 401 |
+
all_relayed.extend(result["relay"])
|
| 402 |
+
assert result["execute_tools"] == []
|
| 403 |
+
assert result["send_continuation"] is None
|
| 404 |
+
|
| 405 |
+
# All continuation events relayed
|
| 406 |
+
assert len(all_relayed) == 5
|
| 407 |
+
types = [json.loads(e)["type"] for e in all_relayed]
|
| 408 |
+
assert types[0] == "response.created"
|
| 409 |
+
assert types[-1] == "response.completed"
|
| 410 |
+
|
| 411 |
+
# Verify the text content
|
| 412 |
+
text_events = [
|
| 413 |
+
json.loads(e)
|
| 414 |
+
for e in all_relayed
|
| 415 |
+
if json.loads(e)["type"] == "response.output_text.delta"
|
| 416 |
+
]
|
| 417 |
+
assert len(text_events) == 1
|
| 418 |
+
assert text_events[0]["delta"] == "The codename is Pegasus-2"
|
| 419 |
+
|
| 420 |
+
def test_non_json_message_always_relayed(self):
|
| 421 |
+
"""Binary or non-JSON messages pass through regardless."""
|
| 422 |
+
relay = WSMemoryRelayState()
|
| 423 |
+
|
| 424 |
+
result = relay.process_event("not valid json {{{")
|
| 425 |
+
assert len(result["relay"]) == 1
|
| 426 |
+
assert result["relay"][0] == "not valid json {{{"
|
| 427 |
+
|
| 428 |
+
def test_multiple_memory_tools_in_one_response(self):
|
| 429 |
+
"""Multiple memory tools in one response — all suppressed."""
|
| 430 |
+
relay = WSMemoryRelayState()
|
| 431 |
+
|
| 432 |
+
events = [
|
| 433 |
+
_response_created("resp_multi"),
|
| 434 |
+
_output_item_added_function_call("memory_search", index=0, call_id="call_1"),
|
| 435 |
+
_output_item_done_function_call("memory_search", index=0, call_id="call_1"),
|
| 436 |
+
# The model decides to save something too
|
| 437 |
+
_output_item_added_function_call("memory_save", index=1, call_id="call_2"),
|
| 438 |
+
_output_item_done_function_call(
|
| 439 |
+
"memory_save",
|
| 440 |
+
index=1,
|
| 441 |
+
call_id="call_2",
|
| 442 |
+
arguments='{"content": "test"}',
|
| 443 |
+
),
|
| 444 |
+
_response_completed("resp_multi"),
|
| 445 |
+
]
|
| 446 |
+
|
| 447 |
+
all_relayed: list[str] = []
|
| 448 |
+
tool_executions: list[Any] = []
|
| 449 |
+
continuations: list[Any] = []
|
| 450 |
+
|
| 451 |
+
for ev in events:
|
| 452 |
+
result = relay.process_event(ev)
|
| 453 |
+
all_relayed.extend(result["relay"])
|
| 454 |
+
tool_executions.extend(result["execute_tools"])
|
| 455 |
+
if result["send_continuation"]:
|
| 456 |
+
continuations.append(result["send_continuation"])
|
| 457 |
+
|
| 458 |
+
assert len(all_relayed) == 0
|
| 459 |
+
assert len(tool_executions) == 2
|
| 460 |
+
assert {t["name"] for t in tool_executions} == {"memory_search", "memory_save"}
|
| 461 |
+
assert len(continuations) == 1
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
class TestWSMemoryRelayStateReset:
|
| 465 |
+
"""State resets properly between responses."""
|
| 466 |
+
|
| 467 |
+
def test_state_resets_after_memory_response(self):
|
| 468 |
+
"""After a memory response, the relay is ready for a fresh response."""
|
| 469 |
+
relay = WSMemoryRelayState()
|
| 470 |
+
|
| 471 |
+
# Memory response
|
| 472 |
+
for ev in [
|
| 473 |
+
_response_created("resp_A"),
|
| 474 |
+
_output_item_added_function_call("memory_search"),
|
| 475 |
+
_output_item_done_function_call("memory_search"),
|
| 476 |
+
_response_completed("resp_A"),
|
| 477 |
+
]:
|
| 478 |
+
relay.process_event(ev)
|
| 479 |
+
|
| 480 |
+
# State should be reset
|
| 481 |
+
assert relay.decided is False
|
| 482 |
+
assert relay.suppress_response is False
|
| 483 |
+
assert len(relay.pending_function_calls) == 0
|
| 484 |
+
assert len(relay.event_buffer) == 0
|
| 485 |
+
|
| 486 |
+
def test_alternating_memory_and_normal(self):
|
| 487 |
+
"""Memory response → normal response → both work correctly."""
|
| 488 |
+
relay = WSMemoryRelayState()
|
| 489 |
+
|
| 490 |
+
# 1. Memory response (suppressed)
|
| 491 |
+
for ev in [
|
| 492 |
+
_response_created("resp_A"),
|
| 493 |
+
_output_item_added_function_call("memory_search"),
|
| 494 |
+
_output_item_done_function_call("memory_search"),
|
| 495 |
+
_response_completed("resp_A"),
|
| 496 |
+
]:
|
| 497 |
+
relay.process_event(ev)
|
| 498 |
+
|
| 499 |
+
# 2. Continuation text response (relayed)
|
| 500 |
+
relayed: list[str] = []
|
| 501 |
+
for ev in [
|
| 502 |
+
_response_created("resp_B"),
|
| 503 |
+
_output_item_added_text(),
|
| 504 |
+
_output_text_delta(text="Pegasus-2"),
|
| 505 |
+
_output_item_done_text(),
|
| 506 |
+
_response_completed("resp_B"),
|
| 507 |
+
]:
|
| 508 |
+
result = relay.process_event(ev)
|
| 509 |
+
relayed.extend(result["relay"])
|
| 510 |
+
|
| 511 |
+
assert len(relayed) == 5
|
| 512 |
+
|
| 513 |
+
# 3. Another normal response should also work
|
| 514 |
+
relayed2: list[str] = []
|
| 515 |
+
for ev in [
|
| 516 |
+
_response_created("resp_C"),
|
| 517 |
+
_output_item_added_shell(),
|
| 518 |
+
_response_completed("resp_C"),
|
| 519 |
+
]:
|
| 520 |
+
result = relay.process_event(ev)
|
| 521 |
+
relayed2.extend(result["relay"])
|
| 522 |
+
|
| 523 |
+
assert len(relayed2) == 3
|