iclr2026-realign-challenge / scripts /red_team_smoke_test.py
siddsuresh97's picture
Initial commit: ICLR 2026 Representational Alignment Challenge
d6c8a4f
from __future__ import annotations
import argparse
import json
import os
import sys
from pathlib import Path
# Add project root to path so we can import app module
PROJECT_ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(PROJECT_ROOT))
import modal
from dotenv import load_dotenv
load_dotenv()
def _require_env(name: str) -> str:
value = os.environ.get(name, "").strip()
if not value:
raise ValueError(f"Missing required env var: {name}")
return value
def _seed_dummy_dataset(app_name: str) -> None:
seed_fn = modal.Function.from_name(app_name, "seed_dummy_dataset")
seed = seed_fn.remote(num_images=6, image_size=224, dataset_name="dummy")
print(f"Seeded dataset at {seed['dataset_root']}")
def _load_stimuli_from_catalog(stimuli_path: str) -> list[dict[str, str]]:
"""Load stimuli from a JSONL catalog file."""
stimuli = []
with open(stimuli_path, "r") as f:
for line in f:
line = line.strip()
if line:
stimuli.append(json.loads(line))
return stimuli
def main() -> None:
parser = argparse.ArgumentParser(description="Red team smoke test")
parser.add_argument(
"--stimuli",
type=str,
default=None,
help="Path to custom stimuli JSONL file. If not provided, uses HACKATHON_STIMULI_CATALOG env var.",
)
parser.add_argument(
"--s3",
action="store_true",
help="Use S3-backed datasets (requires aws-s3-credentials Modal secret).",
)
parser.add_argument(
"--skip-seed",
action="store_true",
help="Skip seeding dummy dataset (use when testing with real images).",
)
args = parser.parse_args()
_require_env("HACKATHON_MODAL_ENABLE")
_require_env("HACKATHON_MODEL_REGISTRY")
data_dir = _require_env("HACKATHON_DATA_DIR")
# Use custom stimuli if provided, otherwise use env var
if args.stimuli:
stimuli_path = Path(args.stimuli).resolve()
if not stimuli_path.exists():
raise ValueError(f"Stimuli file not found: {stimuli_path}")
os.environ["HACKATHON_STIMULI_CATALOG"] = str(stimuli_path)
print(f"Using custom stimuli: {stimuli_path}")
else:
_require_env("HACKATHON_STIMULI_CATALOG")
stimuli_catalog_path = os.environ.get("HACKATHON_STIMULI_CATALOG", "").strip()
# Set S3 mode if requested
if args.s3:
os.environ["HACKATHON_USE_S3"] = "true"
print("S3 mode enabled - will use extract_embeddings_s3 function")
app_name = os.environ.get("HACKATHON_MODAL_APP", "iclr2026-eval")
# Only seed dummy dataset if not using real images
if not args.skip_seed and not args.s3:
_seed_dummy_dataset(app_name)
elif args.skip_seed:
print("Skipping dummy dataset seeding")
elif args.s3:
print("S3 mode: skipping dummy dataset seeding (using real images)")
from app import submit_red
# Build stimulus sets based on mode
if args.s3:
# Load stimuli from catalog for S3 mode
catalog_stimuli = _load_stimuli_from_catalog(stimuli_catalog_path)
if len(catalog_stimuli) < 6:
raise ValueError(f"Need at least 6 stimuli in catalog, found {len(catalog_stimuli)}")
# Create test sets from real catalog stimuli
stimulus_sets = [
catalog_stimuli[0:2], # First 2 stimuli
catalog_stimuli[2:5], # Next 3 stimuli
[catalog_stimuli[1], catalog_stimuli[5]], # Mixed selection
]
print(f"Using {len(catalog_stimuli)} stimuli from catalog for S3 mode")
else:
# Use dummy stimuli for local mode
stimulus_sets = [
[
{"dataset_name": "dummy", "image_identifier": "images/img_0000.png"},
{"dataset_name": "dummy", "image_identifier": "images/img_0001.png"},
],
[
{"dataset_name": "dummy", "image_identifier": "images/img_0002.png"},
{"dataset_name": "dummy", "image_identifier": "images/img_0003.png"},
{"dataset_name": "dummy", "image_identifier": "images/img_0004.png"},
],
[
{"dataset_name": "dummy", "image_identifier": "images/img_0001.png"},
{"dataset_name": "dummy", "image_identifier": "images/img_0005.png"},
],
]
for idx, stimuli in enumerate(stimulus_sets, start=1):
payload = json.dumps({"differentiating_images": stimuli})
submitter = f"red-test-{idx}"
msg, leaderboard, pairwise = submit_red(submitter, payload)
print(f"Submission {idx} message: {msg}")
print(f"Submission {idx} leaderboard: {leaderboard.tail(1).to_dict(orient='records')}")
print(f"Submission {idx} pairwise: {pairwise.to_dict(orient='records')}")
assert not pairwise.empty, "Pairwise table should not be empty."
red_path = Path(data_dir) / "red_submissions.json"
assert red_path.exists(), f"Missing submission file: {red_path}"
print("Red team smoke test complete.")
if __name__ == "__main__":
main()