#!/usr/bin/env python3
"""ENGRAM Research Paper — Figure Generation.
Generates all 15 figures for the ENGRAM paper from results/ data files.
Output: results/figures/*.pdf (LaTeX-compatible, 300 DPI)
Usage:
cd ENGRAM && python scripts/paper_figures.py
python scripts/paper_figures.py --only fig02 # Single figure
python scripts/paper_figures.py --list # List all figures
"""
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
from typing import Any
import matplotlib
matplotlib.use("Agg") # Non-interactive backend
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
# ── Configuration ────────────────────────────────────────────────────────
RESULTS_DIR = Path(__file__).parent.parent / "results"
FIGURES_DIR = RESULTS_DIR / "figures"
ABSOLUTE_DIR = RESULTS_DIR / "absolute"
STRESS_DIR = RESULTS_DIR / "stress"
# LaTeX-compatible style
plt.rcParams.update({
"font.family": "serif",
"font.size": 11,
"axes.labelsize": 12,
"axes.titlesize": 13,
"xtick.labelsize": 10,
"ytick.labelsize": 10,
"legend.fontsize": 10,
"figure.dpi": 300,
"savefig.dpi": 300,
"savefig.bbox": "tight",
"savefig.pad_inches": 0.1,
"axes.grid": True,
"grid.alpha": 0.3,
"axes.spines.top": False,
"axes.spines.right": False,
})
# Colorblind-safe palette
COLORS = {
"blue": "#4477AA",
"orange": "#EE6677",
"green": "#228833",
"purple": "#AA3377",
"cyan": "#66CCEE",
"grey": "#BBBBBB",
"red": "#CC3311",
"teal": "#009988",
"yellow": "#CCBB44",
"indigo": "#332288",
}
PASS_COLOR = COLORS["green"]
FAIL_COLOR = COLORS["red"]
# ── Data Loading ─────────────────────────────────────────────────────────
def load_json(path: Path) -> dict[str, Any]:
"""Load JSON file and return parsed dict."""
return json.loads(path.read_text())
def save_figure(fig: plt.Figure, name: str) -> None:
"""Save figure as PDF and PNG."""
FIGURES_DIR.mkdir(parents=True, exist_ok=True)
fig.savefig(FIGURES_DIR / f"{name}.pdf", format="pdf")
fig.savefig(FIGURES_DIR / f"{name}.png", format="png")
plt.close(fig)
print(f" Saved: {name}.pdf + .png")
# ── Figure 2: Frequency Combination Comparison ──────────────────────────
def fig02_frequency_comparison() -> None:
"""Bar chart: 6 frequency combos × recall and margin."""
print("Fig 02: Frequency combination comparison...")
data = load_json(ABSOLUTE_DIR / "multifreq_comparison.json")
results = data["results"]
combos = list(results.keys())
recalls = [results[c]["recall"] * 100 for c in combos]
margins = [results[c]["margin_mean"] * 1000 for c in combos] # ×1000
failures = [results[c]["n_failures"] for c in combos]
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4.5))
# Left: Recall
x = np.arange(len(combos))
bar_colors = [COLORS["green"] if c == "f0+f1" else COLORS["blue"] for c in combos]
bars = ax1.bar(x, recalls, color=bar_colors, edgecolor="white", linewidth=0.5)
ax1.set_xticks(x)
ax1.set_xticklabels(combos, rotation=30, ha="right")
ax1.set_ylabel("Recall@1 (%)")
ax1.set_title("(a) Recall by Frequency Combination")
ax1.set_ylim(60, 102)
for bar, val, nf in zip(bars, recalls, failures):
ax1.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.5,
f"{val:.0f}%\n({nf} fail)", ha="center", va="bottom", fontsize=8)
# Right: Mean margin
bars2 = ax2.bar(x, margins, color=bar_colors, edgecolor="white", linewidth=0.5)
ax2.set_xticks(x)
ax2.set_xticklabels(combos, rotation=30, ha="right")
ax2.set_ylabel("Mean Margin (×10³)")
ax2.set_title("(b) Mean Discrimination Margin")
for bar, val in zip(bars2, margins):
ax2.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.05,
f"{val:.1f}", ha="center", va="bottom", fontsize=8)
fig.suptitle("Multi-Frequency Fingerprint Ablation (N=200)", fontsize=14, y=1.02)
fig.tight_layout()
save_figure(fig, "fig02_frequency_comparison")
# ── Figure 3: Margin Power Law ──────────────────────────────────────────
def fig03_margin_power_law() -> None:
"""Log-log plot: margin vs N for f1 and f0+f1 with fitted power laws."""
print("Fig 03: Margin power law...")
f1_data = load_json(ABSOLUTE_DIR / "margin_compression_law.json")
f0f1_data = load_json(ABSOLUTE_DIR / "multifreq_law.json")
# f1 data
f1_n = [int(n) for n in f1_data["results"].keys()]
f1_margins = [f1_data["results"][str(n)]["mean_margin"] for n in f1_n]
f1_alpha = f1_data["alpha"]
f1_A = f1_data["A"]
# f0+f1 data
f0f1_n = [int(n) for n in f0f1_data["results"].keys()]
f0f1_margins = [f0f1_data["results"][str(n)]["mean_margin"] for n in f0f1_n]
f0f1_alpha = f0f1_data["alpha"]
f0f1_A = f0f1_data["A"]
fig, ax = plt.subplots(figsize=(7, 5))
# Data points
ax.scatter(f1_n, f1_margins, color=COLORS["orange"], s=60, zorder=5, label="f1 (data)")
ax.scatter(f0f1_n, f0f1_margins, color=COLORS["blue"], s=60, zorder=5, label="f0+f1 (data)")
# Fitted curves
n_fit = np.linspace(3, 250, 200)
f1_fit = f1_A * n_fit ** f1_alpha
f0f1_fit = f0f1_A * n_fit ** f0f1_alpha
ax.plot(n_fit, f1_fit, color=COLORS["orange"], linestyle="--", alpha=0.7,
label=f"f1 fit: {f1_A:.4f}·N^{{{f1_alpha:.3f}}}")
ax.plot(n_fit, f0f1_fit, color=COLORS["blue"], linestyle="--", alpha=0.7,
label=f"f0+f1 fit: {f0f1_A:.4f}·N^{{{f0f1_alpha:.3f}}}")
ax.set_xscale("log")
ax.set_yscale("log")
ax.set_xlabel("Corpus Size N")
ax.set_ylabel("Mean Discrimination Margin")
ax.set_title("Margin Power Law: Graceful Degradation")
ax.legend(loc="upper right")
ax.xaxis.set_major_formatter(ticker.ScalarFormatter())
ax.set_xticks([5, 10, 20, 50, 100, 200])
# Annotation
ax.annotate(
f"f0+f1: α={f0f1_alpha:.3f} (shallower)\nf1: α={f1_alpha:.3f}",
xy=(100, f0f1_A * 100 ** f0f1_alpha), xytext=(30, 0.003),
arrowprops={"arrowstyle": "->", "color": COLORS["grey"]},
fontsize=9, bbox={"boxstyle": "round,pad=0.3", "facecolor": "wheat", "alpha": 0.5}
)
fig.tight_layout()
save_figure(fig, "fig03_margin_power_law")
# ── Figure 4: Recall vs N — Fourier vs FCDB ─────────────────────────────
def fig04_recall_vs_n() -> None:
"""Fourier f0+f1 recall vs FCDB recall across corpus sizes."""
print("Fig 04: Recall vs N (Fourier vs FCDB)...")
f0f1_data = load_json(ABSOLUTE_DIR / "multifreq_law.json")
stress_data = load_json(STRESS_DIR / "STRESS_SUMMARY.json")
# Fourier f0+f1
fourier_n = [int(n) for n in f0f1_data["results"].keys()]
fourier_recall = [f0f1_data["results"][str(n)]["recall"] * 100 for n in fourier_n]
# FCDB cross-model
fcdb_map = stress_data["recall_at_1_vs_n_fcdb"]
fcdb_n = [int(n) for n in fcdb_map.keys()]
fcdb_recall = [v * 100 for v in fcdb_map.values()]
fig, ax = plt.subplots(figsize=(7, 5))
ax.plot(fourier_n, fourier_recall, "o-", color=COLORS["blue"], linewidth=2,
markersize=7, label="Fourier f0+f1 (same-model)", zorder=5)
ax.plot(fcdb_n, fcdb_recall, "s--", color=COLORS["orange"], linewidth=2,
markersize=7, label="FCDB (cross-model)", zorder=5)
# Collapse annotation
ax.axvline(x=100, color=COLORS["red"], linestyle=":", alpha=0.5)
ax.annotate("FCDB collapse\n(N=100)", xy=(100, 30), xytext=(140, 50),
arrowprops={"arrowstyle": "->", "color": COLORS["red"]},
fontsize=9, color=COLORS["red"])
ax.set_xlabel("Corpus Size N")
ax.set_ylabel("Recall@1 (%)")
ax.set_title("Retrieval Recall vs Corpus Size")
ax.legend(loc="lower left")
ax.set_ylim(-5, 105)
ax.set_xlim(0, 210)
fig.tight_layout()
save_figure(fig, "fig04_recall_vs_n")
# ── Figure 5: Cross-Model Strategy Comparison ───────────────────────────
def fig05_cross_model_strategies() -> None:
"""Horizontal bar chart: 9 cross-model methods × margin."""
print("Fig 05: Cross-model strategy comparison...")
strategies = [
("CCA", -0.420, False),
("Residual FCB", -0.382, False),
("Procrustes", -0.104, False),
("RR (K=20)", -0.066, False),
("FCB+ridge", -0.017, False),
("Contrastive", 0.001, True),
("JCB", 0.011, True),
("JCB+delta", 0.037, True),
("FCDB", 0.124, True),
]
names = [s[0] for s in strategies]
margins = [s[1] for s in strategies]
colors = [PASS_COLOR if s[2] else FAIL_COLOR for s in strategies]
fig, ax = plt.subplots(figsize=(8, 5))
y_pos = np.arange(len(names))
bars = ax.barh(y_pos, margins, color=colors, edgecolor="white", linewidth=0.5, height=0.7)
ax.set_yticks(y_pos)
ax.set_yticklabels(names)
ax.set_xlabel("Retrieval Margin")
ax.set_title("Cross-Model Transfer Strategies (Llama 3B → 8B)")
ax.axvline(x=0, color="black", linewidth=0.8)
# Value labels
for bar, val in zip(bars, margins):
x_offset = 0.005 if val >= 0 else -0.005
ha = "left" if val >= 0 else "right"
ax.text(val + x_offset, bar.get_y() + bar.get_height() / 2,
f"{val:+.3f}", ha=ha, va="center", fontsize=9, fontweight="bold")
# Legend
from matplotlib.patches import Patch
legend_elements = [Patch(facecolor=PASS_COLOR, label="PASS (margin > 0)"),
Patch(facecolor=FAIL_COLOR, label="FAIL (margin ≤ 0)")]
ax.legend(handles=legend_elements, loc="lower right")
fig.tight_layout()
save_figure(fig, "fig05_cross_model_strategies")
# ── Figure 6: CKA Layer Similarity ──────────────────────────────────────
def fig06_cka_layers() -> None:
"""CKA similarity per layer: within-family vs cross-family."""
print("Fig 06: CKA layer similarity...")
within = load_json(ABSOLUTE_DIR / "FAMILY_CKA.json")
cross = load_json(ABSOLUTE_DIR / "FAMILY_CKA_CROSS.json")
within_cka = within["layer_ckas"]
cross_cka = cross["layer_ckas"]
layers = list(range(len(within_cka)))
fig, ax = plt.subplots(figsize=(8, 4.5))
ax.plot(layers, within_cka, "o-", color=COLORS["blue"], markersize=5, linewidth=1.5,
label=f"Within-family (Llama 3B↔8B), μ={within['mean_cka']:.3f}")
ax.plot(layers, cross_cka, "s--", color=COLORS["orange"], markersize=5, linewidth=1.5,
label=f"Cross-family (Llama↔Qwen), μ={cross['mean_cka']:.3f}")
ax.axhline(y=0.95, color=COLORS["grey"], linestyle=":", alpha=0.5, label="0.95 threshold")
ax.set_xlabel("Layer Index")
ax.set_ylabel("CKA Similarity")
ax.set_title("Centered Kernel Alignment Across Layers")
ax.legend(loc="lower left", fontsize=9)
ax.set_ylim(0.85, 1.0)
# Annotate min
min_idx_w = int(np.argmin(within_cka))
min_idx_c = int(np.argmin(cross_cka))
ax.annotate(f"min={within_cka[min_idx_w]:.3f}", xy=(min_idx_w, within_cka[min_idx_w]),
xytext=(min_idx_w + 2, within_cka[min_idx_w] - 0.01),
fontsize=8, color=COLORS["blue"])
ax.annotate(f"min={cross_cka[min_idx_c]:.3f}", xy=(min_idx_c, cross_cka[min_idx_c]),
xytext=(min_idx_c + 2, cross_cka[min_idx_c] - 0.01),
fontsize=8, color=COLORS["orange"])
fig.tight_layout()
save_figure(fig, "fig06_cka_layers")
# ── Figure 7: Domain Confusion Before/After ──────────────────────────────
def fig07_confusion_matrix() -> None:
"""Heatmaps: f1 confusion vs f0+f1 confusion across domains."""
print("Fig 07: Domain confusion matrix...")
data = load_json(ABSOLUTE_DIR / "confusion_analysis.json")
domains = sorted({
k.split(" -> ")[0] for k in data["f1_confusion"].keys()
} | {
k.split(" -> ")[1] for k in data["f1_confusion"].keys()
})
def build_matrix(confusion_dict: dict[str, int]) -> np.ndarray:
n = len(domains)
mat = np.zeros((n, n))
for key, count in confusion_dict.items():
src, dst = key.split(" -> ")
if src in domains and dst in domains:
i = domains.index(src)
j = domains.index(dst)
mat[i, j] = count
return mat
f1_mat = build_matrix(data["f1_confusion"])
best_mat = build_matrix(data["best_confusion"])
# Short domain labels
short_labels = [d[:6] for d in domains]
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
im1 = ax1.imshow(f1_mat, cmap="Reds", aspect="auto", interpolation="nearest")
ax1.set_xticks(range(len(domains)))
ax1.set_yticks(range(len(domains)))
ax1.set_xticklabels(short_labels, rotation=45, ha="right", fontsize=8)
ax1.set_yticklabels(short_labels, fontsize=8)
ax1.set_title("(a) f1 Only — 28 Failures")
ax1.set_xlabel("Confused With")
ax1.set_ylabel("True Domain")
fig.colorbar(im1, ax=ax1, shrink=0.8)
im2 = ax2.imshow(best_mat, cmap="Blues", aspect="auto", interpolation="nearest")
ax2.set_xticks(range(len(domains)))
ax2.set_yticks(range(len(domains)))
ax2.set_xticklabels(short_labels, rotation=45, ha="right", fontsize=8)
ax2.set_yticklabels(short_labels, fontsize=8)
ax2.set_title("(b) f0+f1 — 4 Failures")
ax2.set_xlabel("Confused With")
ax2.set_ylabel("True Domain")
fig.colorbar(im2, ax=ax2, shrink=0.8)
fig.suptitle("Domain Confusion Analysis (N=200)", fontsize=14, y=1.02)
fig.tight_layout()
save_figure(fig, "fig07_confusion_matrix")
# ── Figure 8: Domain Recall Radar ────────────────────────────────────────
def fig08_domain_recall_radar() -> None:
"""Radar chart: per-domain recall with f0+f1."""
print("Fig 08: Domain recall radar...")
data = load_json(ABSOLUTE_DIR / "confusion_analysis.json")
domain_recall = data["domain_recall"]
categories = list(domain_recall.keys())
values = [domain_recall[c] * 100 for c in categories]
# Close the polygon
values_closed = values + [values[0]]
n = len(categories)
angles = [i / n * 2 * np.pi for i in range(n)]
angles_closed = angles + [angles[0]]
fig, ax = plt.subplots(figsize=(6, 6), subplot_kw={"projection": "polar"})
ax.plot(angles_closed, values_closed, "o-", color=COLORS["blue"], linewidth=2, markersize=6)
ax.fill(angles_closed, values_closed, color=COLORS["blue"], alpha=0.15)
ax.set_xticks(angles)
ax.set_xticklabels([c.replace("_", "\n") for c in categories], fontsize=9)
ax.set_ylim(80, 102)
ax.set_yticks([85, 90, 95, 100])
ax.set_yticklabels(["85%", "90%", "95%", "100%"], fontsize=8)
ax.set_title("Per-Domain Recall@1 (f0+f1, N=200)", pad=20)
# Annotate minimum
min_idx = int(np.argmin(values))
ax.annotate(f"{values[min_idx]:.0f}%",
xy=(angles[min_idx], values[min_idx]),
xytext=(angles[min_idx] + 0.2, values[min_idx] - 3),
fontsize=9, fontweight="bold", color=COLORS["red"])
fig.tight_layout()
save_figure(fig, "fig08_domain_recall_radar")
# ── Figure 9: HNSW Benchmark ────────────────────────────────────────────
def fig09_hnsw_benchmark() -> None:
"""Bar chart: HNSW vs brute-force latency."""
print("Fig 09: HNSW benchmark...")
data = load_json(ABSOLUTE_DIR / "HNSW_BENCH.json")
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(9, 4))
# Latency comparison
methods = ["Brute-Force", "HNSW"]
latencies = [data["bf_latency_us"], data["hnsw_latency_us"]]
colors = [COLORS["orange"], COLORS["blue"]]
bars = ax1.bar(methods, latencies, color=colors, edgecolor="white", width=0.5)
ax1.set_ylabel("Latency (μs)")
ax1.set_title(f"(a) Search Latency — {data['speedup']:.1f}× Speedup")
for bar, val in zip(bars, latencies):
ax1.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 3,
f"{val:.1f} μs", ha="center", va="bottom", fontsize=10)
# Recall comparison
recalls = [data["bruteforce_recall"] * 100, data["hnsw_recall"] * 100]
bars2 = ax2.bar(methods, recalls, color=colors, edgecolor="white", width=0.5)
ax2.set_ylabel("Recall@1 (%)")
ax2.set_title("(b) Recall Preserved")
ax2.set_ylim(98, 100.5)
for bar, val in zip(bars2, recalls):
ax2.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.05,
f"{val:.1f}%", ha="center", va="bottom", fontsize=10)
fig.suptitle("HNSW Index Benchmark (N=200)", fontsize=14, y=1.02)
fig.tight_layout()
save_figure(fig, "fig09_hnsw_benchmark")
# ── Figure 10: INT8 Compression ──────────────────────────────────────────
def fig10_int8_compression() -> None:
"""Bar chart: FP16 vs INT8 comparison."""
print("Fig 10: INT8 compression...")
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(9, 4))
# Size comparison
configs = ["591 tok", "6,403 tok"]
fp16_sizes = [73.9, 800.4]
int8_sizes = [37.5, 406.5]
x = np.arange(len(configs))
w = 0.35
ax1.bar(x - w / 2, fp16_sizes, w, label="FP16", color=COLORS["orange"], edgecolor="white")
ax1.bar(x + w / 2, int8_sizes, w, label="INT8", color=COLORS["blue"], edgecolor="white")
ax1.set_xticks(x)
ax1.set_xticklabels(configs)
ax1.set_ylabel("File Size (MB)")
ax1.set_title("(a) .eng File Size — 1.97× Compression")
ax1.legend()
# Quality metrics
metrics = ["Cosine\nSimilarity", "Margin\n(FP16)", "Margin\n(INT8)"]
values = [0.99998, 0.381, 0.262]
bar_colors = [COLORS["green"], COLORS["blue"], COLORS["cyan"]]
bars = ax2.bar(metrics, values, color=bar_colors, edgecolor="white", width=0.5)
ax2.set_ylabel("Value")
ax2.set_title("(b) Quality Preservation")
for bar, val in zip(bars, values):
ax2.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.01,
f"{val:.5f}" if val > 0.9 else f"{val:.3f}",
ha="center", va="bottom", fontsize=9)
fig.suptitle("INT8 Quantization Impact", fontsize=14, y=1.02)
fig.tight_layout()
save_figure(fig, "fig10_int8_compression")
# ── Figure 12: Margin Distribution ───────────────────────────────────────
def fig12_margin_distribution() -> None:
"""Distribution comparison: f1 vs f0+f1 summary statistics."""
print("Fig 12: Margin distribution...")
data = load_json(ABSOLUTE_DIR / "multifreq_comparison.json")
results = data["results"]
fig, ax = plt.subplots(figsize=(7, 4.5))
# We'll show key statistics as a visualization
combos = ["f1", "f0+f1"]
means = [results[c]["margin_mean"] * 1000 for c in combos]
medians = [results[c]["margin_median"] * 1000 for c in combos]
mins = [results[c]["margin_min"] * 1000 for c in combos]
x = np.arange(len(combos))
w = 0.25
ax.bar(x - w, means, w, label="Mean", color=COLORS["blue"], edgecolor="white")
ax.bar(x, medians, w, label="Median", color=COLORS["green"], edgecolor="white")
ax.bar(x + w, mins, w, label="Min", color=COLORS["red"], edgecolor="white")
ax.set_xticks(x)
ax.set_xticklabels(combos, fontsize=12)
ax.set_ylabel("Margin (×10³)")
ax.set_title("Margin Statistics: f1 vs f0+f1 (N=200)")
ax.legend()
ax.axhline(y=0, color="black", linewidth=0.5)
# Annotate improvement
ax.annotate(
f"+76% mean margin\n25/28 failures fixed",
xy=(1, means[1]), xytext=(1.3, means[1] + 1),
arrowprops={"arrowstyle": "->", "color": COLORS["green"]},
fontsize=9, bbox={"boxstyle": "round,pad=0.3", "facecolor": "#e6ffe6", "alpha": 0.8}
)
fig.tight_layout()
save_figure(fig, "fig12_margin_distribution")
# ── Figure 13: FCDB Stability-Discrimination Tradeoff ────────────────────
def fig13_fcdb_tradeoff() -> None:
"""Dual-axis: basis stability vs retrieval margin vs corpus size."""
print("Fig 13: FCDB stability-discrimination tradeoff...")
# Data from PAPER_TABLE.md
n_vals = [50, 100, 125, 200]
stability = [0.82, 0.906, 0.983, 0.999] # subspace agreement
margin = [0.124, None, None, 0.013] # Only measured at 50 and 200
margin_n = [50, 200]
margin_v = [0.124, 0.013]
fig, ax1 = plt.subplots(figsize=(7, 5))
ax2 = ax1.twinx()
# Stability (left axis)
line1 = ax1.plot(n_vals, stability, "o-", color=COLORS["blue"], linewidth=2,
markersize=8, label="Basis Stability", zorder=5)
ax1.set_xlabel("Corpus Size N")
ax1.set_ylabel("Subspace Agreement", color=COLORS["blue"])
ax1.tick_params(axis="y", labelcolor=COLORS["blue"])
ax1.set_ylim(0.7, 1.05)
# Margin (right axis)
line2 = ax2.plot(margin_n, margin_v, "s--", color=COLORS["orange"], linewidth=2,
markersize=8, label="Retrieval Margin", zorder=5)
ax2.set_ylabel("Cross-Model Margin", color=COLORS["orange"])
ax2.tick_params(axis="y", labelcolor=COLORS["orange"])
ax2.set_ylim(-0.01, 0.15)
# Threshold line
ax1.axhline(y=0.99, color=COLORS["grey"], linestyle=":", alpha=0.5)
ax1.annotate("Stable (≥0.99)", xy=(125, 0.99), fontsize=8, color=COLORS["grey"])
# Combined legend
lines = line1 + line2
labels = [l.get_label() for l in lines]
ax1.legend(lines, labels, loc="center left")
ax1.set_title("FCDB Stability–Discrimination Tradeoff")
fig.tight_layout()
save_figure(fig, "fig13_fcdb_tradeoff")
# ── Figure 14: TTFT Speedup ─────────────────────────────────────────────
def fig14_ttft_speedup() -> None:
"""Grouped bar chart: cold vs warm TTFT."""
print("Fig 14: TTFT speedup...")
configs = ["3B / 4K tok", "3B / 16K tok", "8B / 591 tok"]
cold_ttft = [11439, 94592, 3508] # ms
warm_ttft = [170, 1777, 116] # ms
speedups = [67.2, 53.2, 30.8]
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4.5))
x = np.arange(len(configs))
w = 0.35
ax1.bar(x - w / 2, cold_ttft, w, label="Cold TTFT", color=COLORS["orange"], edgecolor="white")
ax1.bar(x + w / 2, warm_ttft, w, label="Warm TTFT", color=COLORS["blue"], edgecolor="white")
ax1.set_xticks(x)
ax1.set_xticklabels(configs, fontsize=9)
ax1.set_ylabel("TTFT (ms)")
ax1.set_title("(a) Time to First Token")
ax1.set_yscale("log")
ax1.legend()
# Speedup bars
bars = ax2.bar(configs, speedups, color=COLORS["green"], edgecolor="white", width=0.5)
ax2.set_ylabel("Speedup (×)")
ax2.set_title("(b) KV Cache Restoration Speedup")
ax2.set_xticklabels(configs, fontsize=9)
for bar, val in zip(bars, speedups):
ax2.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.5,
f"{val:.1f}×", ha="center", va="bottom", fontsize=10, fontweight="bold")
fig.suptitle("KV Cache Warm Start Performance", fontsize=14, y=1.02)
fig.tight_layout()
save_figure(fig, "fig14_ttft_speedup")
# ── Figure 15: EGR Overhead Scaling ──────────────────────────────────────
def fig15_egr_overhead() -> None:
"""Scatter/line: EGR overhead vs token count."""
print("Fig 15: EGR overhead scaling...")
tokens = [600, 6403, 600]
overhead_ms = [30.6, 48.8, 84.0]
labels = ["16 layers\n(8-24)", "16 layers\n(8-24)", "32 layers\n(all)"]
colors_pts = [COLORS["blue"], COLORS["blue"], COLORS["orange"]]
fig, ax = plt.subplots(figsize=(6, 4.5))
for t, o, l, c in zip(tokens, overhead_ms, labels, colors_pts):
ax.scatter(t, o, s=100, color=c, zorder=5, edgecolor="white", linewidth=1.5)
ax.annotate(l, xy=(t, o), xytext=(t + 200, o + 2), fontsize=9)
ax.set_xlabel("Context Length (tokens)")
ax.set_ylabel("EGR Overhead (ms)")
ax.set_title("Fingerprint Extraction Overhead")
ax.set_xlim(0, 7000)
ax.set_ylim(20, 95)
# Reference lines
ax.axhline(y=50, color=COLORS["grey"], linestyle=":", alpha=0.3)
ax.text(100, 51, "50ms threshold", fontsize=8, color=COLORS["grey"])
fig.tight_layout()
save_figure(fig, "fig15_egr_overhead")
# ── Figure 1: Architecture Diagram (Mermaid) ────────────────────────────
def fig01_architecture_mermaid() -> None:
"""Generate Mermaid flowchart for system architecture."""
print("Fig 01: Architecture diagram (Mermaid)...")
mermaid = """\
%%{init: {'theme': 'base', 'themeVariables': {'primaryColor': '#4477AA', 'primaryTextColor': '#fff', 'primaryBorderColor': '#335588', 'lineColor': '#666', 'secondaryColor': '#EE6677', 'tertiaryColor': '#228833'}}}%%
flowchart TD
A[LLM Runtime
llama.cpp] -->|KV cache blob| B[Blob Parser]
B -->|Layer keys K| C[Fourier Fingerprint
f0+f1 DFT]
C -->|2048-dim vector| D{Storage}
D -->|.eng binary| E[EIGENGRAM File
v1.2 format]
D -->|HNSW index| F[FAISS IndexHNSW
M=32]
G[Query Session] -->|New KV cache| C
C -->|Query fingerprint| H[Geodesic Retrieval]
F -->|Top-k candidates| H
H --> I{Stage 0
Prior Check}
I -->|chronic failure| J[Skip / LOW]
I -->|ok| K{Stage 1
HNSW Search}
K -->|HIGH / MEDIUM| L[Result]
K -->|below threshold| M{Stage 2
Trajectory}
M -->|interpolation| N{Stage 3
Constraints}
N --> O{Stage 4
Metadata}
O --> L
subgraph Confidence Tracking
P[IndexC
SQLite] ---|update| I
L ---|record| P
end
style A fill:#4477AA,stroke:#335588,color:#fff
style C fill:#228833,stroke:#1a6625,color:#fff
style E fill:#EE6677,stroke:#cc5566,color:#fff
style F fill:#66CCEE,stroke:#55aabb,color:#000
style H fill:#AA3377,stroke:#882266,color:#fff
"""
mermaid_path = FIGURES_DIR / "fig01_architecture.mmd"
mermaid_path.write_text(mermaid)
print(f" Saved: fig01_architecture.mmd")
# ── Figure 11: Retrieval Pipeline (Mermaid) ──────────────────────────────
def fig11_retrieval_pipeline_mermaid() -> None:
"""Generate Mermaid diagram for 4-stage geodesic retrieval."""
print("Fig 11: Retrieval pipeline (Mermaid)...")
mermaid = """\
%%{init: {'theme': 'base'}}%%
flowchart LR
Q[Query
Fingerprint] --> S0
S0[Stage 0
Prior Preemption
IndexC chronic
failure check]
S0 -->|"pass"| S1
S0 -->|"preempt"| SKIP[SKIP
confidence=LOW]
S1[Stage 1
HNSW Search
cosine top-k]
S1 -->|"margin > 0.005"| HIGH[HIGH
199/200 docs]
S1 -->|"margin 0.001-0.005"| MED[MEDIUM]
S1 -->|"margin < 0.001"| S2
S2[Stage 2
Trajectory
interpolation
w=0.3]
S2 --> S3
S3[Stage 3
Negative
Constraints
apophatic layer]
S3 --> S4
S4[Stage 4
Metadata
Disambig
domain + keywords
+ norms]
S4 --> LOW[LOW
1/200 docs
doc_146]
style S0 fill:#66CCEE,stroke:#55aabb
style S1 fill:#4477AA,stroke:#335588,color:#fff
style S2 fill:#CCBB44,stroke:#aa9933
style S3 fill:#EE6677,stroke:#cc5566,color:#fff
style S4 fill:#AA3377,stroke:#882266,color:#fff
style HIGH fill:#228833,stroke:#1a6625,color:#fff
style MED fill:#CCBB44,stroke:#aa9933
style LOW fill:#EE6677,stroke:#cc5566,color:#fff
style SKIP fill:#BBBBBB,stroke:#999999
"""
mermaid_path = FIGURES_DIR / "fig11_retrieval_pipeline.mmd"
mermaid_path.write_text(mermaid)
print(f" Saved: fig11_retrieval_pipeline.mmd")
# ── Consolidated Findings JSON ───────────────────────────────────────────
def generate_findings() -> None:
"""Consolidate all key metrics into a single findings.json."""
print("Generating consolidated findings...")
findings = {
"title": "ENGRAM Protocol — Consolidated Research Findings",
"date": "2026-04-03",
"hardware": {
"platform": "Apple M3, 24GB RAM",
"gpu": "Metal (n_gpu_layers=-1)",
"os": "macOS Darwin 25.4.0",
"llama_cpp": "0.3.19",
"faiss": "1.13.2",
"torch": "2.11.0",
},
"same_model_retrieval": {
"method": "Fourier f0+f1 fingerprint",
"corpus_size": 200,
"n_domains": 10,
"recall_at_1": 0.98,
"n_failures": 4,
"mean_margin": 0.007201,
"margin_power_law": {"A": 0.021342, "alpha": -0.2065},
"f1_only_recall": 0.86,
"f1_only_failures": 28,
"improvement_over_f1": "25/28 failures fixed (+76% mean margin)",
"ml_math_confusion_reduction": "81.5%",
},
"frequency_ablation": {
"combos_tested": 6,
"best": "f0+f1",
"results": {
"f1": {"recall": 0.86, "margin": 0.004087},
"f2": {"recall": 0.715, "margin": 0.002196},
"f1+f2": {"recall": 0.95, "margin": 0.004744},
"f1+f2+f3": {"recall": 0.95, "margin": 0.004129},
"f0+f1": {"recall": 0.98, "margin": 0.007201},
"f1+f3": {"recall": 0.89, "margin": 0.003477},
},
},
"hnsw_index": {
"speedup": 5.65,
"recall": 0.995,
"latency_us": 51.83,
"bruteforce_latency_us": 293.07,
},
"geodesic_retrieval": {
"stages": 4,
"final_recall": 1.0,
"n_high": 0,
"n_medium": 199,
"n_low": 1,
"hard_failure": "doc_146 (resolved by Stage 4 metadata)",
},
"int8_compression": {
"ratio": 1.97,
"cosine_similarity": 0.99998,
"margin_fp16": 0.381,
"margin_int8": 0.262,
"margin_preserved": True,
},
"ttft_speedup": {
"3b_4k": {"cold_ms": 11439, "warm_ms": 170, "speedup": 67.2},
"3b_16k": {"cold_ms": 94592, "warm_ms": 1777, "speedup": 53.2},
"8b_591": {"cold_ms": 3508, "warm_ms": 116, "speedup": 30.8},
},
"cross_model_transfer": {
"n_strategies": 9,
"best_method": "FCDB",
"best_margin": 0.124,
"results": {
"CCA": {"margin": -0.420, "correct": False},
"Residual_FCB": {"margin": -0.382, "correct": False},
"Procrustes": {"margin": -0.104, "correct": False},
"RR": {"margin": -0.066, "correct": False},
"FCB_ridge": {"margin": -0.017, "correct": False},
"Contrastive": {"margin": 0.001, "correct": True},
"JCB": {"margin": 0.011, "correct": True},
"JCB_delta": {"margin": 0.037, "correct": True},
"FCDB": {"margin": 0.124, "correct": True},
},
"key_insight": "Cross-model transfer requires representing documents as directions from a shared reference point (Frechet mean), not positions in space",
},
"fcdb_scaling": {
"v1_n50": {"stability": 0.82, "margin": 0.124},
"v2_n200": {"stability": 0.999, "margin": 0.013},
"collapse_n": 100,
"tradeoff": "Larger corpus stabilizes basis but dilutes per-document signal",
},
"cka_analysis": {
"within_family": {"models": "Llama 3B ↔ 8B", "mean_cka": 0.975, "f0f1_sim": 0.875},
"cross_family": {"models": "Llama ↔ Qwen", "mean_cka": 0.927, "f0f1_sim": 0.259},
"verdict": "Manifolds topologically isomorphic (CKA>0.92 all pairs)",
},
"domain_recall": {
"computer_science": 1.0, "general_world": 0.95, "history": 1.0,
"language_arts": 1.0, "ml_systems": 0.90, "mathematics": 1.0,
"philosophy": 1.0, "medicine": 0.95, "biology": 1.0, "physics": 1.0,
},
"eigengram_format": {
"version": "1.2",
"architectures": ["llama", "gemma", "gemma4/ISWA", "phi", "qwen", "mistral"],
"iswa_support": "Gemma 4 26B dual-cache (5+25 layers, 6144-dim fingerprint)",
},
}
paper_dir = RESULTS_DIR / "paper"
paper_dir.mkdir(parents=True, exist_ok=True)
findings_path = paper_dir / "findings.json"
findings_path.write_text(json.dumps(findings, indent=2))
print(f" Saved: paper/findings.json")
# ── LaTeX Tables ─────────────────────────────────────────────────────────
def generate_latex_tables() -> None:
"""Generate LaTeX table source for the paper."""
print("Generating LaTeX tables...")
tables = r"""\
% ──────────────────────────────────────────────────────────────────────
% Table 1: Multi-Frequency Ablation
% ──────────────────────────────────────────────────────────────────────
\begin{table}[t]
\centering
\caption{Multi-frequency fingerprint ablation at $N=200$. The f0+f1 combination
achieves the highest recall and mean margin, fixing 25 of 28 single-frequency failures.}
\label{tab:frequency-ablation}
\begin{tabular}{lcccc}
\toprule
Frequencies & Recall@1 & Mean Margin & Min Margin & Failures \\
\midrule
$f_1$ & 86.0\% & 4.09$\times 10^{-3}$ & $-4.71\times 10^{-3}$ & 28 \\
$f_2$ & 71.5\% & 2.20$\times 10^{-3}$ & $-5.85\times 10^{-3}$ & 57 \\
$f_1 + f_2$ & 95.0\% & 4.74$\times 10^{-3}$ & $-2.68\times 10^{-3}$ & 10 \\
$f_1 + f_2 + f_3$ & 95.0\% & 4.13$\times 10^{-3}$ & $-2.71\times 10^{-3}$ & 10 \\
\rowcolor{green!10}
$f_0 + f_1$ & \textbf{98.0\%} & \textbf{7.20}$\times 10^{-3}$ & $-4.09\times 10^{-3}$ & \textbf{4} \\
$f_1 + f_3$ & 89.0\% & 3.48$\times 10^{-3}$ & $-4.08\times 10^{-3}$ & 22 \\
\bottomrule
\end{tabular}
\end{table}
% ──────────────────────────────────────────────────────────────────────
% Table 2: Cross-Model Transfer Strategies
% ──────────────────────────────────────────────────────────────────────
\begin{table}[t]
\centering
\caption{Cross-model transfer strategies (Llama 3B $\to$ 8B). Nine methods tested;
FCDB achieves the only reliable positive margin without requiring an adapter.}
\label{tab:cross-model}
\begin{tabular}{lccc}
\toprule
Method & Margin & Correct & Adapter \\
\midrule
CCA & $-0.420$ & \xmark & symmetric \\
Residual FCB & $-0.382$ & \xmark & none \\
Procrustes & $-0.104$ & \xmark & orthogonal \\
Relative Repr. & $-0.066$ & \xmark & none \\
FCB + ridge & $-0.017$ & \xmark & ridge \\
\midrule
Contrastive $\delta$ & $+0.001$ & \cmark & ridge \\
JCB & $+0.011$ & \cmark & none \\
JCB + $\delta$ & $+0.037$ & \cmark & none \\
\rowcolor{green!10}
\textbf{FCDB} & $\mathbf{+0.124}$ & \cmark & \textbf{none} \\
\bottomrule
\end{tabular}
\end{table}
% ──────────────────────────────────────────────────────────────────────
% Table 3: TTFT Speedup
% ──────────────────────────────────────────────────────────────────────
\begin{table}[t]
\centering
\caption{KV cache warm-start performance. TTFT speedup ranges from 27--67$\times$
depending on model size and context length.}
\label{tab:ttft}
\begin{tabular}{lccccc}
\toprule
Model & Tokens & Cold TTFT & Warm TTFT & Speedup & EGR (ms) \\
\midrule
Llama 3.2 3B & 4,002 & 11,439\,ms & 170\,ms & 67.2$\times$ & 9.5 \\
Llama 3.2 3B & 16,382 & 94,592\,ms & 1,777\,ms & 53.2$\times$ & 9.5 \\
Llama 3.1 8B & 591 & 3,508\,ms & 116\,ms & 30.8$\times$ & 30.6 \\
\bottomrule
\end{tabular}
\end{table}
% ──────────────────────────────────────────────────────────────────────
% Table 4: INT8 Compression
% ──────────────────────────────────────────────────────────────────────
\begin{table}[t]
\centering
\caption{INT8 quantization results. Per-row symmetric quantization achieves
1.97$\times$ compression with negligible quality loss (cos\_sim = 0.99998).}
\label{tab:int8}
\begin{tabular}{lcccc}
\toprule
Tokens & FP16 Size & INT8 Size & Ratio & $\cos(s_\text{fp16}, s_\text{int8})$ \\
\midrule
591 & 73.9\,MB & 37.5\,MB & 1.97$\times$ & 0.99998 \\
6,403 & 800.4\,MB & 406.5\,MB & 1.97$\times$ & 0.99998 \\
\bottomrule
\end{tabular}
\end{table}
% ──────────────────────────────────────────────────────────────────────
% Table 5: CKA Analysis
% ──────────────────────────────────────────────────────────────────────
\begin{table}[t]
\centering
\caption{Centered Kernel Alignment (CKA) between model families. High CKA values
($>0.92$) confirm topological isomorphism of key manifolds across architectures.}
\label{tab:cka}
\begin{tabular}{lccc}
\toprule
Comparison & Mean CKA & f0+f1 Sim & Verdict \\
\midrule
Within-family (Llama 3B $\leftrightarrow$ 8B) & 0.975 & 0.875 & Isomorphic \\
Cross-family (Llama $\leftrightarrow$ Qwen) & 0.927 & 0.259 & Isomorphic \\
\bottomrule
\end{tabular}
\end{table}
% ──────────────────────────────────────────────────────────────────────
% Table 6: HNSW Benchmark
% ──────────────────────────────────────────────────────────────────────
\begin{table}[t]
\centering
\caption{HNSW index performance at $N=200$. The index provides 5.65$\times$
speedup over brute-force with no recall loss.}
\label{tab:hnsw}
\begin{tabular}{lcc}
\toprule
Method & Latency ($\mu$s) & Recall@1 \\
\midrule
Brute-force & 293.1 & 99.5\% \\
HNSW ($M=32$) & 51.8 & 99.5\% \\
\midrule
\textbf{Speedup} & \textbf{5.65$\times$} & --- \\
\bottomrule
\end{tabular}
\end{table}
% ──────────────────────────────────────────────────────────────────────
% Table 7: Domain Recall
% ──────────────────────────────────────────────────────────────────────
\begin{table}[t]
\centering
\caption{Per-domain recall@1 with f0+f1 fingerprint at $N=200$.
All domains achieve $\geq 90\%$ recall.}
\label{tab:domain-recall}
\begin{tabular}{lc}
\toprule
Domain & Recall@1 \\
\midrule
Biology & 100.0\% \\
Computer Science & 100.0\% \\
History & 100.0\% \\
Language Arts & 100.0\% \\
Mathematics & 100.0\% \\
Philosophy & 100.0\% \\
Physics & 100.0\% \\
General World & 95.0\% \\
Medicine & 95.0\% \\
ML/Systems & 90.0\% \\
\bottomrule
\end{tabular}
\end{table}
% ──────────────────────────────────────────────────────────────────────
% Table 8: Margin Power Law
% ──────────────────────────────────────────────────────────────────────
\begin{table}[t]
\centering
\caption{Margin scaling law parameters. Both fingerprint methods follow
power-law decay $\bar{m} = A \cdot N^\alpha$ with no hard collapse point.}
\label{tab:power-law}
\begin{tabular}{lccc}
\toprule
Fingerprint & $A$ & $\alpha$ & Recall@200 \\
\midrule
$f_1$ & 0.0181 & $-0.277$ & 86.0\% \\
$f_0 + f_1$ & 0.0213 & $-0.207$ & 98.0\% \\
\bottomrule
\end{tabular}
\end{table}
"""
paper_dir = RESULTS_DIR / "paper"
paper_dir.mkdir(parents=True, exist_ok=True)
tables_path = paper_dir / "tables.tex"
tables_path.write_text(tables)
print(f" Saved: paper/tables.tex")
# ── Registry ─────────────────────────────────────────────────────────────
FIGURE_REGISTRY: dict[str, tuple[str, object]] = {
"fig01": ("System Architecture (Mermaid)", fig01_architecture_mermaid),
"fig02": ("Frequency Combination Comparison", fig02_frequency_comparison),
"fig03": ("Margin Power Law", fig03_margin_power_law),
"fig04": ("Recall vs N (Fourier vs FCDB)", fig04_recall_vs_n),
"fig05": ("Cross-Model Strategy Comparison", fig05_cross_model_strategies),
"fig06": ("CKA Layer Similarity", fig06_cka_layers),
"fig07": ("Domain Confusion Matrix", fig07_confusion_matrix),
"fig08": ("Domain Recall Radar", fig08_domain_recall_radar),
"fig09": ("HNSW Benchmark", fig09_hnsw_benchmark),
"fig10": ("INT8 Compression", fig10_int8_compression),
"fig11": ("Retrieval Pipeline (Mermaid)", fig11_retrieval_pipeline_mermaid),
"fig12": ("Margin Distribution", fig12_margin_distribution),
"fig13": ("FCDB Tradeoff", fig13_fcdb_tradeoff),
"fig14": ("TTFT Speedup", fig14_ttft_speedup),
"fig15": ("EGR Overhead Scaling", fig15_egr_overhead),
"findings": ("Consolidated Findings JSON", generate_findings),
"tables": ("LaTeX Tables", generate_latex_tables),
}
def main() -> None:
parser = argparse.ArgumentParser(description="Generate ENGRAM paper figures")
parser.add_argument("--only", help="Generate only this figure (e.g., fig02)")
parser.add_argument("--list", action="store_true", help="List all figures")
args = parser.parse_args()
if args.list:
print("\nAvailable figures:")
for key, (desc, _) in FIGURE_REGISTRY.items():
print(f" {key:10s} {desc}")
return
FIGURES_DIR.mkdir(parents=True, exist_ok=True)
print(f"\nOutput directory: {FIGURES_DIR}\n")
if args.only:
if args.only not in FIGURE_REGISTRY:
print(f"Unknown figure: {args.only}")
print(f"Available: {', '.join(FIGURE_REGISTRY.keys())}")
sys.exit(1)
desc, func = FIGURE_REGISTRY[args.only]
func()
else:
for key, (desc, func) in FIGURE_REGISTRY.items():
try:
func()
except Exception as e:
print(f" ERROR generating {key}: {e}")
print(f"\nDone. Figures saved to: {FIGURES_DIR}")
if __name__ == "__main__":
main()