File size: 6,391 Bytes
0f4326e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
"""
Bar charts comparing base vs. checkpoint persuasion outcomes.

Chart 1: Outcome breakdown per model (shifted toward A / no change / backfire).
Chart 2: Conversion rate among users who initially leaned B (pre > 4).
"""
import json
import os
from huggingface_hub import HfApi
import numpy as np
import matplotlib.pyplot as plt


REPOS = {
    "base":       "ehejin/user_study-preference-base_REAL",
    "checkpoint": "ehejin/user_study-preference-base_DETAILED_checkpoint",
}

# Cap per study (matches distribution plot)
MAX_REVIEWS = {"base": None, "checkpoint": 50}

# Pastel palette
ROSE      = "#E8A9A1"
SAGE      = "#A8C5A1"
ROSE_EDGE = "#C87E75"
SAGE_EDGE = "#6F9A67"
# Outcome colors for the stacked/grouped chart
GREEN     = "#A8C5A1"   # shifted toward A (success)
GREY      = "#D4D0C8"   # no change
RED       = "#E8A9A1"   # backfire


def fetch_rows(repo_id: str, token: str, max_reviews: int = None) -> list:
    api = HfApi(token=token)
    files = list(api.list_repo_files(repo_id=repo_id, repo_type="dataset"))
    json_files = [f for f in files if f.startswith("json/") and f.endswith(".json")]

    submissions = []
    for filepath in json_files:
        local = api.hf_hub_download(
            repo_id=repo_id, filename=filepath,
            repo_type="dataset", token=token,
        )
        with open(local) as f:
            submissions.append(json.load(f))
    submissions.sort(key=lambda s: s.get("start_time", 0))

    rows = []
    for sub in submissions:
        for item in sub.get("items", []):
            pre, post = item.get("pre_rating"), item.get("post_rating")
            if pre is not None and post is not None:
                rows.append({"pre": pre, "post": post, "delta": post - pre})
                if max_reviews is not None and len(rows) >= max_reviews:
                    return rows
    return rows


def plot_outcome_breakdown(base_rows, ckpt_rows,
                           out_path="/dfs/scratch1/echoi1/prolific_preferences/outcome_breakdown.png"):
    """Grouped bar chart: counts of each outcome for base vs. checkpoint."""
    def counts(rows):
        shifted_a  = sum(1 for r in rows if r["delta"] < 0)
        no_change  = sum(1 for r in rows if r["delta"] == 0)
        shifted_b  = sum(1 for r in rows if r["delta"] > 0)
        return shifted_a, no_change, shifted_b

    base_a, base_n, base_b = counts(base_rows)
    ckpt_a, ckpt_n, ckpt_b = counts(ckpt_rows)

    categories = ["Shifted toward A\n(success ✓)", "No change", "Shifted toward B\n(backfire ✗)"]
    base_vals  = [base_a, base_n, base_b]
    ckpt_vals  = [ckpt_a, ckpt_n, ckpt_b]

    x         = np.arange(len(categories))
    bar_width = 0.38

    fig, ax = plt.subplots(figsize=(10, 6), dpi=150)

    b1 = ax.bar(x - bar_width/2, base_vals,  bar_width,
                color=ROSE, edgecolor=ROSE_EDGE, linewidth=1.5,
                label=f"Base model  (N={len(base_rows)})")
    b2 = ax.bar(x + bar_width/2, ckpt_vals,  bar_width,
                color=SAGE, edgecolor=SAGE_EDGE, linewidth=1.5,
                label=f"Fine-tuned  (N={len(ckpt_rows)})")

    # Add count labels on top of each bar
    for bars in (b1, b2):
        for bar in bars:
            h = bar.get_height()
            ax.text(bar.get_x() + bar.get_width()/2, h + 0.3,
                    f"{int(h)}", ha="center", va="bottom",
                    fontsize=10, fontweight="bold")

    ax.set_xticks(x)
    ax.set_xticklabels(categories, fontsize=11)
    ax.set_ylabel("Number of item-reviews", fontsize=12)
    ax.set_title("Persuasion outcomes by model",
                 fontsize=14, fontweight="bold", pad=15)
    ax.legend(loc="upper right", frameon=True, fontsize=11)
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    ax.grid(axis="y", alpha=0.3, linestyle="-", linewidth=0.5)
    ax.set_axisbelow(True)

    plt.tight_layout()
    plt.savefig(out_path, bbox_inches="tight", facecolor="white")
    print(f"Saved: {out_path}")


def plot_resistant_conversion(base_rows, ckpt_rows,
                              out_path="/dfs/scratch1/echoi1/prolific_preferences/resistant_conversion.png"):
    """Among users who pre-rated > 4 (leaning B), how many did each model flip toward A?"""
    def stats(rows):
        resistant = [r for r in rows if r["pre"] > 4]
        if not resistant:
            return 0, 0, 0.0
        converted = sum(1 for r in resistant if r["delta"] < 0)
        return converted, len(resistant), converted / len(resistant) * 100

    base_conv, base_tot, base_pct = stats(base_rows)
    ckpt_conv, ckpt_tot, ckpt_pct = stats(ckpt_rows)

    fig, ax = plt.subplots(figsize=(8, 6), dpi=150)
    x     = np.arange(2)
    vals  = [base_conv, ckpt_conv]
    tots  = [base_tot,  ckpt_tot]
    pcts  = [base_pct,  ckpt_pct]

    bars = ax.bar(
        x, vals, width=0.5,
        color=[ROSE, SAGE], edgecolor=[ROSE_EDGE, SAGE_EDGE], linewidth=1.5,
    )

    # Label each bar with "X / Y   (Z%)"
    for bar, val, tot, pct in zip(bars, vals, tots, pcts):
        h = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2, h + 0.2,
                f"{val} / {tot}\n{pct:.1f}%",
                ha="center", va="bottom",
                fontsize=11, fontweight="bold")

    ax.set_xticks(x)
    ax.set_xticklabels([f"Base\n(N={base_tot})", f"Fine-tuned\n(N={ckpt_tot})"], fontsize=12)
    ax.set_ylabel("Users converted to prefer A", fontsize=12)
    ax.set_title(
        "Converting resistant users\n(those who initially leaned toward B)",
        fontsize=14, fontweight="bold", pad=15,
    )
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    ax.grid(axis="y", alpha=0.3, linestyle="-", linewidth=0.5)
    ax.set_axisbelow(True)

    plt.tight_layout()
    plt.savefig(out_path, bbox_inches="tight", facecolor="white")
    print(f"Saved: {out_path}")


if __name__ == "__main__":
    token = os.getenv("HF_TOKEN")
    assert token, "set HF_TOKEN"

    print("Fetching base study...")
    base_rows = fetch_rows(REPOS["base"], token, MAX_REVIEWS["base"])
    print(f"  {len(base_rows)} item-reviews")

    print("Fetching checkpoint study...")
    ckpt_rows = fetch_rows(REPOS["checkpoint"], token, MAX_REVIEWS["checkpoint"])
    print(f"  {len(ckpt_rows)} item-reviews")

    plot_outcome_breakdown(base_rows, ckpt_rows)
    plot_resistant_conversion(base_rows, ckpt_rows)