| """ |
| Launch separator-attention and random-target intervention experiments across 8 GPUs. |
| Each GPU runs N trials, collecting per-number success data. |
| Assembles two plots: |
| 1. Success probability per number when intervening at separator-attending positions |
| 2. Success probability per number when intervening with random target |
| """ |
| import os |
| import subprocess |
| import sys |
| import time |
| import numpy as np |
| import matplotlib |
| matplotlib.use('Agg') |
| import matplotlib.pyplot as plt |
|
|
| SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) |
| OUTPUT_BASE = os.path.join(SCRIPT_DIR, 'outputs') |
| PLOT_DIR = os.path.join(OUTPUT_BASE, |
| 'plots_V256_B16_LR3e-2_MI200000_E64_H1_L2_ds1337_is1337_ckpt200000') |
| TMP_DIR = os.path.join(OUTPUT_BASE, 'tmp_results', 'separator_random') |
| CKPT = os.path.join(SCRIPT_DIR, |
| 'sortgpt_k16_methfixed_mlp1_L2_N256_E64_pos0_fln1_wd0p0_lr0p03_dseed1337_iseed1337__final.pt') |
|
|
| NUM_GPUS = 8 |
| TRIALS_PER_GPU = 1000 |
| INTENSITIES = [2.0, 6.0, 10.0] |
| TAG = 'V=256 B=16 lr=0.03 iters=200000 dseed=1337 iseed=1337' |
|
|
|
|
| def launch_workers(): |
| os.makedirs(TMP_DIR, exist_ok=True) |
| log_dir = os.path.join(OUTPUT_BASE, 'separator_logs') |
| os.makedirs(log_dir, exist_ok=True) |
|
|
| procs = [] |
| for g in range(NUM_GPUS): |
| out = os.path.join(TMP_DIR, f'gpu{g}.npz') |
| if os.path.exists(out): |
| continue |
| lf = open(os.path.join(log_dir, f'gpu{g}.log'), 'w') |
| proc = subprocess.Popen( |
| [sys.executable, os.path.join(SCRIPT_DIR, 'separator_intervention_worker.py'), |
| '--ckpt', CKPT, '--gpu', str(g), |
| '--trials', str(TRIALS_PER_GPU), '--out', out], |
| stdout=lf, stderr=subprocess.STDOUT, cwd=SCRIPT_DIR) |
| procs.append((proc, lf, g)) |
| return procs |
|
|
|
|
| def wait_for_workers(procs): |
| t0 = time.time() |
| while any(p.poll() is None for p, _, _ in procs): |
| time.sleep(10) |
| done = sum(1 for g in range(NUM_GPUS) |
| if os.path.exists(os.path.join(TMP_DIR, f'gpu{g}.npz'))) |
| print(f" [{time.time()-t0:.0f}s] {done}/{NUM_GPUS} GPUs done", flush=True) |
|
|
| for proc, lf, g in procs: |
| lf.close() |
| if proc.returncode != 0: |
| print(f" WARN: GPU {g} exited with code {proc.returncode}", flush=True) |
|
|
| print(f"All workers done in {time.time()-t0:.0f}s", flush=True) |
|
|
|
|
| def load_and_combine(): |
| all_sep, all_rand = [], [] |
| for g in range(NUM_GPUS): |
| f = os.path.join(TMP_DIR, f'gpu{g}.npz') |
| if not os.path.exists(f): |
| continue |
| d = np.load(f) |
| if len(d['sep_data']) > 0: |
| all_sep.append(d['sep_data']) |
| if len(d['rand_data']) > 0: |
| all_rand.append(d['rand_data']) |
|
|
| sep = np.concatenate(all_sep) if all_sep else np.empty((0, 3), dtype=np.int32) |
| rand = np.concatenate(all_rand) if all_rand else np.empty((0, 3), dtype=np.int32) |
| print(f"Combined: sep={len(sep)} records, rand={len(rand)} records") |
| return sep, rand |
|
|
|
|
| def plot_per_number(data, title_prefix, filename, tag): |
| """Plot success probability per number for each intensity.""" |
| os.makedirs(PLOT_DIR, exist_ok=True) |
|
|
| colors = {2.0: '#1f77b4', 6.0: '#ff7f0e', 10.0: '#d62728'} |
| fig, axes = plt.subplots(2, 1, figsize=(14, 8), |
| gridspec_kw={'height_ratios': [3, 1]}) |
|
|
| ax = axes[0] |
| for intens in INTENSITIES: |
| mask = data[:, 1] == intens |
| subset = data[mask] |
| if len(subset) == 0: |
| continue |
|
|
| xs, ys = [], [] |
| for n_val in range(256): |
| nm = subset[:, 0] == n_val |
| count = nm.sum() |
| if count >= 10: |
| xs.append(n_val) |
| ys.append(subset[nm, 2].mean()) |
|
|
| ax.plot(xs, ys, color=colors.get(intens, '#333'), |
| linewidth=0.8, alpha=0.6, label=f'raw int={intens}') |
|
|
| if len(xs) >= 11: |
| raw_arr = np.full(256, np.nan) |
| for x, y in zip(xs, ys): |
| raw_arr[x] = y |
| win = 11 |
| padded = np.nan_to_num(raw_arr, nan=0.5) |
| smoothed = np.convolve(padded, np.ones(win) / win, mode='same') |
| valid = ~np.isnan(raw_arr) |
| ax.plot(np.arange(256)[valid], smoothed[valid], |
| color=colors.get(intens, '#333'), linewidth=2.5, |
| linestyle='--', label=f'smoothed int={intens}') |
|
|
| ax.set_ylabel('Success Probability', fontsize=12) |
| ax.set_title(f'{title_prefix} (Layer 0)\n{tag}', fontsize=12, fontweight='bold') |
| ax.legend(fontsize=8, ncol=2, loc='lower left') |
| ax.grid(True, alpha=0.3) |
| ax.set_ylim(-0.05, 1.1) |
| ax.set_xlim(0, 255) |
|
|
| ax2 = axes[1] |
| max_intens = max(INTENSITIES) |
| mask_hi = data[:, 1] == max_intens |
| counts = np.array([(mask_hi & (data[:, 0] == n)).sum() for n in range(256)]) |
| ax2.bar(range(256), counts, width=1, color='#666', alpha=0.5) |
| ax2.set_xlabel('Number in Vocabulary', fontsize=12) |
| ax2.set_ylabel('Sample Count', fontsize=10) |
| ax2.set_xlim(0, 255) |
| ax2.grid(True, alpha=0.2) |
|
|
| fig.tight_layout() |
| out_path = os.path.join(PLOT_DIR, filename) |
| fig.savefig(out_path, dpi=200, bbox_inches='tight') |
| plt.close() |
| print(f"Plot saved: {out_path}") |
|
|
|
|
| def main(): |
| t0 = time.time() |
| print("=" * 60) |
| print("SEPARATOR & RANDOM INTERVENTION EXPERIMENT (Layer 0)") |
| print("=" * 60) |
|
|
| cached = sum(1 for g in range(NUM_GPUS) |
| if os.path.exists(os.path.join(TMP_DIR, f'gpu{g}.npz'))) |
| print(f"Cached GPUs: {cached}/{NUM_GPUS}") |
|
|
| procs = launch_workers() |
| if procs: |
| print(f"Launched {len(procs)} workers") |
| wait_for_workers(procs) |
| else: |
| print("All workers already cached") |
|
|
| sep, rand = load_and_combine() |
|
|
| if len(sep) > 0: |
| plot_per_number(sep, |
| 'Intervention Success when Attending to Separator', |
| 'intervention_pernumber_separator_layer0.png', TAG) |
| else: |
| print("WARNING: No separator-attending data collected!") |
|
|
| if len(rand) > 0: |
| plot_per_number(rand, |
| 'Intervention Success with Random Target', |
| 'intervention_pernumber_random_layer0.png', TAG) |
|
|
| elapsed = time.time() - t0 |
| print(f"\nTotal time: {elapsed:.0f}s ({elapsed/60:.1f}m)") |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|