axcpt-llm-control / scripts /run_leakage_reduced_embedding_analysis.py
neuro-bot's picture
Upload 91 files
a668ecc verified
"""Leakage-reduced exploratory embedding analysis for AX-CPT representations.
This second-pass analysis removes explicit dataset labels, condition labels,
and direct DCM indicator fields from the serialized text before embedding. It
keeps original labels only as non-embedded metadata for grouping summaries.
"""
from __future__ import annotations
import argparse
import csv
import json
from pathlib import Path
import numpy as np
from run_embedding_analysis import (
MODEL_NAME,
TOKEN_RE,
condition_vector_rows,
cosine_pair_summary,
embed_rows,
metadata_rows,
pca_2d,
projection_rows,
read_jsonl,
save_embedding_bundle,
write_csv,
)
MASKED_MODEL_NAME = f"{MODEL_NAME}_leakage_reduced_input_v1"
REMOVED_FIELDS = {
"dataset",
"condition",
"dcm_invocation_rate",
"use_dcm",
"dcm_invoked",
}
def write_jsonl(path: Path, rows: list[dict[str, object]]) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("w", encoding="utf-8") as handle:
for row in rows:
handle.write(json.dumps(row, ensure_ascii=False, sort_keys=True) + "\n")
def mask_pipe_fields(piece: str) -> str:
kept: list[str] = []
for field in piece.split("|"):
field = field.strip()
if not field:
continue
key = field.split("=", 1)[0]
if key in REMOVED_FIELDS:
continue
kept.append(field)
return "|".join(kept)
def mask_serialized_text(text: str) -> str:
if " || trials: " in text:
header, trial_text = text.split(" || trials: ", 1)
masked_header = mask_pipe_fields(header)
masked_trials = " ; ".join(mask_pipe_fields(piece) for piece in trial_text.split(" ; "))
return f"{masked_header} || trials: {masked_trials}"
return " ; ".join(mask_pipe_fields(piece) for piece in text.split(" ; "))
def make_analysis_rows(rows: list[dict[str, object]]) -> list[dict[str, object]]:
masked: list[dict[str, object]] = []
for row in rows:
new_row = dict(row)
new_row["serialized_text"] = mask_serialized_text(str(row["serialized_text"]))
masked.append(new_row)
return masked
def make_masked_representation_rows(rows: list[dict[str, object]], prefix: str) -> list[dict[str, object]]:
masked: list[dict[str, object]] = []
for idx, row in enumerate(rows):
item = {
"masked_representation_id": f"{prefix}::{idx:06d}",
"representation_level": row["representation_level"],
"n_trials": row.get("n_trials", ""),
"serialized_text": mask_serialized_text(str(row["serialized_text"])),
}
if row["representation_level"] == "sliding_window":
item.update(
{
"window_size": row.get("window_size", ""),
"window_start_trial_idx": row.get("window_start_trial_idx", ""),
"window_end_trial_idx": row.get("window_end_trial_idx", ""),
}
)
masked.append(item)
return masked
def masked_condition_vector_rows(rows: list[dict[str, object]], embeddings: np.ndarray) -> list[dict[str, object]]:
out = condition_vector_rows(rows, embeddings)
for row in out:
row["embedding_model"] = MASKED_MODEL_NAME
return out
def write_masking_report(path: Path, condition_count: int, sliding_count: int) -> None:
report = f"""# Leakage-Reduced Embedding Analysis
This is a second-pass exploratory embedding analysis. It tests whether broad separation remains after removing obvious metadata from the text passed into the embedding step.
## Inputs
- `outputs/condition_level_representations.jsonl`: {condition_count} rows.
- `outputs/sliding_window_representations.jsonl`: {sliding_count} rows.
## Removed Or Masked From Embedded Text
The following pipe-delimited serialized text fields are removed before embedding:
- `dataset`
- `condition`
- `dcm_invocation_rate`
- `use_dcm`
- `dcm_invoked`
Original dataset and condition labels are retained only outside the embedded text as grouping metadata for similarity summaries and plots.
## Still Present In Embedded Text
The masked text may still contain trial type, cue, probe, parsed response, correctness, invalid-response flags, previous-trial type/correctness, distractor count, context window, and sequence/order structure. These remaining fields can still distinguish experimental families or conditions.
## Embedding Model
- Model/library: `{MASKED_MODEL_NAME}` using the same local hashed token n-gram vectorizer as `scripts/run_embedding_analysis.py`.
- Text processing: lowercase alphanumeric tokenization with regex `{TOKEN_RE.pattern}`.
- Vectorization: deterministic signed CRC32 feature hashing into 256 dimensions.
- Normalization: L2 normalization.
- Projection: deterministic PCA via `numpy.linalg.svd`.
This is not a neural embedding, latent model state analysis, logit analysis, probability analysis, reaction-time analysis, or cost/latency analysis.
"""
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(report, encoding="utf-8")
def write_report(
path: Path,
condition_count: int,
sliding_count: int,
condition_explained: list[float],
sliding_explained: list[float],
condition_group_summary: list[dict[str, object]],
sliding_group_summary: list[dict[str, object]],
) -> None:
def summary_lines(rows: list[dict[str, object]]) -> str:
lines = []
for row in rows:
lines.append(
f"- {row['comparison_group']}: mean={row['mean_cosine_similarity']}, "
f"min={row['min_cosine_similarity']}, max={row['max_cosine_similarity']}, n={row['n_pairs']}"
)
return "\n".join(lines)
report = f"""# Leakage-Reduced Exploratory Embedding Report
This analysis recomputes embeddings after removing explicit dataset labels, condition labels, and direct DCM indicator fields from the serialized text.
## Rows Embedded
- Condition-level rows: {condition_count}
- Sliding-window rows: {sliding_count}
- Trial-level rows: not embedded in this pass.
## PCA Summary
- Condition-level PCA explained variance ratio: PC1={condition_explained[0]:.6f}, PC2={condition_explained[1]:.6f}
- Sliding-window PCA explained variance ratio: PC1={sliding_explained[0]:.6f}, PC2={sliding_explained[1]:.6f}
## Similarity Group Summary
Condition-level:
{summary_lines(condition_group_summary)}
Sliding-window:
{summary_lines(sliding_group_summary)}
## Interpretation Scope
If separation remains, it should be interpreted as separation in the remaining serialized observable fields, not as evidence about latent model internals. Remaining fields include response symbols, correctness/invalid flags, context-window values, trial order, and AX-CPT event sequences.
"""
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(report, encoding="utf-8")
def condition_category(row: dict[str, object]) -> str | None:
dataset_a = str(row["dataset_a"])
dataset_b = str(row["dataset_b"])
condition_a = str(row["condition_a"])
condition_b = str(row["condition_b"])
if row["pair_type"] == "within_condition":
return "within_condition"
if dataset_a == "axcpt_v4b" and dataset_b == "axcpt_v4b":
return "v4b_vs_v4b"
if dataset_a != dataset_b:
return "v4b_vs_v5"
if dataset_a == "axcpt_v5_dcm" and dataset_b == "axcpt_v5_dcm":
a_is_base = condition_a.endswith("_BASE")
b_is_base = condition_b.endswith("_BASE")
a_is_dcm = condition_a.endswith("_DCM")
b_is_dcm = condition_b.endswith("_DCM")
if a_is_base and b_is_base:
return "v5_base_vs_base"
if a_is_dcm and b_is_dcm:
return "v5_dcm_vs_dcm"
if (a_is_base and b_is_dcm) or (a_is_dcm and b_is_base):
return "v5_base_vs_dcm"
return "v5_other"
return None
def build_group_summary(rows: list[dict[str, object]], level: str) -> list[dict[str, object]]:
grouped: dict[str, list[float]] = {}
pair_counts: dict[str, int] = {}
for row in rows:
category = condition_category(row)
if category is None or row["mean_cosine_similarity"] == "":
continue
grouped.setdefault(category, []).append(float(row["mean_cosine_similarity"]))
pair_counts[category] = pair_counts.get(category, 0) + int(row["n_pairs"])
out: list[dict[str, object]] = []
for category in [
"within_condition",
"v4b_vs_v4b",
"v5_base_vs_base",
"v5_dcm_vs_dcm",
"v5_base_vs_dcm",
"v4b_vs_v5",
]:
values = grouped.get(category)
if not values:
continue
out.append(
{
"representation_level": level,
"comparison_group": category,
"n_summary_rows": len(values),
"n_pairs": pair_counts[category],
"mean_cosine_similarity": round(sum(values) / len(values), 6),
"min_cosine_similarity": round(min(values), 6),
"max_cosine_similarity": round(max(values), 6),
}
)
return out
def main() -> int:
parser = argparse.ArgumentParser(description="Run leakage-reduced exploratory embedding analysis.")
parser.add_argument("--input-dir", type=Path, default=Path("outputs"))
parser.add_argument("--output-dir", type=Path, default=Path("outputs/embedding_analysis_leakage_reduced"))
parser.add_argument("--dim", type=int, default=256)
args = parser.parse_args()
condition_rows = read_jsonl(args.input_dir / "condition_level_representations.jsonl")
sliding_rows = read_jsonl(args.input_dir / "sliding_window_representations.jsonl")
condition_analysis_rows = make_analysis_rows(condition_rows)
sliding_analysis_rows = make_analysis_rows(sliding_rows)
args.output_dir.mkdir(parents=True, exist_ok=True)
write_jsonl(
args.output_dir / "leakage_reduced_condition_level_representations.jsonl",
make_masked_representation_rows(condition_rows, "masked_condition"),
)
write_jsonl(
args.output_dir / "leakage_reduced_sliding_window_representations.jsonl",
make_masked_representation_rows(sliding_rows, "masked_sliding_window"),
)
condition_embeddings = embed_rows(condition_analysis_rows, dim=args.dim)
sliding_embeddings = embed_rows(sliding_analysis_rows, dim=args.dim)
save_embedding_bundle(args.output_dir / "condition_embeddings.npz", condition_analysis_rows, condition_embeddings)
save_embedding_bundle(args.output_dir / "sliding_window_embeddings.npz", sliding_analysis_rows, sliding_embeddings)
metadata_fields = [
"row_idx",
"representation_id",
"representation_level",
"dataset",
"condition",
"n_trials",
"window_size",
"window_start_trial_idx",
"window_end_trial_idx",
]
write_csv(args.output_dir / "condition_embedding_metadata.csv", metadata_rows(condition_analysis_rows), metadata_fields)
write_csv(args.output_dir / "sliding_window_embedding_metadata.csv", metadata_rows(sliding_analysis_rows), metadata_fields)
write_csv(
args.output_dir / "condition_embedding_vectors.csv",
masked_condition_vector_rows(condition_analysis_rows, condition_embeddings),
[
"row_idx",
"representation_id",
"dataset",
"condition",
"embedding_model",
"embedding_dim",
"embedding_vector_json",
],
)
similarity_fields = [
"representation_level",
"dataset_a",
"condition_a",
"dataset_b",
"condition_b",
"pair_type",
"n_pairs",
"mean_cosine_similarity",
"mean_cosine_distance",
"min_cosine_similarity",
"max_cosine_similarity",
"std_cosine_similarity",
]
condition_similarity = cosine_pair_summary(condition_analysis_rows, condition_embeddings, "condition_leakage_reduced")
sliding_similarity = cosine_pair_summary(sliding_analysis_rows, sliding_embeddings, "sliding_window_leakage_reduced")
write_csv(args.output_dir / "condition_embedding_similarity_pairs.csv", condition_similarity, similarity_fields)
write_csv(args.output_dir / "sliding_window_embedding_similarity_summary.csv", sliding_similarity, similarity_fields)
condition_group_summary = build_group_summary(condition_similarity, "condition_leakage_reduced")
sliding_group_summary = build_group_summary(sliding_similarity, "sliding_window_leakage_reduced")
group_summary_fields = [
"representation_level",
"comparison_group",
"n_summary_rows",
"n_pairs",
"mean_cosine_similarity",
"min_cosine_similarity",
"max_cosine_similarity",
]
write_csv(
args.output_dir / "leakage_reduced_similarity_group_summary.csv",
condition_group_summary + sliding_group_summary,
group_summary_fields,
)
condition_coords, condition_explained = pca_2d(condition_embeddings)
sliding_coords, sliding_explained = pca_2d(sliding_embeddings)
projection_fields = [
"row_idx",
"representation_id",
"representation_level",
"dataset",
"condition",
"window_size",
"window_start_trial_idx",
"window_end_trial_idx",
"pc1",
"pc2",
"pc1_explained_variance_ratio",
"pc2_explained_variance_ratio",
]
write_csv(
args.output_dir / "condition_embedding_projection_2d.csv",
projection_rows(condition_analysis_rows, condition_coords, condition_explained),
projection_fields,
)
write_csv(
args.output_dir / "sliding_window_embedding_projection_2d.csv",
projection_rows(sliding_analysis_rows, sliding_coords, sliding_explained),
projection_fields,
)
config = {
"analysis_label": "leakage_reduced_exploratory_embedding_analysis",
"embedding_model": MASKED_MODEL_NAME,
"embedding_dim": args.dim,
"library": "numpy",
"removed_from_embedded_text": sorted(REMOVED_FIELDS),
"labels_retained_only_for_grouping": ["dataset", "condition"],
"inputs": {
"condition_level": str(args.input_dir / "condition_level_representations.jsonl"),
"sliding_window": str(args.input_dir / "sliding_window_representations.jsonl"),
},
"masked_outputs": {
"condition_level": str(args.output_dir / "leakage_reduced_condition_level_representations.jsonl"),
"sliding_window": str(args.output_dir / "leakage_reduced_sliding_window_representations.jsonl"),
},
}
(args.output_dir / "embedding_model_config.json").write_text(
json.dumps(config, ensure_ascii=False, indent=2, sort_keys=True) + "\n",
encoding="utf-8",
)
write_masking_report(
args.output_dir / "leakage_reduced_masking_report.md",
condition_count=len(condition_analysis_rows),
sliding_count=len(sliding_analysis_rows),
)
write_report(
args.output_dir / "leakage_reduced_embedding_report.md",
condition_count=len(condition_analysis_rows),
sliding_count=len(sliding_analysis_rows),
condition_explained=condition_explained,
sliding_explained=sliding_explained,
condition_group_summary=condition_group_summary,
sliding_group_summary=sliding_group_summary,
)
print(f"Embedded {len(condition_analysis_rows)} leakage-reduced condition-level rows.")
print(f"Embedded {len(sliding_analysis_rows)} leakage-reduced sliding-window rows.")
print(f"Wrote leakage-reduced embedding outputs to {args.output_dir}")
return 0
if __name__ == "__main__":
raise SystemExit(main())