Mirror of GitHub source: OpenEnv-compliant LeniencyBench environment + training scripts
6b4f87f verified | """Generate plots from the v6 partial run logs + Llama baseline. | |
| Outputs (all 150 dpi PNG, axis-labeled, captioned): | |
| outputs/sft_loss_v6.png SFT loss curve over 1000 steps (3B run) | |
| outputs/direction_split_v6.png Pre vs post-SFT direction-split accuracy | |
| outputs/cross_model_baseline.png Llama 3.1 8B vs Qwen 2.5 3B leniency bias | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import re | |
| from pathlib import Path | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| LOG = Path("outputs/v6_full_logs.txt") | |
| OUT = Path("outputs") | |
| OUT.mkdir(exist_ok=True) | |
| # --------------------------------------------------------------------------- | |
| # Parse SFT loss curve from v6 logs | |
| # --------------------------------------------------------------------------- | |
| def parse_sft_losses() -> list[tuple[float, float]]: | |
| """Return list of (epoch, loss) parsed from `{'loss': X, ..., 'epoch': Y}` lines.""" | |
| text = LOG.read_text() | |
| pattern = re.compile(r"\{'loss': '([\d\.e\-]+)'.*?'epoch': '([\d\.]+)'\}") | |
| out: list[tuple[float, float]] = [] | |
| for m in pattern.finditer(text): | |
| try: | |
| loss = float(m.group(1)) | |
| epoch = float(m.group(2)) | |
| except ValueError: | |
| continue | |
| if epoch <= 1.0: # only SFT (epoch \in [0,1]) | |
| out.append((epoch, loss)) | |
| return out | |
| # --------------------------------------------------------------------------- | |
| # Plot 1: SFT loss curve | |
| # --------------------------------------------------------------------------- | |
| def plot_sft_loss(): | |
| pts = parse_sft_losses() | |
| if not pts: | |
| print("[skip] no SFT loss points found in v6 logs") | |
| return | |
| epochs, losses = zip(*pts) | |
| fig, ax = plt.subplots(figsize=(8, 4.4)) | |
| ax.plot(epochs, losses, color="#2a6df4", linewidth=1.5, alpha=0.85) | |
| ax.scatter(epochs, losses, color="#2a6df4", s=10, zorder=3) | |
| ax.set_xlabel("Training progress (fraction of 1 epoch over 16,000 samples)") | |
| ax.set_ylabel("Cross-entropy loss") | |
| ax.set_title("SFT loss over training (Qwen 2.5 3B + LoRA r=16, lr=2e-4)") | |
| ax.set_yscale("log") | |
| ax.grid(alpha=0.3, which="both") | |
| fig.tight_layout() | |
| p = OUT / "sft_loss_v6.png" | |
| fig.savefig(p, dpi=150) | |
| plt.close(fig) | |
| print(f"wrote {p} ({len(pts)} points, loss {losses[0]:.3f} -> {losses[-1]:.3f})") | |
| # --------------------------------------------------------------------------- | |
| # Plot 2: Pre vs post-SFT direction-split accuracy on Qwen 2.5 3B | |
| # --------------------------------------------------------------------------- | |
| def plot_direction_split_v6(): | |
| # From v6 logs, hardcoded (already extracted): | |
| pre = {"tightening": (0, 23), "loosening": (3, 14)} | |
| post = {"tightening": (21, 23), "loosening": (10, 14)} | |
| labels = ["Tightening", "Loosening"] | |
| pre_acc = [pre[k.lower()][0] / pre[k.lower()][1] * 100 for k in labels] | |
| post_acc = [post[k.lower()][0] / post[k.lower()][1] * 100 for k in labels] | |
| x = range(len(labels)) | |
| width = 0.36 | |
| fig, ax = plt.subplots(figsize=(8, 4.6)) | |
| b1 = ax.bar([i - width/2 for i in x], pre_acc, width, label="Pre-training", | |
| color="#d5342a") | |
| b2 = ax.bar([i + width/2 for i in x], post_acc, width, label="Post-SFT (1 epoch)", | |
| color="#2a6df4") | |
| for bars in (b1, b2): | |
| for b in bars: | |
| ax.text(b.get_x() + b.get_width()/2, b.get_height() + 1.6, | |
| f"{b.get_height():.1f}%", | |
| ha="center", va="bottom", fontsize=11, fontweight="bold") | |
| # Annotate with raw counts under each bar | |
| for i, lbl in enumerate(labels): | |
| c_pre, t_pre = pre[lbl.lower()] | |
| c_post, t_post = post[lbl.lower()] | |
| ax.text(i - width/2, -7, f"{c_pre}/{t_pre}", ha="center", fontsize=9, color="#666") | |
| ax.text(i + width/2, -7, f"{c_post}/{t_post}", ha="center", fontsize=9, color="#666") | |
| ax.set_xticks(list(x)) | |
| ax.set_xticklabels(labels) | |
| ax.set_ylabel("Drift-sensitive accuracy") | |
| ax.set_ylim(-10, 105) | |
| ax.set_yticks([0, 20, 40, 60, 80, 100]) | |
| ax.set_yticklabels([f"{v}%" for v in [0, 20, 40, 60, 80, 100]]) | |
| ax.set_title("Qwen 2.5 3B: leniency bias before vs after one epoch of SFT") | |
| ax.legend(loc="upper left", framealpha=0.95) | |
| ax.grid(alpha=0.2, axis="y") | |
| ax.axhline(0, color="#888", linewidth=0.6) | |
| fig.tight_layout() | |
| p = OUT / "direction_split_v6.png" | |
| fig.savefig(p, dpi=150, bbox_inches="tight") | |
| plt.close(fig) | |
| print(f"wrote {p}") | |
| # --------------------------------------------------------------------------- | |
| # Plot 3: Cross-model baseline (Llama 3.1 8B vs Qwen 2.5 3B, both untrained) | |
| # --------------------------------------------------------------------------- | |
| def plot_cross_model(): | |
| # Llama 3.1 8B from eval_results.json | |
| with open("eval_results.json") as f: | |
| data = json.load(f) | |
| from drift_env.policy import drift_direction | |
| llama = {"tightening": [0, 0], "loosening": [0, 0]} | |
| for name, stats in data["summary"]["per_drift"].items(): | |
| d = drift_direction(name) | |
| if d in llama: | |
| llama[d][0] += stats["correct"] | |
| llama[d][1] += stats["total"] | |
| qwen = {"tightening": (0, 23), "loosening": (3, 14)} # from v6 logs | |
| labels = ["Tightening", "Loosening"] | |
| llama_acc = [llama[k.lower()][0] / llama[k.lower()][1] * 100 for k in labels] | |
| qwen_acc = [qwen[k.lower()][0] / qwen[k.lower()][1] * 100 for k in labels] | |
| x = range(len(labels)) | |
| width = 0.36 | |
| fig, ax = plt.subplots(figsize=(8, 4.6)) | |
| b1 = ax.bar([i - width/2 for i in x], llama_acc, width, | |
| label="Llama 3.1 8B (untrained, 8 episodes)", color="#7b3fbf") | |
| b2 = ax.bar([i + width/2 for i in x], qwen_acc, width, | |
| label="Qwen 2.5 3B (untrained, 200 samples)", color="#1f8a3b") | |
| for bars in (b1, b2): | |
| for b in bars: | |
| ax.text(b.get_x() + b.get_width()/2, b.get_height() + 1.6, | |
| f"{b.get_height():.1f}%", | |
| ha="center", va="bottom", fontsize=11, fontweight="bold") | |
| ax.set_xticks(list(x)) | |
| ax.set_xticklabels(labels) | |
| ax.set_ylabel("Drift-sensitive accuracy") | |
| ax.set_ylim(0, 60) | |
| ax.set_yticks([0, 10, 20, 30, 40, 50]) | |
| ax.set_yticklabels([f"{v}%" for v in [0, 10, 20, 30, 40, 50]]) | |
| ax.set_title("Leniency bias is direction-asymmetric across model families") | |
| ax.legend(loc="upper left", framealpha=0.95) | |
| ax.grid(alpha=0.2, axis="y") | |
| fig.tight_layout() | |
| p = OUT / "cross_model_baseline.png" | |
| fig.savefig(p, dpi=150, bbox_inches="tight") | |
| plt.close(fig) | |
| print(f"wrote {p}") | |
| # --------------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| plot_sft_loss() | |
| plot_direction_split_v6() | |
| plot_cross_model() | |
| print("done") | |