Spaces:
Running
Running
| """ | |
| Bar charts comparing base vs. checkpoint persuasion outcomes. | |
| Chart 1: Outcome breakdown per model (shifted toward A / no change / backfire). | |
| Chart 2: Conversion rate among users who initially leaned B (pre > 4). | |
| """ | |
| import json | |
| import os | |
| from huggingface_hub import HfApi | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| REPOS = { | |
| "base": "ehejin/user_study-preference-base_REAL", | |
| "checkpoint": "ehejin/user_study-preference-base_DETAILED_checkpoint", | |
| } | |
| # Cap per study (matches distribution plot) | |
| MAX_REVIEWS = {"base": None, "checkpoint": 50} | |
| # Pastel palette | |
| ROSE = "#E8A9A1" | |
| SAGE = "#A8C5A1" | |
| ROSE_EDGE = "#C87E75" | |
| SAGE_EDGE = "#6F9A67" | |
| # Outcome colors for the stacked/grouped chart | |
| GREEN = "#A8C5A1" # shifted toward A (success) | |
| GREY = "#D4D0C8" # no change | |
| RED = "#E8A9A1" # backfire | |
| def fetch_rows(repo_id: str, token: str, max_reviews: int = None) -> list: | |
| api = HfApi(token=token) | |
| files = list(api.list_repo_files(repo_id=repo_id, repo_type="dataset")) | |
| json_files = [f for f in files if f.startswith("json/") and f.endswith(".json")] | |
| submissions = [] | |
| for filepath in json_files: | |
| local = api.hf_hub_download( | |
| repo_id=repo_id, filename=filepath, | |
| repo_type="dataset", token=token, | |
| ) | |
| with open(local) as f: | |
| submissions.append(json.load(f)) | |
| submissions.sort(key=lambda s: s.get("start_time", 0)) | |
| rows = [] | |
| for sub in submissions: | |
| for item in sub.get("items", []): | |
| pre, post = item.get("pre_rating"), item.get("post_rating") | |
| if pre is not None and post is not None: | |
| rows.append({"pre": pre, "post": post, "delta": post - pre}) | |
| if max_reviews is not None and len(rows) >= max_reviews: | |
| return rows | |
| return rows | |
| def plot_outcome_breakdown(base_rows, ckpt_rows, | |
| out_path="/dfs/scratch1/echoi1/prolific_preferences/outcome_breakdown.png"): | |
| """Grouped bar chart: counts of each outcome for base vs. checkpoint.""" | |
| def counts(rows): | |
| shifted_a = sum(1 for r in rows if r["delta"] < 0) | |
| no_change = sum(1 for r in rows if r["delta"] == 0) | |
| shifted_b = sum(1 for r in rows if r["delta"] > 0) | |
| return shifted_a, no_change, shifted_b | |
| base_a, base_n, base_b = counts(base_rows) | |
| ckpt_a, ckpt_n, ckpt_b = counts(ckpt_rows) | |
| categories = ["Shifted toward A\n(success ✓)", "No change", "Shifted toward B\n(backfire ✗)"] | |
| base_vals = [base_a, base_n, base_b] | |
| ckpt_vals = [ckpt_a, ckpt_n, ckpt_b] | |
| x = np.arange(len(categories)) | |
| bar_width = 0.38 | |
| fig, ax = plt.subplots(figsize=(10, 6), dpi=150) | |
| b1 = ax.bar(x - bar_width/2, base_vals, bar_width, | |
| color=ROSE, edgecolor=ROSE_EDGE, linewidth=1.5, | |
| label=f"Base model (N={len(base_rows)})") | |
| b2 = ax.bar(x + bar_width/2, ckpt_vals, bar_width, | |
| color=SAGE, edgecolor=SAGE_EDGE, linewidth=1.5, | |
| label=f"Fine-tuned (N={len(ckpt_rows)})") | |
| # Add count labels on top of each bar | |
| for bars in (b1, b2): | |
| for bar in bars: | |
| h = bar.get_height() | |
| ax.text(bar.get_x() + bar.get_width()/2, h + 0.3, | |
| f"{int(h)}", ha="center", va="bottom", | |
| fontsize=10, fontweight="bold") | |
| ax.set_xticks(x) | |
| ax.set_xticklabels(categories, fontsize=11) | |
| ax.set_ylabel("Number of item-reviews", fontsize=12) | |
| ax.set_title("Persuasion outcomes by model", | |
| fontsize=14, fontweight="bold", pad=15) | |
| ax.legend(loc="upper right", frameon=True, fontsize=11) | |
| ax.spines["top"].set_visible(False) | |
| ax.spines["right"].set_visible(False) | |
| ax.grid(axis="y", alpha=0.3, linestyle="-", linewidth=0.5) | |
| ax.set_axisbelow(True) | |
| plt.tight_layout() | |
| plt.savefig(out_path, bbox_inches="tight", facecolor="white") | |
| print(f"Saved: {out_path}") | |
| def plot_resistant_conversion(base_rows, ckpt_rows, | |
| out_path="/dfs/scratch1/echoi1/prolific_preferences/resistant_conversion.png"): | |
| """Among users who pre-rated > 4 (leaning B), how many did each model flip toward A?""" | |
| def stats(rows): | |
| resistant = [r for r in rows if r["pre"] > 4] | |
| if not resistant: | |
| return 0, 0, 0.0 | |
| converted = sum(1 for r in resistant if r["delta"] < 0) | |
| return converted, len(resistant), converted / len(resistant) * 100 | |
| base_conv, base_tot, base_pct = stats(base_rows) | |
| ckpt_conv, ckpt_tot, ckpt_pct = stats(ckpt_rows) | |
| fig, ax = plt.subplots(figsize=(8, 6), dpi=150) | |
| x = np.arange(2) | |
| vals = [base_conv, ckpt_conv] | |
| tots = [base_tot, ckpt_tot] | |
| pcts = [base_pct, ckpt_pct] | |
| bars = ax.bar( | |
| x, vals, width=0.5, | |
| color=[ROSE, SAGE], edgecolor=[ROSE_EDGE, SAGE_EDGE], linewidth=1.5, | |
| ) | |
| # Label each bar with "X / Y (Z%)" | |
| for bar, val, tot, pct in zip(bars, vals, tots, pcts): | |
| h = bar.get_height() | |
| ax.text(bar.get_x() + bar.get_width()/2, h + 0.2, | |
| f"{val} / {tot}\n{pct:.1f}%", | |
| ha="center", va="bottom", | |
| fontsize=11, fontweight="bold") | |
| ax.set_xticks(x) | |
| ax.set_xticklabels([f"Base\n(N={base_tot})", f"Fine-tuned\n(N={ckpt_tot})"], fontsize=12) | |
| ax.set_ylabel("Users converted to prefer A", fontsize=12) | |
| ax.set_title( | |
| "Converting resistant users\n(those who initially leaned toward B)", | |
| fontsize=14, fontweight="bold", pad=15, | |
| ) | |
| ax.spines["top"].set_visible(False) | |
| ax.spines["right"].set_visible(False) | |
| ax.grid(axis="y", alpha=0.3, linestyle="-", linewidth=0.5) | |
| ax.set_axisbelow(True) | |
| plt.tight_layout() | |
| plt.savefig(out_path, bbox_inches="tight", facecolor="white") | |
| print(f"Saved: {out_path}") | |
| if __name__ == "__main__": | |
| token = os.getenv("HF_TOKEN") | |
| assert token, "set HF_TOKEN" | |
| print("Fetching base study...") | |
| base_rows = fetch_rows(REPOS["base"], token, MAX_REVIEWS["base"]) | |
| print(f" {len(base_rows)} item-reviews") | |
| print("Fetching checkpoint study...") | |
| ckpt_rows = fetch_rows(REPOS["checkpoint"], token, MAX_REVIEWS["checkpoint"]) | |
| print(f" {len(ckpt_rows)} item-reviews") | |
| plot_outcome_breakdown(base_rows, ckpt_rows) | |
| plot_resistant_conversion(base_rows, ckpt_rows) |