| |
| """End-to-end smoke test mimicking a blue team submission via Modal. |
| |
| Picks 20 models from the blue team registry and 50 images from the |
| stimuli catalog, runs the full extraction + CKA scoring pipeline |
| on Modal, and verifies the results. |
| |
| Usage: |
| python scripts/smoke_test_submission.py \ |
| --registry configs/blue_team_model_registry.json \ |
| --stimuli configs/blue_team_images.jsonl |
| |
| # With custom counts: |
| python scripts/smoke_test_submission.py \ |
| --num-models 5 --num-stimuli 10 |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import hashlib |
| import json |
| import sys |
| from pathlib import Path |
|
|
|
|
| def load_registry(path: str) -> list[dict]: |
| data = json.loads(Path(path).read_text()) |
| if isinstance(data, dict): |
| return data["models"] |
| return data |
|
|
|
|
| def load_stimuli(path: str) -> list[dict]: |
| p = Path(path) |
| if p.suffix == ".jsonl": |
| return [json.loads(line) for line in p.read_text().splitlines() if line.strip()] |
| data = json.loads(p.read_text()) |
| if isinstance(data, dict): |
| return data["stimuli"] |
| return data |
|
|
|
|
| def cache_key_from_payload(registry: list[dict], stimuli: list[dict]) -> str: |
| payload = {"registry": registry, "stimuli": stimuli} |
| encoded = json.dumps(payload, sort_keys=True).encode("utf-8") |
| digest = hashlib.sha1(encoded).hexdigest()[:12] |
| return f"smoke_{digest}" |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="End-to-end blue team submission smoke test on Modal") |
| parser.add_argument("--registry", default="configs/blue_team_model_registry.json", |
| help="Path to blue team model registry JSON") |
| parser.add_argument("--stimuli", default="configs/blue_team_images.jsonl", |
| help="Path to stimuli catalog JSONL") |
| parser.add_argument("--num-models", type=int, default=20, |
| help="Number of models to select (default: 20)") |
| parser.add_argument("--num-stimuli", type=int, default=50, |
| help="Number of stimuli to select (default: 50)") |
| parser.add_argument("--app-name", default="iclr2026-eval", |
| help="Modal app name") |
| parser.add_argument("--batch-size", type=int, default=64, |
| help="Batch size for embedding extraction") |
| args = parser.parse_args() |
|
|
| |
| full_registry = load_registry(args.registry) |
| full_stimuli = load_stimuli(args.stimuli) |
|
|
| print(f"Full registry: {len(full_registry)} models") |
| print(f"Full stimuli: {len(full_stimuli)} images") |
|
|
| |
| n_models = min(args.num_models, len(full_registry)) |
| if n_models < len(full_registry): |
| step = len(full_registry) / n_models |
| indices = [int(i * step) for i in range(n_models)] |
| subset_registry = [full_registry[i] for i in indices] |
| else: |
| subset_registry = full_registry |
|
|
| n_stimuli = min(args.num_stimuli, len(full_stimuli)) |
| subset_stimuli = full_stimuli[:n_stimuli] |
|
|
| model_names = [m["model_name"] for m in subset_registry] |
| print(f"\nSelected {len(subset_registry)} models:") |
| for m in model_names: |
| print(f" - {m}") |
| print(f"\nSelected {len(subset_stimuli)} stimuli (first {n_stimuli} from catalog)") |
|
|
| |
| cache_key = cache_key_from_payload(subset_registry, subset_stimuli) |
| print(f"\nCache key: {cache_key}") |
|
|
| |
| print("\nConnecting to Modal...") |
| import modal |
|
|
| extract_fn = modal.Function.from_name(args.app_name, "extract_embeddings_s3") |
| score_fn = modal.Function.from_name(args.app_name, "compute_pairwise_cka") |
| print("Modal connection OK") |
|
|
| |
| print(f"\n{'=' * 60}") |
| print("STEP 1: Extracting embeddings...") |
| print(f"{'=' * 60}") |
|
|
| extract_result = extract_fn.remote( |
| model_registry=subset_registry, |
| stimuli=subset_stimuli, |
| cache_key=cache_key, |
| batch_size=args.batch_size, |
| reuse_cache=False, |
| ) |
|
|
| print(f"\nExtraction result:") |
| print(f" cache_key: {extract_result.get('cache_key')}") |
| print(f" num_stimuli: {extract_result.get('num_stimuli')}") |
| print(f" models: {len(extract_result.get('models', []))}") |
|
|
| for m in extract_result.get("models", []): |
| print( |
| f" {m['model_name']:45s} layer={m.get('layer', '?'):25s} " |
| f"dim={m.get('dim', '?'):>6} samples={m.get('num_samples', '?')}" |
| ) |
|
|
| |
| errors = [] |
| for m in extract_result.get("models", []): |
| if m.get("num_samples") != n_stimuli: |
| errors.append(f" {m['model_name']}: expected {n_stimuli} samples, got {m.get('num_samples')}") |
| if m.get("dim", 0) <= 0: |
| errors.append(f" {m['model_name']}: invalid dim {m.get('dim')}") |
|
|
| if errors: |
| print(f"\nExtraction ERRORS:") |
| for e in errors: |
| print(e) |
| sys.exit(1) |
| print(f"\nExtraction: ALL {len(extract_result.get('models', []))} models OK") |
|
|
| |
| print(f"\n{'=' * 60}") |
| print("STEP 2: Computing pairwise CKA...") |
| print(f"{'=' * 60}") |
|
|
| cka_result = score_fn.remote( |
| cache_key=cache_key, |
| model_names=model_names, |
| ) |
|
|
| avg_cka = cka_result.get("avg_cka", 0.0) |
| pairwise = cka_result.get("pairwise", []) |
|
|
| print(f"\nCKA results:") |
| print(f" avg_cka: {avg_cka:.6f}") |
| print(f" num_pairs: {len(pairwise)}") |
|
|
| expected_pairs = n_models * (n_models - 1) // 2 |
| if len(pairwise) != expected_pairs: |
| print(f" WARNING: expected {expected_pairs} pairs, got {len(pairwise)}") |
|
|
| |
| sorted_pairs = sorted(pairwise, key=lambda x: x.get("cka", 0), reverse=True) |
| print(f"\n Top 5 most similar pairs:") |
| for p in sorted_pairs[:5]: |
| ma = p.get("model_a", "?") |
| mb = p.get("model_b", "?") |
| cka = p.get("cka", 0.0) |
| print(f" {ma:40s} <-> {mb:40s} CKA={cka:.6f}") |
|
|
| print(f"\n Bottom 5 least similar pairs:") |
| for p in sorted_pairs[-5:]: |
| ma = p.get("model_a", "?") |
| mb = p.get("model_b", "?") |
| cka = p.get("cka", 0.0) |
| print(f" {ma:40s} <-> {mb:40s} CKA={cka:.6f}") |
|
|
| |
| cka_errors = [] |
| for p in pairwise: |
| cka_val = p.get("cka") |
| if cka_val is None: |
| cka_errors.append(" Missing CKA value for a pair") |
| elif not (-0.1 <= cka_val <= 1.5): |
| ma = p.get("model_a", "?") |
| mb = p.get("model_b", "?") |
| cka_errors.append(f" {ma} <-> {mb}: CKA={cka_val} out of expected range") |
|
|
| if cka_errors: |
| print(f"\nCKA ERRORS:") |
| for e in cka_errors: |
| print(e) |
| sys.exit(1) |
|
|
| |
| print(f"\n{'=' * 60}") |
| print("STEP 3: Validating submission format...") |
| print(f"{'=' * 60}") |
|
|
| submission_payload = { |
| "models": [ |
| {"model_name": m["model_name"], "layer_name": m["layer"]} |
| for m in subset_registry |
| ] |
| } |
| print(f"\n Sample submission entry: {json.dumps(submission_payload['models'][0])}") |
|
|
| try: |
| from src.hackathon.validation import ( |
| BLUE_TEAM_REQUIRED_MODELS, |
| load_model_registry, |
| load_model_registry_specs, |
| validate_blue_submission, |
| ) |
|
|
| registry_names = load_model_registry(args.registry) |
| registry_specs = load_model_registry_specs(args.registry) |
|
|
| if n_models == BLUE_TEAM_REQUIRED_MODELS: |
| validated_names = validate_blue_submission( |
| submission_payload, |
| model_registry=registry_names, |
| registry_specs=registry_specs, |
| ) |
| print(f" Validation OK: {len(validated_names)} models accepted") |
| else: |
| print(f" Skipping count validation (selected {n_models}, required {BLUE_TEAM_REQUIRED_MODELS})") |
| |
| mismatches = [] |
| for m in subset_registry: |
| name = m["model_name"] |
| if name in registry_specs: |
| expected = registry_specs[name]["layer"] |
| submitted = m["layer"] |
| if submitted != expected: |
| mismatches.append(f" {name}: layer={submitted} expected={expected}") |
| if mismatches: |
| print(f" Layer MISMATCHES:") |
| for mm in mismatches: |
| print(mm) |
| sys.exit(1) |
| print(f" Layer validation OK for all {n_models} models") |
| except ImportError as exc: |
| print(f" Skipping validation (import failed: {exc})") |
| except Exception as exc: |
| print(f" Validation FAILED: {exc}") |
| sys.exit(1) |
|
|
| |
| blue_score = avg_cka |
| print(f"\n{'=' * 60}") |
| print("SMOKE TEST PASSED") |
| print(f"{'=' * 60}") |
| print(f" Models: {n_models}") |
| print(f" Stimuli: {n_stimuli}") |
| print(f" Avg CKA: {avg_cka:.6f}") |
| print(f" Blue score: {blue_score:.6f}") |
| print(f" Pairs: {len(pairwise)} / {expected_pairs} expected") |
| print(f" All layers: validated against registry") |
| print() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|