""" Bar charts comparing base vs. checkpoint persuasion outcomes. Chart 1: Outcome breakdown per model (shifted toward A / no change / backfire). Chart 2: Conversion rate among users who initially leaned B (pre > 4). """ import json import os from huggingface_hub import HfApi import numpy as np import matplotlib.pyplot as plt REPOS = { "base": "ehejin/user_study-preference-base_REAL", "checkpoint": "ehejin/user_study-preference-base_DETAILED_checkpoint", } # Cap per study (matches distribution plot) MAX_REVIEWS = {"base": None, "checkpoint": 50} # Pastel palette ROSE = "#E8A9A1" SAGE = "#A8C5A1" ROSE_EDGE = "#C87E75" SAGE_EDGE = "#6F9A67" # Outcome colors for the stacked/grouped chart GREEN = "#A8C5A1" # shifted toward A (success) GREY = "#D4D0C8" # no change RED = "#E8A9A1" # backfire def fetch_rows(repo_id: str, token: str, max_reviews: int = None) -> list: api = HfApi(token=token) files = list(api.list_repo_files(repo_id=repo_id, repo_type="dataset")) json_files = [f for f in files if f.startswith("json/") and f.endswith(".json")] submissions = [] for filepath in json_files: local = api.hf_hub_download( repo_id=repo_id, filename=filepath, repo_type="dataset", token=token, ) with open(local) as f: submissions.append(json.load(f)) submissions.sort(key=lambda s: s.get("start_time", 0)) rows = [] for sub in submissions: for item in sub.get("items", []): pre, post = item.get("pre_rating"), item.get("post_rating") if pre is not None and post is not None: rows.append({"pre": pre, "post": post, "delta": post - pre}) if max_reviews is not None and len(rows) >= max_reviews: return rows return rows def plot_outcome_breakdown(base_rows, ckpt_rows, out_path="/dfs/scratch1/echoi1/prolific_preferences/outcome_breakdown.png"): """Grouped bar chart: counts of each outcome for base vs. checkpoint.""" def counts(rows): shifted_a = sum(1 for r in rows if r["delta"] < 0) no_change = sum(1 for r in rows if r["delta"] == 0) shifted_b = sum(1 for r in rows if r["delta"] > 0) return shifted_a, no_change, shifted_b base_a, base_n, base_b = counts(base_rows) ckpt_a, ckpt_n, ckpt_b = counts(ckpt_rows) categories = ["Shifted toward A\n(success ✓)", "No change", "Shifted toward B\n(backfire ✗)"] base_vals = [base_a, base_n, base_b] ckpt_vals = [ckpt_a, ckpt_n, ckpt_b] x = np.arange(len(categories)) bar_width = 0.38 fig, ax = plt.subplots(figsize=(10, 6), dpi=150) b1 = ax.bar(x - bar_width/2, base_vals, bar_width, color=ROSE, edgecolor=ROSE_EDGE, linewidth=1.5, label=f"Base model (N={len(base_rows)})") b2 = ax.bar(x + bar_width/2, ckpt_vals, bar_width, color=SAGE, edgecolor=SAGE_EDGE, linewidth=1.5, label=f"Fine-tuned (N={len(ckpt_rows)})") # Add count labels on top of each bar for bars in (b1, b2): for bar in bars: h = bar.get_height() ax.text(bar.get_x() + bar.get_width()/2, h + 0.3, f"{int(h)}", ha="center", va="bottom", fontsize=10, fontweight="bold") ax.set_xticks(x) ax.set_xticklabels(categories, fontsize=11) ax.set_ylabel("Number of item-reviews", fontsize=12) ax.set_title("Persuasion outcomes by model", fontsize=14, fontweight="bold", pad=15) ax.legend(loc="upper right", frameon=True, fontsize=11) ax.spines["top"].set_visible(False) ax.spines["right"].set_visible(False) ax.grid(axis="y", alpha=0.3, linestyle="-", linewidth=0.5) ax.set_axisbelow(True) plt.tight_layout() plt.savefig(out_path, bbox_inches="tight", facecolor="white") print(f"Saved: {out_path}") def plot_resistant_conversion(base_rows, ckpt_rows, out_path="/dfs/scratch1/echoi1/prolific_preferences/resistant_conversion.png"): """Among users who pre-rated > 4 (leaning B), how many did each model flip toward A?""" def stats(rows): resistant = [r for r in rows if r["pre"] > 4] if not resistant: return 0, 0, 0.0 converted = sum(1 for r in resistant if r["delta"] < 0) return converted, len(resistant), converted / len(resistant) * 100 base_conv, base_tot, base_pct = stats(base_rows) ckpt_conv, ckpt_tot, ckpt_pct = stats(ckpt_rows) fig, ax = plt.subplots(figsize=(8, 6), dpi=150) x = np.arange(2) vals = [base_conv, ckpt_conv] tots = [base_tot, ckpt_tot] pcts = [base_pct, ckpt_pct] bars = ax.bar( x, vals, width=0.5, color=[ROSE, SAGE], edgecolor=[ROSE_EDGE, SAGE_EDGE], linewidth=1.5, ) # Label each bar with "X / Y (Z%)" for bar, val, tot, pct in zip(bars, vals, tots, pcts): h = bar.get_height() ax.text(bar.get_x() + bar.get_width()/2, h + 0.2, f"{val} / {tot}\n{pct:.1f}%", ha="center", va="bottom", fontsize=11, fontweight="bold") ax.set_xticks(x) ax.set_xticklabels([f"Base\n(N={base_tot})", f"Fine-tuned\n(N={ckpt_tot})"], fontsize=12) ax.set_ylabel("Users converted to prefer A", fontsize=12) ax.set_title( "Converting resistant users\n(those who initially leaned toward B)", fontsize=14, fontweight="bold", pad=15, ) ax.spines["top"].set_visible(False) ax.spines["right"].set_visible(False) ax.grid(axis="y", alpha=0.3, linestyle="-", linewidth=0.5) ax.set_axisbelow(True) plt.tight_layout() plt.savefig(out_path, bbox_inches="tight", facecolor="white") print(f"Saved: {out_path}") if __name__ == "__main__": token = os.getenv("HF_TOKEN") assert token, "set HF_TOKEN" print("Fetching base study...") base_rows = fetch_rows(REPOS["base"], token, MAX_REVIEWS["base"]) print(f" {len(base_rows)} item-reviews") print("Fetching checkpoint study...") ckpt_rows = fetch_rows(REPOS["checkpoint"], token, MAX_REVIEWS["checkpoint"]) print(f" {len(ckpt_rows)} item-reviews") plot_outcome_breakdown(base_rows, ckpt_rows) plot_resistant_conversion(base_rows, ckpt_rows)