openenv
leniencybench / scripts /make_v6_plots.py
shreyas-garg's picture
Mirror of GitHub source: OpenEnv-compliant LeniencyBench environment + training scripts
6b4f87f verified
Raw
History Blame Contribute Delete
6.86 kB
"""Generate plots from the v6 partial run logs + Llama baseline.
Outputs (all 150 dpi PNG, axis-labeled, captioned):
outputs/sft_loss_v6.png SFT loss curve over 1000 steps (3B run)
outputs/direction_split_v6.png Pre vs post-SFT direction-split accuracy
outputs/cross_model_baseline.png Llama 3.1 8B vs Qwen 2.5 3B leniency bias
"""
from __future__ import annotations
import json
import re
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
LOG = Path("outputs/v6_full_logs.txt")
OUT = Path("outputs")
OUT.mkdir(exist_ok=True)
# ---------------------------------------------------------------------------
# Parse SFT loss curve from v6 logs
# ---------------------------------------------------------------------------
def parse_sft_losses() -> list[tuple[float, float]]:
"""Return list of (epoch, loss) parsed from `{'loss': X, ..., 'epoch': Y}` lines."""
text = LOG.read_text()
pattern = re.compile(r"\{'loss': '([\d\.e\-]+)'.*?'epoch': '([\d\.]+)'\}")
out: list[tuple[float, float]] = []
for m in pattern.finditer(text):
try:
loss = float(m.group(1))
epoch = float(m.group(2))
except ValueError:
continue
if epoch <= 1.0: # only SFT (epoch \in [0,1])
out.append((epoch, loss))
return out
# ---------------------------------------------------------------------------
# Plot 1: SFT loss curve
# ---------------------------------------------------------------------------
def plot_sft_loss():
pts = parse_sft_losses()
if not pts:
print("[skip] no SFT loss points found in v6 logs")
return
epochs, losses = zip(*pts)
fig, ax = plt.subplots(figsize=(8, 4.4))
ax.plot(epochs, losses, color="#2a6df4", linewidth=1.5, alpha=0.85)
ax.scatter(epochs, losses, color="#2a6df4", s=10, zorder=3)
ax.set_xlabel("Training progress (fraction of 1 epoch over 16,000 samples)")
ax.set_ylabel("Cross-entropy loss")
ax.set_title("SFT loss over training (Qwen 2.5 3B + LoRA r=16, lr=2e-4)")
ax.set_yscale("log")
ax.grid(alpha=0.3, which="both")
fig.tight_layout()
p = OUT / "sft_loss_v6.png"
fig.savefig(p, dpi=150)
plt.close(fig)
print(f"wrote {p} ({len(pts)} points, loss {losses[0]:.3f} -> {losses[-1]:.3f})")
# ---------------------------------------------------------------------------
# Plot 2: Pre vs post-SFT direction-split accuracy on Qwen 2.5 3B
# ---------------------------------------------------------------------------
def plot_direction_split_v6():
# From v6 logs, hardcoded (already extracted):
pre = {"tightening": (0, 23), "loosening": (3, 14)}
post = {"tightening": (21, 23), "loosening": (10, 14)}
labels = ["Tightening", "Loosening"]
pre_acc = [pre[k.lower()][0] / pre[k.lower()][1] * 100 for k in labels]
post_acc = [post[k.lower()][0] / post[k.lower()][1] * 100 for k in labels]
x = range(len(labels))
width = 0.36
fig, ax = plt.subplots(figsize=(8, 4.6))
b1 = ax.bar([i - width/2 for i in x], pre_acc, width, label="Pre-training",
color="#d5342a")
b2 = ax.bar([i + width/2 for i in x], post_acc, width, label="Post-SFT (1 epoch)",
color="#2a6df4")
for bars in (b1, b2):
for b in bars:
ax.text(b.get_x() + b.get_width()/2, b.get_height() + 1.6,
f"{b.get_height():.1f}%",
ha="center", va="bottom", fontsize=11, fontweight="bold")
# Annotate with raw counts under each bar
for i, lbl in enumerate(labels):
c_pre, t_pre = pre[lbl.lower()]
c_post, t_post = post[lbl.lower()]
ax.text(i - width/2, -7, f"{c_pre}/{t_pre}", ha="center", fontsize=9, color="#666")
ax.text(i + width/2, -7, f"{c_post}/{t_post}", ha="center", fontsize=9, color="#666")
ax.set_xticks(list(x))
ax.set_xticklabels(labels)
ax.set_ylabel("Drift-sensitive accuracy")
ax.set_ylim(-10, 105)
ax.set_yticks([0, 20, 40, 60, 80, 100])
ax.set_yticklabels([f"{v}%" for v in [0, 20, 40, 60, 80, 100]])
ax.set_title("Qwen 2.5 3B: leniency bias before vs after one epoch of SFT")
ax.legend(loc="upper left", framealpha=0.95)
ax.grid(alpha=0.2, axis="y")
ax.axhline(0, color="#888", linewidth=0.6)
fig.tight_layout()
p = OUT / "direction_split_v6.png"
fig.savefig(p, dpi=150, bbox_inches="tight")
plt.close(fig)
print(f"wrote {p}")
# ---------------------------------------------------------------------------
# Plot 3: Cross-model baseline (Llama 3.1 8B vs Qwen 2.5 3B, both untrained)
# ---------------------------------------------------------------------------
def plot_cross_model():
# Llama 3.1 8B from eval_results.json
with open("eval_results.json") as f:
data = json.load(f)
from drift_env.policy import drift_direction
llama = {"tightening": [0, 0], "loosening": [0, 0]}
for name, stats in data["summary"]["per_drift"].items():
d = drift_direction(name)
if d in llama:
llama[d][0] += stats["correct"]
llama[d][1] += stats["total"]
qwen = {"tightening": (0, 23), "loosening": (3, 14)} # from v6 logs
labels = ["Tightening", "Loosening"]
llama_acc = [llama[k.lower()][0] / llama[k.lower()][1] * 100 for k in labels]
qwen_acc = [qwen[k.lower()][0] / qwen[k.lower()][1] * 100 for k in labels]
x = range(len(labels))
width = 0.36
fig, ax = plt.subplots(figsize=(8, 4.6))
b1 = ax.bar([i - width/2 for i in x], llama_acc, width,
label="Llama 3.1 8B (untrained, 8 episodes)", color="#7b3fbf")
b2 = ax.bar([i + width/2 for i in x], qwen_acc, width,
label="Qwen 2.5 3B (untrained, 200 samples)", color="#1f8a3b")
for bars in (b1, b2):
for b in bars:
ax.text(b.get_x() + b.get_width()/2, b.get_height() + 1.6,
f"{b.get_height():.1f}%",
ha="center", va="bottom", fontsize=11, fontweight="bold")
ax.set_xticks(list(x))
ax.set_xticklabels(labels)
ax.set_ylabel("Drift-sensitive accuracy")
ax.set_ylim(0, 60)
ax.set_yticks([0, 10, 20, 30, 40, 50])
ax.set_yticklabels([f"{v}%" for v in [0, 10, 20, 30, 40, 50]])
ax.set_title("Leniency bias is direction-asymmetric across model families")
ax.legend(loc="upper left", framealpha=0.95)
ax.grid(alpha=0.2, axis="y")
fig.tight_layout()
p = OUT / "cross_model_baseline.png"
fig.savefig(p, dpi=150, bbox_inches="tight")
plt.close(fig)
print(f"wrote {p}")
# ---------------------------------------------------------------------------
if __name__ == "__main__":
plot_sft_loss()
plot_direction_split_v6()
plot_cross_model()
print("done")