Spaces:
Running
Running
File size: 6,391 Bytes
0f4326e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 | """
Bar charts comparing base vs. checkpoint persuasion outcomes.
Chart 1: Outcome breakdown per model (shifted toward A / no change / backfire).
Chart 2: Conversion rate among users who initially leaned B (pre > 4).
"""
import json
import os
from huggingface_hub import HfApi
import numpy as np
import matplotlib.pyplot as plt
REPOS = {
"base": "ehejin/user_study-preference-base_REAL",
"checkpoint": "ehejin/user_study-preference-base_DETAILED_checkpoint",
}
# Cap per study (matches distribution plot)
MAX_REVIEWS = {"base": None, "checkpoint": 50}
# Pastel palette
ROSE = "#E8A9A1"
SAGE = "#A8C5A1"
ROSE_EDGE = "#C87E75"
SAGE_EDGE = "#6F9A67"
# Outcome colors for the stacked/grouped chart
GREEN = "#A8C5A1" # shifted toward A (success)
GREY = "#D4D0C8" # no change
RED = "#E8A9A1" # backfire
def fetch_rows(repo_id: str, token: str, max_reviews: int = None) -> list:
api = HfApi(token=token)
files = list(api.list_repo_files(repo_id=repo_id, repo_type="dataset"))
json_files = [f for f in files if f.startswith("json/") and f.endswith(".json")]
submissions = []
for filepath in json_files:
local = api.hf_hub_download(
repo_id=repo_id, filename=filepath,
repo_type="dataset", token=token,
)
with open(local) as f:
submissions.append(json.load(f))
submissions.sort(key=lambda s: s.get("start_time", 0))
rows = []
for sub in submissions:
for item in sub.get("items", []):
pre, post = item.get("pre_rating"), item.get("post_rating")
if pre is not None and post is not None:
rows.append({"pre": pre, "post": post, "delta": post - pre})
if max_reviews is not None and len(rows) >= max_reviews:
return rows
return rows
def plot_outcome_breakdown(base_rows, ckpt_rows,
out_path="/dfs/scratch1/echoi1/prolific_preferences/outcome_breakdown.png"):
"""Grouped bar chart: counts of each outcome for base vs. checkpoint."""
def counts(rows):
shifted_a = sum(1 for r in rows if r["delta"] < 0)
no_change = sum(1 for r in rows if r["delta"] == 0)
shifted_b = sum(1 for r in rows if r["delta"] > 0)
return shifted_a, no_change, shifted_b
base_a, base_n, base_b = counts(base_rows)
ckpt_a, ckpt_n, ckpt_b = counts(ckpt_rows)
categories = ["Shifted toward A\n(success ✓)", "No change", "Shifted toward B\n(backfire ✗)"]
base_vals = [base_a, base_n, base_b]
ckpt_vals = [ckpt_a, ckpt_n, ckpt_b]
x = np.arange(len(categories))
bar_width = 0.38
fig, ax = plt.subplots(figsize=(10, 6), dpi=150)
b1 = ax.bar(x - bar_width/2, base_vals, bar_width,
color=ROSE, edgecolor=ROSE_EDGE, linewidth=1.5,
label=f"Base model (N={len(base_rows)})")
b2 = ax.bar(x + bar_width/2, ckpt_vals, bar_width,
color=SAGE, edgecolor=SAGE_EDGE, linewidth=1.5,
label=f"Fine-tuned (N={len(ckpt_rows)})")
# Add count labels on top of each bar
for bars in (b1, b2):
for bar in bars:
h = bar.get_height()
ax.text(bar.get_x() + bar.get_width()/2, h + 0.3,
f"{int(h)}", ha="center", va="bottom",
fontsize=10, fontweight="bold")
ax.set_xticks(x)
ax.set_xticklabels(categories, fontsize=11)
ax.set_ylabel("Number of item-reviews", fontsize=12)
ax.set_title("Persuasion outcomes by model",
fontsize=14, fontweight="bold", pad=15)
ax.legend(loc="upper right", frameon=True, fontsize=11)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.grid(axis="y", alpha=0.3, linestyle="-", linewidth=0.5)
ax.set_axisbelow(True)
plt.tight_layout()
plt.savefig(out_path, bbox_inches="tight", facecolor="white")
print(f"Saved: {out_path}")
def plot_resistant_conversion(base_rows, ckpt_rows,
out_path="/dfs/scratch1/echoi1/prolific_preferences/resistant_conversion.png"):
"""Among users who pre-rated > 4 (leaning B), how many did each model flip toward A?"""
def stats(rows):
resistant = [r for r in rows if r["pre"] > 4]
if not resistant:
return 0, 0, 0.0
converted = sum(1 for r in resistant if r["delta"] < 0)
return converted, len(resistant), converted / len(resistant) * 100
base_conv, base_tot, base_pct = stats(base_rows)
ckpt_conv, ckpt_tot, ckpt_pct = stats(ckpt_rows)
fig, ax = plt.subplots(figsize=(8, 6), dpi=150)
x = np.arange(2)
vals = [base_conv, ckpt_conv]
tots = [base_tot, ckpt_tot]
pcts = [base_pct, ckpt_pct]
bars = ax.bar(
x, vals, width=0.5,
color=[ROSE, SAGE], edgecolor=[ROSE_EDGE, SAGE_EDGE], linewidth=1.5,
)
# Label each bar with "X / Y (Z%)"
for bar, val, tot, pct in zip(bars, vals, tots, pcts):
h = bar.get_height()
ax.text(bar.get_x() + bar.get_width()/2, h + 0.2,
f"{val} / {tot}\n{pct:.1f}%",
ha="center", va="bottom",
fontsize=11, fontweight="bold")
ax.set_xticks(x)
ax.set_xticklabels([f"Base\n(N={base_tot})", f"Fine-tuned\n(N={ckpt_tot})"], fontsize=12)
ax.set_ylabel("Users converted to prefer A", fontsize=12)
ax.set_title(
"Converting resistant users\n(those who initially leaned toward B)",
fontsize=14, fontweight="bold", pad=15,
)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.grid(axis="y", alpha=0.3, linestyle="-", linewidth=0.5)
ax.set_axisbelow(True)
plt.tight_layout()
plt.savefig(out_path, bbox_inches="tight", facecolor="white")
print(f"Saved: {out_path}")
if __name__ == "__main__":
token = os.getenv("HF_TOKEN")
assert token, "set HF_TOKEN"
print("Fetching base study...")
base_rows = fetch_rows(REPOS["base"], token, MAX_REVIEWS["base"])
print(f" {len(base_rows)} item-reviews")
print("Fetching checkpoint study...")
ckpt_rows = fetch_rows(REPOS["checkpoint"], token, MAX_REVIEWS["checkpoint"])
print(f" {len(ckpt_rows)} item-reviews")
plot_outcome_breakdown(base_rows, ckpt_rows)
plot_resistant_conversion(base_rows, ckpt_rows) |