|
|
""" |
|
|
Cross-Prompt Robustness Analysis |
|
|
|
|
|
Evaluates whether supernodes/features generalize across prompt variations: |
|
|
- Entity swaps: Dallas->Oakland, Texas->California, Austin->Sacramento |
|
|
- Same semantic structure, different entities |
|
|
|
|
|
Measures: |
|
|
1. Feature survival rate (>=70% activation overlap) |
|
|
2. Concept shift accuracy (features activate on swapped entities) |
|
|
3. Layer distribution similarity |
|
|
4. Notable failure modes |
|
|
|
|
|
Usage: |
|
|
python scripts/experiments/cross_prompt_robustness.py \ |
|
|
--prompt1_csv <path> --prompt2_csv <path> \ |
|
|
--prompt1_name "Dallas" --prompt2_name "Oakland" \ |
|
|
--entity_mapping '{"Dallas":"Oakland","Texas":"California","Austin":"Sacramento"}' \ |
|
|
--output_json <path> |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import json |
|
|
import warnings |
|
|
from pathlib import Path |
|
|
from typing import Dict, List, Tuple, Any, Set |
|
|
from collections import defaultdict |
|
|
|
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
from scipy.stats import entropy, ks_2samp |
|
|
|
|
|
|
|
|
def load_and_prepare_data(csv_path: Path) -> pd.DataFrame: |
|
|
"""Load CSV and prepare for analysis.""" |
|
|
df = pd.read_csv(csv_path) |
|
|
|
|
|
|
|
|
required = ["feature_key", "layer", "prompt", "activation_max", "peak_token", "peak_token_idx"] |
|
|
missing = [col for col in required if col not in df.columns] |
|
|
|
|
|
if missing: |
|
|
raise ValueError(f"Missing required columns: {missing}") |
|
|
|
|
|
return df |
|
|
|
|
|
|
|
|
def extract_entity_from_prompt(prompt: str) -> str: |
|
|
"""Extract the main entity being tested in a probe prompt.""" |
|
|
|
|
|
if ":" in prompt: |
|
|
after_colon = prompt.split(":", 1)[1].strip() |
|
|
|
|
|
words = after_colon.split() |
|
|
if words: |
|
|
return words[-1].strip(".,!?") |
|
|
return "" |
|
|
|
|
|
|
|
|
def get_probe_concepts(df: pd.DataFrame) -> List[str]: |
|
|
"""Extract unique probe concepts from prompts.""" |
|
|
concepts = set() |
|
|
for prompt in df["prompt"].unique(): |
|
|
entity = extract_entity_from_prompt(prompt) |
|
|
if entity: |
|
|
concepts.add(entity) |
|
|
return sorted(list(concepts)) |
|
|
|
|
|
|
|
|
def compute_activation_overlap( |
|
|
activations1: np.ndarray, |
|
|
activations2: np.ndarray, |
|
|
threshold_percentile: float = 50.0 |
|
|
) -> float: |
|
|
""" |
|
|
Compute overlap of high-activation positions. |
|
|
|
|
|
Returns fraction of top activated positions that overlap. |
|
|
""" |
|
|
if len(activations1) == 0 or len(activations2) == 0: |
|
|
return 0.0 |
|
|
|
|
|
|
|
|
thresh1 = np.percentile(activations1, threshold_percentile) |
|
|
thresh2 = np.percentile(activations2, threshold_percentile) |
|
|
|
|
|
active1 = set(np.where(activations1 >= thresh1)[0]) |
|
|
active2 = set(np.where(activations2 >= thresh2)[0]) |
|
|
|
|
|
if len(active1) == 0 or len(active2) == 0: |
|
|
return 0.0 |
|
|
|
|
|
overlap = len(active1.intersection(active2)) |
|
|
union = len(active1.union(active2)) |
|
|
|
|
|
return overlap / union if union > 0 else 0.0 |
|
|
|
|
|
|
|
|
def compute_peak_token_consistency( |
|
|
df1: pd.DataFrame, |
|
|
df2: pd.DataFrame, |
|
|
feature_key: str, |
|
|
entity_mapping: Dict[str, str] |
|
|
) -> Tuple[float, Dict[str, Any]]: |
|
|
""" |
|
|
Check if peak tokens shift appropriately with entity swaps. |
|
|
|
|
|
Returns: |
|
|
- consistency score (0-1) |
|
|
- details dict with matched/mismatched tokens |
|
|
""" |
|
|
feature1 = df1[df1["feature_key"] == feature_key] |
|
|
feature2 = df2[df2["feature_key"] == feature_key] |
|
|
|
|
|
if len(feature1) == 0 or len(feature2) == 0: |
|
|
return 0.0, {"error": "missing feature"} |
|
|
|
|
|
|
|
|
peaks1 = {} |
|
|
peaks2 = {} |
|
|
|
|
|
for _, row in feature1.iterrows(): |
|
|
prompt = row["prompt"] |
|
|
peak = row["peak_token"] |
|
|
peaks1[prompt] = peak |
|
|
|
|
|
for _, row in feature2.iterrows(): |
|
|
prompt = row["prompt"] |
|
|
peak = row["peak_token"] |
|
|
peaks2[prompt] = peak |
|
|
|
|
|
|
|
|
matches = 0 |
|
|
total = 0 |
|
|
details = {"matches": [], "mismatches": []} |
|
|
|
|
|
for old_entity, new_entity in entity_mapping.items(): |
|
|
|
|
|
for prompt1, peak1 in peaks1.items(): |
|
|
if old_entity.lower() in prompt1.lower(): |
|
|
|
|
|
prompt2_candidate = prompt1.replace(old_entity, new_entity) |
|
|
|
|
|
if prompt2_candidate in peaks2: |
|
|
peak2 = peaks2[prompt2_candidate] |
|
|
|
|
|
total += 1 |
|
|
|
|
|
|
|
|
|
|
|
if peak1.strip().lower() == old_entity.lower() and peak2.strip().lower() == new_entity.lower(): |
|
|
matches += 1 |
|
|
details["matches"].append({ |
|
|
"prompt1": prompt1, |
|
|
"prompt2": prompt2_candidate, |
|
|
"peak1": peak1, |
|
|
"peak2": peak2, |
|
|
"expected_shift": f"{old_entity}->{new_entity}" |
|
|
}) |
|
|
elif peak1 == peak2: |
|
|
|
|
|
matches += 1 |
|
|
details["matches"].append({ |
|
|
"prompt1": prompt1, |
|
|
"prompt2": prompt2_candidate, |
|
|
"peak1": peak1, |
|
|
"peak2": peak2, |
|
|
"note": "functional token (stable)" |
|
|
}) |
|
|
else: |
|
|
details["mismatches"].append({ |
|
|
"prompt1": prompt1, |
|
|
"prompt2": prompt2_candidate, |
|
|
"peak1": peak1, |
|
|
"peak2": peak2, |
|
|
"expected": f"{old_entity}->{new_entity}" |
|
|
}) |
|
|
|
|
|
consistency = matches / total if total > 0 else 0.0 |
|
|
return consistency, details |
|
|
|
|
|
|
|
|
def compute_layer_distribution_similarity( |
|
|
df1: pd.DataFrame, |
|
|
df2: pd.DataFrame, |
|
|
feature_key: str |
|
|
) -> Tuple[float, Dict[str, Any]]: |
|
|
""" |
|
|
Measure similarity of layer distributions using KS test. |
|
|
|
|
|
Returns p-value (higher = more similar). |
|
|
""" |
|
|
feature1 = df1[df1["feature_key"] == feature_key] |
|
|
feature2 = df2[df2["feature_key"] == feature_key] |
|
|
|
|
|
if len(feature1) == 0 or len(feature2) == 0: |
|
|
return 0.0, {"error": "missing feature"} |
|
|
|
|
|
layers1 = feature1["layer"].values |
|
|
layers2 = feature2["layer"].values |
|
|
|
|
|
|
|
|
statistic, pvalue = ks_2samp(layers1, layers2) |
|
|
|
|
|
details = { |
|
|
"mean_layer1": float(np.mean(layers1)), |
|
|
"mean_layer2": float(np.mean(layers2)), |
|
|
"std_layer1": float(np.std(layers1)), |
|
|
"std_layer2": float(np.std(layers2)), |
|
|
"ks_statistic": float(statistic), |
|
|
"ks_pvalue": float(pvalue), |
|
|
} |
|
|
|
|
|
return pvalue, details |
|
|
|
|
|
|
|
|
def analyze_feature_robustness( |
|
|
df1: pd.DataFrame, |
|
|
df2: pd.DataFrame, |
|
|
feature_key: str, |
|
|
entity_mapping: Dict[str, str], |
|
|
survival_threshold: float = 0.7 |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Comprehensive robustness analysis for a single feature. |
|
|
""" |
|
|
result = { |
|
|
"feature_key": feature_key, |
|
|
"survives": False, |
|
|
"metrics": {} |
|
|
} |
|
|
|
|
|
|
|
|
feature1 = df1[df1["feature_key"] == feature_key] |
|
|
feature2 = df2[df2["feature_key"] == feature_key] |
|
|
|
|
|
if len(feature1) == 0 or len(feature2) == 0: |
|
|
result["error"] = "Feature not found in both datasets" |
|
|
return result |
|
|
|
|
|
|
|
|
activations1 = feature1["activation_max"].values |
|
|
activations2 = feature2["activation_max"].values |
|
|
|
|
|
|
|
|
max_len = max(len(activations1), len(activations2)) |
|
|
act1_padded = np.pad(activations1, (0, max_len - len(activations1)), constant_values=0) |
|
|
act2_padded = np.pad(activations2, (0, max_len - len(activations2)), constant_values=0) |
|
|
|
|
|
overlap = compute_activation_overlap(act1_padded, act2_padded) |
|
|
result["metrics"]["activation_overlap"] = float(overlap) |
|
|
|
|
|
|
|
|
token_consistency, token_details = compute_peak_token_consistency( |
|
|
df1, df2, feature_key, entity_mapping |
|
|
) |
|
|
result["metrics"]["peak_token_consistency"] = float(token_consistency) |
|
|
result["peak_token_details"] = token_details |
|
|
|
|
|
|
|
|
layer_similarity, layer_details = compute_layer_distribution_similarity( |
|
|
df1, df2, feature_key |
|
|
) |
|
|
result["metrics"]["layer_distribution_pvalue"] = float(layer_similarity) |
|
|
result["layer_details"] = layer_details |
|
|
|
|
|
|
|
|
|
|
|
result["survives"] = overlap >= survival_threshold |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
def analyze_cross_prompt_robustness( |
|
|
df1: pd.DataFrame, |
|
|
df2: pd.DataFrame, |
|
|
prompt1_name: str, |
|
|
prompt2_name: str, |
|
|
entity_mapping: Dict[str, str], |
|
|
survival_threshold: float = 0.7 |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Full cross-prompt robustness analysis. |
|
|
""" |
|
|
|
|
|
features1 = set(df1["feature_key"].unique()) |
|
|
features2 = set(df2["feature_key"].unique()) |
|
|
common_features = features1.intersection(features2) |
|
|
|
|
|
print(f"\n{prompt1_name}: {len(features1)} features") |
|
|
print(f"{prompt2_name}: {len(features2)} features") |
|
|
print(f"Common features: {len(common_features)}") |
|
|
|
|
|
|
|
|
feature_results = [] |
|
|
survival_count = 0 |
|
|
|
|
|
for feature_key in sorted(common_features): |
|
|
result = analyze_feature_robustness( |
|
|
df1, df2, feature_key, entity_mapping, survival_threshold |
|
|
) |
|
|
feature_results.append(result) |
|
|
|
|
|
if result.get("survives", False): |
|
|
survival_count += 1 |
|
|
|
|
|
survival_rate = survival_count / len(common_features) if len(common_features) > 0 else 0.0 |
|
|
|
|
|
|
|
|
activation_overlaps = [r["metrics"].get("activation_overlap", 0) for r in feature_results if "error" not in r] |
|
|
peak_consistencies = [r["metrics"].get("peak_token_consistency", 0) for r in feature_results if "error" not in r] |
|
|
layer_pvalues = [r["metrics"].get("layer_distribution_pvalue", 0) for r in feature_results if "error" not in r] |
|
|
|
|
|
aggregate_metrics = { |
|
|
"mean_activation_overlap": float(np.mean(activation_overlaps)) if activation_overlaps else 0.0, |
|
|
"std_activation_overlap": float(np.std(activation_overlaps)) if activation_overlaps else 0.0, |
|
|
"mean_peak_token_consistency": float(np.mean(peak_consistencies)) if peak_consistencies else 0.0, |
|
|
"std_peak_token_consistency": float(np.std(peak_consistencies)) if peak_consistencies else 0.0, |
|
|
"mean_layer_pvalue": float(np.mean(layer_pvalues)) if layer_pvalues else 0.0, |
|
|
"features_with_similar_layers": sum(1 for p in layer_pvalues if p > 0.05), |
|
|
} |
|
|
|
|
|
|
|
|
failure_modes = identify_failure_modes(feature_results, df1, df2) |
|
|
|
|
|
return { |
|
|
"prompt1": prompt1_name, |
|
|
"prompt2": prompt2_name, |
|
|
"entity_mapping": entity_mapping, |
|
|
"n_features_prompt1": len(features1), |
|
|
"n_features_prompt2": len(features2), |
|
|
"n_common_features": len(common_features), |
|
|
"survival_threshold": survival_threshold, |
|
|
"survival_count": survival_count, |
|
|
"survival_rate": survival_rate, |
|
|
"aggregate_metrics": aggregate_metrics, |
|
|
"feature_results": feature_results, |
|
|
"failure_modes": failure_modes, |
|
|
} |
|
|
|
|
|
|
|
|
def identify_failure_modes( |
|
|
feature_results: List[Dict[str, Any]], |
|
|
df1: pd.DataFrame, |
|
|
df2: pd.DataFrame |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Identify common failure patterns. |
|
|
""" |
|
|
failures = [r for r in feature_results if not r.get("survives", False)] |
|
|
|
|
|
failure_modes = { |
|
|
"n_failures": len(failures), |
|
|
"categories": {} |
|
|
} |
|
|
|
|
|
|
|
|
low_overlap = [f for f in failures if f["metrics"].get("activation_overlap", 1.0) < 0.3] |
|
|
inconsistent_peaks = [f for f in failures if f["metrics"].get("peak_token_consistency", 1.0) < 0.5] |
|
|
shifted_layers = [f for f in failures if f["metrics"].get("layer_distribution_pvalue", 1.0) < 0.01] |
|
|
|
|
|
failure_modes["categories"]["low_activation_overlap"] = { |
|
|
"count": len(low_overlap), |
|
|
"examples": [f["feature_key"] for f in low_overlap[:5]] |
|
|
} |
|
|
|
|
|
failure_modes["categories"]["inconsistent_peak_tokens"] = { |
|
|
"count": len(inconsistent_peaks), |
|
|
"examples": [f["feature_key"] for f in inconsistent_peaks[:5]] |
|
|
} |
|
|
|
|
|
failure_modes["categories"]["layer_distribution_shift"] = { |
|
|
"count": len(shifted_layers), |
|
|
"examples": [f["feature_key"] for f in shifted_layers[:5]] |
|
|
} |
|
|
|
|
|
return failure_modes |
|
|
|
|
|
|
|
|
def print_summary(results: Dict[str, Any]): |
|
|
"""Print human-readable summary.""" |
|
|
print("\n" + "="*80) |
|
|
print("CROSS-PROMPT ROBUSTNESS ANALYSIS") |
|
|
print("="*80) |
|
|
|
|
|
print(f"\nPrompt 1: {results['prompt1']}") |
|
|
print(f"Prompt 2: {results['prompt2']}") |
|
|
print(f"\nEntity Mapping:") |
|
|
for old, new in results['entity_mapping'].items(): |
|
|
print(f" {old} -> {new}") |
|
|
|
|
|
print(f"\nFeature Coverage:") |
|
|
print(f" {results['prompt1']}: {results['n_features_prompt1']} features") |
|
|
print(f" {results['prompt2']}: {results['n_features_prompt2']} features") |
|
|
print(f" Common: {results['n_common_features']} features") |
|
|
|
|
|
print(f"\nSurvival Analysis (threshold >={results['survival_threshold']:.0%} overlap):") |
|
|
print(f" Survived: {results['survival_count']} / {results['n_common_features']}") |
|
|
print(f" Survival Rate: {results['survival_rate']:.2%}") |
|
|
|
|
|
agg = results['aggregate_metrics'] |
|
|
print(f"\nAggregate Metrics:") |
|
|
print(f" Mean Activation Overlap: {agg['mean_activation_overlap']:.3f} +/- {agg['std_activation_overlap']:.3f}") |
|
|
print(f" Mean Peak Token Consistency: {agg['mean_peak_token_consistency']:.3f} +/- {agg['std_peak_token_consistency']:.3f}") |
|
|
print(f" Mean Layer P-value: {agg['mean_layer_pvalue']:.3f}") |
|
|
print(f" Features with Similar Layers (p>0.05): {agg['features_with_similar_layers']} / {results['n_common_features']}") |
|
|
|
|
|
fm = results['failure_modes'] |
|
|
print(f"\nFailure Modes ({fm['n_failures']} total failures):") |
|
|
for category, data in fm['categories'].items(): |
|
|
print(f" {category.replace('_', ' ').title()}: {data['count']}") |
|
|
if data['examples']: |
|
|
print(f" Examples: {', '.join(data['examples'][:3])}") |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="Cross-prompt robustness analysis") |
|
|
parser.add_argument("--prompt1_csv", type=str, required=True, help="CSV for prompt 1") |
|
|
parser.add_argument("--prompt2_csv", type=str, required=True, help="CSV for prompt 2") |
|
|
parser.add_argument("--prompt1_name", type=str, required=True, help="Name for prompt 1") |
|
|
parser.add_argument("--prompt2_name", type=str, required=True, help="Name for prompt 2") |
|
|
parser.add_argument("--entity_mapping", type=str, required=True, |
|
|
help='JSON entity mapping, e.g. \'{"Dallas":"Oakland","Texas":"California"}\'') |
|
|
parser.add_argument("--survival_threshold", type=float, default=0.7, |
|
|
help="Activation overlap threshold for survival (default: 0.7)") |
|
|
parser.add_argument("--output_json", type=str, default=None, help="Output JSON path") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
entity_mapping = json.loads(args.entity_mapping) |
|
|
|
|
|
|
|
|
print(f"Loading {args.prompt1_name} from {args.prompt1_csv}...") |
|
|
df1 = load_and_prepare_data(Path(args.prompt1_csv)) |
|
|
|
|
|
print(f"Loading {args.prompt2_name} from {args.prompt2_csv}...") |
|
|
df2 = load_and_prepare_data(Path(args.prompt2_csv)) |
|
|
|
|
|
|
|
|
results = analyze_cross_prompt_robustness( |
|
|
df1, df2, |
|
|
args.prompt1_name, args.prompt2_name, |
|
|
entity_mapping, |
|
|
args.survival_threshold |
|
|
) |
|
|
|
|
|
|
|
|
print_summary(results) |
|
|
|
|
|
|
|
|
if args.output_json: |
|
|
output_path = Path(args.output_json) |
|
|
else: |
|
|
output_path = Path(args.prompt1_csv).parent / "cross_prompt_robustness.json" |
|
|
|
|
|
print(f"\nSaving results to {output_path}...") |
|
|
with open(output_path, "w") as f: |
|
|
json.dump(results, f, indent=2) |
|
|
|
|
|
print("\nDone!") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|
|
|
|