prolific_preferences / data_analysis /plot_distribution.py
ehejin's picture
sync w/ detailed repo
0f4326e
"""
Bar charts comparing base vs. checkpoint persuasion outcomes.
Chart 1: Outcome breakdown per model (shifted toward A / no change / backfire).
Chart 2: Conversion rate among users who initially leaned B (pre > 4).
"""
import json
import os
from huggingface_hub import HfApi
import numpy as np
import matplotlib.pyplot as plt
REPOS = {
"base": "ehejin/user_study-preference-base_REAL",
"checkpoint": "ehejin/user_study-preference-base_DETAILED_checkpoint",
}
# Cap per study (matches distribution plot)
MAX_REVIEWS = {"base": None, "checkpoint": 50}
# Pastel palette
ROSE = "#E8A9A1"
SAGE = "#A8C5A1"
ROSE_EDGE = "#C87E75"
SAGE_EDGE = "#6F9A67"
# Outcome colors for the stacked/grouped chart
GREEN = "#A8C5A1" # shifted toward A (success)
GREY = "#D4D0C8" # no change
RED = "#E8A9A1" # backfire
def fetch_rows(repo_id: str, token: str, max_reviews: int = None) -> list:
api = HfApi(token=token)
files = list(api.list_repo_files(repo_id=repo_id, repo_type="dataset"))
json_files = [f for f in files if f.startswith("json/") and f.endswith(".json")]
submissions = []
for filepath in json_files:
local = api.hf_hub_download(
repo_id=repo_id, filename=filepath,
repo_type="dataset", token=token,
)
with open(local) as f:
submissions.append(json.load(f))
submissions.sort(key=lambda s: s.get("start_time", 0))
rows = []
for sub in submissions:
for item in sub.get("items", []):
pre, post = item.get("pre_rating"), item.get("post_rating")
if pre is not None and post is not None:
rows.append({"pre": pre, "post": post, "delta": post - pre})
if max_reviews is not None and len(rows) >= max_reviews:
return rows
return rows
def plot_outcome_breakdown(base_rows, ckpt_rows,
out_path="/dfs/scratch1/echoi1/prolific_preferences/outcome_breakdown.png"):
"""Grouped bar chart: counts of each outcome for base vs. checkpoint."""
def counts(rows):
shifted_a = sum(1 for r in rows if r["delta"] < 0)
no_change = sum(1 for r in rows if r["delta"] == 0)
shifted_b = sum(1 for r in rows if r["delta"] > 0)
return shifted_a, no_change, shifted_b
base_a, base_n, base_b = counts(base_rows)
ckpt_a, ckpt_n, ckpt_b = counts(ckpt_rows)
categories = ["Shifted toward A\n(success ✓)", "No change", "Shifted toward B\n(backfire ✗)"]
base_vals = [base_a, base_n, base_b]
ckpt_vals = [ckpt_a, ckpt_n, ckpt_b]
x = np.arange(len(categories))
bar_width = 0.38
fig, ax = plt.subplots(figsize=(10, 6), dpi=150)
b1 = ax.bar(x - bar_width/2, base_vals, bar_width,
color=ROSE, edgecolor=ROSE_EDGE, linewidth=1.5,
label=f"Base model (N={len(base_rows)})")
b2 = ax.bar(x + bar_width/2, ckpt_vals, bar_width,
color=SAGE, edgecolor=SAGE_EDGE, linewidth=1.5,
label=f"Fine-tuned (N={len(ckpt_rows)})")
# Add count labels on top of each bar
for bars in (b1, b2):
for bar in bars:
h = bar.get_height()
ax.text(bar.get_x() + bar.get_width()/2, h + 0.3,
f"{int(h)}", ha="center", va="bottom",
fontsize=10, fontweight="bold")
ax.set_xticks(x)
ax.set_xticklabels(categories, fontsize=11)
ax.set_ylabel("Number of item-reviews", fontsize=12)
ax.set_title("Persuasion outcomes by model",
fontsize=14, fontweight="bold", pad=15)
ax.legend(loc="upper right", frameon=True, fontsize=11)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.grid(axis="y", alpha=0.3, linestyle="-", linewidth=0.5)
ax.set_axisbelow(True)
plt.tight_layout()
plt.savefig(out_path, bbox_inches="tight", facecolor="white")
print(f"Saved: {out_path}")
def plot_resistant_conversion(base_rows, ckpt_rows,
out_path="/dfs/scratch1/echoi1/prolific_preferences/resistant_conversion.png"):
"""Among users who pre-rated > 4 (leaning B), how many did each model flip toward A?"""
def stats(rows):
resistant = [r for r in rows if r["pre"] > 4]
if not resistant:
return 0, 0, 0.0
converted = sum(1 for r in resistant if r["delta"] < 0)
return converted, len(resistant), converted / len(resistant) * 100
base_conv, base_tot, base_pct = stats(base_rows)
ckpt_conv, ckpt_tot, ckpt_pct = stats(ckpt_rows)
fig, ax = plt.subplots(figsize=(8, 6), dpi=150)
x = np.arange(2)
vals = [base_conv, ckpt_conv]
tots = [base_tot, ckpt_tot]
pcts = [base_pct, ckpt_pct]
bars = ax.bar(
x, vals, width=0.5,
color=[ROSE, SAGE], edgecolor=[ROSE_EDGE, SAGE_EDGE], linewidth=1.5,
)
# Label each bar with "X / Y (Z%)"
for bar, val, tot, pct in zip(bars, vals, tots, pcts):
h = bar.get_height()
ax.text(bar.get_x() + bar.get_width()/2, h + 0.2,
f"{val} / {tot}\n{pct:.1f}%",
ha="center", va="bottom",
fontsize=11, fontweight="bold")
ax.set_xticks(x)
ax.set_xticklabels([f"Base\n(N={base_tot})", f"Fine-tuned\n(N={ckpt_tot})"], fontsize=12)
ax.set_ylabel("Users converted to prefer A", fontsize=12)
ax.set_title(
"Converting resistant users\n(those who initially leaned toward B)",
fontsize=14, fontweight="bold", pad=15,
)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.grid(axis="y", alpha=0.3, linestyle="-", linewidth=0.5)
ax.set_axisbelow(True)
plt.tight_layout()
plt.savefig(out_path, bbox_inches="tight", facecolor="white")
print(f"Saved: {out_path}")
if __name__ == "__main__":
token = os.getenv("HF_TOKEN")
assert token, "set HF_TOKEN"
print("Fetching base study...")
base_rows = fetch_rows(REPOS["base"], token, MAX_REVIEWS["base"])
print(f" {len(base_rows)} item-reviews")
print("Fetching checkpoint study...")
ckpt_rows = fetch_rows(REPOS["checkpoint"], token, MAX_REVIEWS["checkpoint"])
print(f" {len(ckpt_rows)} item-reviews")
plot_outcome_breakdown(base_rows, ckpt_rows)
plot_resistant_conversion(base_rows, ckpt_rows)