iclr2026-realign-challenge / scripts /smoke_test_submission.py
siddsuresh97's picture
Initial commit: ICLR 2026 Representational Alignment Challenge
d6c8a4f
#!/usr/bin/env python3
"""End-to-end smoke test mimicking a blue team submission via Modal.
Picks 20 models from the blue team registry and 50 images from the
stimuli catalog, runs the full extraction + CKA scoring pipeline
on Modal, and verifies the results.
Usage:
python scripts/smoke_test_submission.py \
--registry configs/blue_team_model_registry.json \
--stimuli configs/blue_team_images.jsonl
# With custom counts:
python scripts/smoke_test_submission.py \
--num-models 5 --num-stimuli 10
"""
from __future__ import annotations
import argparse
import hashlib
import json
import sys
from pathlib import Path
def load_registry(path: str) -> list[dict]:
data = json.loads(Path(path).read_text())
if isinstance(data, dict):
return data["models"]
return data
def load_stimuli(path: str) -> list[dict]:
p = Path(path)
if p.suffix == ".jsonl":
return [json.loads(line) for line in p.read_text().splitlines() if line.strip()]
data = json.loads(p.read_text())
if isinstance(data, dict):
return data["stimuli"]
return data
def cache_key_from_payload(registry: list[dict], stimuli: list[dict]) -> str:
payload = {"registry": registry, "stimuli": stimuli}
encoded = json.dumps(payload, sort_keys=True).encode("utf-8")
digest = hashlib.sha1(encoded).hexdigest()[:12]
return f"smoke_{digest}"
def main():
parser = argparse.ArgumentParser(description="End-to-end blue team submission smoke test on Modal")
parser.add_argument("--registry", default="configs/blue_team_model_registry.json",
help="Path to blue team model registry JSON")
parser.add_argument("--stimuli", default="configs/blue_team_images.jsonl",
help="Path to stimuli catalog JSONL")
parser.add_argument("--num-models", type=int, default=20,
help="Number of models to select (default: 20)")
parser.add_argument("--num-stimuli", type=int, default=50,
help="Number of stimuli to select (default: 50)")
parser.add_argument("--app-name", default="iclr2026-eval",
help="Modal app name")
parser.add_argument("--batch-size", type=int, default=64,
help="Batch size for embedding extraction")
args = parser.parse_args()
# Load full registry and stimuli
full_registry = load_registry(args.registry)
full_stimuli = load_stimuli(args.stimuli)
print(f"Full registry: {len(full_registry)} models")
print(f"Full stimuli: {len(full_stimuli)} images")
# Select subset - pick models evenly spaced for diversity
n_models = min(args.num_models, len(full_registry))
if n_models < len(full_registry):
step = len(full_registry) / n_models
indices = [int(i * step) for i in range(n_models)]
subset_registry = [full_registry[i] for i in indices]
else:
subset_registry = full_registry
n_stimuli = min(args.num_stimuli, len(full_stimuli))
subset_stimuli = full_stimuli[:n_stimuli]
model_names = [m["model_name"] for m in subset_registry]
print(f"\nSelected {len(subset_registry)} models:")
for m in model_names:
print(f" - {m}")
print(f"\nSelected {len(subset_stimuli)} stimuli (first {n_stimuli} from catalog)")
# Build cache key
cache_key = cache_key_from_payload(subset_registry, subset_stimuli)
print(f"\nCache key: {cache_key}")
# Connect to Modal
print("\nConnecting to Modal...")
import modal
extract_fn = modal.Function.from_name(args.app_name, "extract_embeddings_s3")
score_fn = modal.Function.from_name(args.app_name, "compute_pairwise_cka")
print("Modal connection OK")
# Step 1: Extract embeddings
print(f"\n{'=' * 60}")
print("STEP 1: Extracting embeddings...")
print(f"{'=' * 60}")
extract_result = extract_fn.remote(
model_registry=subset_registry,
stimuli=subset_stimuli,
cache_key=cache_key,
batch_size=args.batch_size,
reuse_cache=False, # Force fresh extraction for smoke test
)
print(f"\nExtraction result:")
print(f" cache_key: {extract_result.get('cache_key')}")
print(f" num_stimuli: {extract_result.get('num_stimuli')}")
print(f" models: {len(extract_result.get('models', []))}")
for m in extract_result.get("models", []):
print(
f" {m['model_name']:45s} layer={m.get('layer', '?'):25s} "
f"dim={m.get('dim', '?'):>6} samples={m.get('num_samples', '?')}"
)
# Validate extraction
errors = []
for m in extract_result.get("models", []):
if m.get("num_samples") != n_stimuli:
errors.append(f" {m['model_name']}: expected {n_stimuli} samples, got {m.get('num_samples')}")
if m.get("dim", 0) <= 0:
errors.append(f" {m['model_name']}: invalid dim {m.get('dim')}")
if errors:
print(f"\nExtraction ERRORS:")
for e in errors:
print(e)
sys.exit(1)
print(f"\nExtraction: ALL {len(extract_result.get('models', []))} models OK")
# Step 2: Compute pairwise CKA
print(f"\n{'=' * 60}")
print("STEP 2: Computing pairwise CKA...")
print(f"{'=' * 60}")
cka_result = score_fn.remote(
cache_key=cache_key,
model_names=model_names,
)
avg_cka = cka_result.get("avg_cka", 0.0)
pairwise = cka_result.get("pairwise", [])
print(f"\nCKA results:")
print(f" avg_cka: {avg_cka:.6f}")
print(f" num_pairs: {len(pairwise)}")
expected_pairs = n_models * (n_models - 1) // 2
if len(pairwise) != expected_pairs:
print(f" WARNING: expected {expected_pairs} pairs, got {len(pairwise)}")
# Show top 5 and bottom 5 pairs
sorted_pairs = sorted(pairwise, key=lambda x: x.get("cka", 0), reverse=True)
print(f"\n Top 5 most similar pairs:")
for p in sorted_pairs[:5]:
ma = p.get("model_a", "?")
mb = p.get("model_b", "?")
cka = p.get("cka", 0.0)
print(f" {ma:40s} <-> {mb:40s} CKA={cka:.6f}")
print(f"\n Bottom 5 least similar pairs:")
for p in sorted_pairs[-5:]:
ma = p.get("model_a", "?")
mb = p.get("model_b", "?")
cka = p.get("cka", 0.0)
print(f" {ma:40s} <-> {mb:40s} CKA={cka:.6f}")
# Validate CKA results
cka_errors = []
for p in pairwise:
cka_val = p.get("cka")
if cka_val is None:
cka_errors.append(" Missing CKA value for a pair")
elif not (-0.1 <= cka_val <= 1.5): # Allow slight numerical overshoot
ma = p.get("model_a", "?")
mb = p.get("model_b", "?")
cka_errors.append(f" {ma} <-> {mb}: CKA={cka_val} out of expected range")
if cka_errors:
print(f"\nCKA ERRORS:")
for e in cka_errors:
print(e)
sys.exit(1)
# Step 3: Validate submission format (local check)
print(f"\n{'=' * 60}")
print("STEP 3: Validating submission format...")
print(f"{'=' * 60}")
submission_payload = {
"models": [
{"model_name": m["model_name"], "layer_name": m["layer"]}
for m in subset_registry
]
}
print(f"\n Sample submission entry: {json.dumps(submission_payload['models'][0])}")
try:
from src.hackathon.validation import (
BLUE_TEAM_REQUIRED_MODELS,
load_model_registry,
load_model_registry_specs,
validate_blue_submission,
)
registry_names = load_model_registry(args.registry)
registry_specs = load_model_registry_specs(args.registry)
if n_models == BLUE_TEAM_REQUIRED_MODELS:
validated_names = validate_blue_submission(
submission_payload,
model_registry=registry_names,
registry_specs=registry_specs,
)
print(f" Validation OK: {len(validated_names)} models accepted")
else:
print(f" Skipping count validation (selected {n_models}, required {BLUE_TEAM_REQUIRED_MODELS})")
# Still check layer matching
mismatches = []
for m in subset_registry:
name = m["model_name"]
if name in registry_specs:
expected = registry_specs[name]["layer"]
submitted = m["layer"]
if submitted != expected:
mismatches.append(f" {name}: layer={submitted} expected={expected}")
if mismatches:
print(f" Layer MISMATCHES:")
for mm in mismatches:
print(mm)
sys.exit(1)
print(f" Layer validation OK for all {n_models} models")
except ImportError as exc:
print(f" Skipping validation (import failed: {exc})")
except Exception as exc:
print(f" Validation FAILED: {exc}")
sys.exit(1)
# Summary
blue_score = avg_cka
print(f"\n{'=' * 60}")
print("SMOKE TEST PASSED")
print(f"{'=' * 60}")
print(f" Models: {n_models}")
print(f" Stimuli: {n_stimuli}")
print(f" Avg CKA: {avg_cka:.6f}")
print(f" Blue score: {blue_score:.6f}")
print(f" Pairs: {len(pairwise)} / {expected_pairs} expected")
print(f" All layers: validated against registry")
print()
if __name__ == "__main__":
main()