File size: 11,983 Bytes
ce7a9d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f87f155
 
 
 
ce7a9d0
 
f87f155
ce7a9d0
f87f155
 
ce7a9d0
 
 
 
 
f87f155
ce7a9d0
f87f155
 
 
 
 
ce7a9d0
f87f155
 
 
 
 
 
 
 
 
ce7a9d0
 
 
f87f155
 
 
 
 
 
 
 
 
 
 
ce7a9d0
f87f155
 
ce7a9d0
 
 
 
 
 
f87f155
 
ce7a9d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33794d9
ce7a9d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
"""
Image Similarity Rating App
----------------------------
Reads pairs.csv (committed to the Space repo) and shows all pairs
in random order to each user. No repetitions within a session.
"""

import io
import os
import uuid
import random
import pandas as pd
import gradio as gr
from datetime import datetime
from datasets import Dataset
from huggingface_hub import HfApi

# ── Config ────────────────────────────────────────────────────────────────────
HF_TOKEN        = os.environ.get("HF_TOKEN", "")
HF_DATASET_REPO = os.environ.get("HF_DATASET_REPO", "")  # where votes are saved
CSV_PATH        = "pairs.csv"                              # committed to Space repo

# ── Load pairs CSV once at startup ───────────────────────────────────────────
print("Loading pairs.csv ...")
_pairs_df = pd.read_csv(CSV_PATH)
_pairs = _pairs_df.to_dict(orient="records")
print(f"Loaded {len(_pairs)} pairs.")


# ── Persistence ───────────────────────────────────────────────────────────────
VOTES_FILE = "votes.parquet"  # single file in the results repo

def save_votes_to_hub(votes: list[dict]):
    """
    Append this session's votes to a single votes.parquet in the results repo.
    Strategy: download existing file -> concat -> upload back.
    """
    if not HF_DATASET_REPO or not HF_TOKEN:
        print("HF_DATASET_REPO or HF_TOKEN not set -- votes not saved remotely.")
        return
    try:
        api = HfApi(token=HF_TOKEN)
        new_df = pd.DataFrame(votes)

        # Try to download the existing parquet and append
        try:
            existing_path = api.hf_hub_download(
                repo_id=HF_DATASET_REPO,
                repo_type="dataset",
                filename=VOTES_FILE,
            )
            existing_df = pd.read_parquet(existing_path)
            combined_df = pd.concat([existing_df, new_df], ignore_index=True)
        except Exception:
            # File doesn't exist yet -- first run
            combined_df = new_df

        buf = io.BytesIO()
        combined_df.to_parquet(buf, index=False)
        buf.seek(0)

        api.upload_file(
            path_or_fileobj=buf,
            path_in_repo=VOTES_FILE,
            repo_id=HF_DATASET_REPO,
            repo_type="dataset",
        )
        print(f"Appended {len(votes)} votes to {HF_DATASET_REPO}/{VOTES_FILE} "
              f"(total rows: {len(combined_df)})")
    except Exception as ex:
        print(f"Failed to save votes: {ex}")


# ── CSS ───────────────────────────────────────────────────────────────────────
CSS = """
/* ── Main Container & Light Background ── */
body, .gradio-container { 
    background: #ffffff !important; 
    color: #000000 !important; 
}

/* ── Instructions Box ── */
.instructions {
    background: #f9f9f9; 
    border: 1px solid #ddd;
    border-radius: 8px;
    padding: 16px 20px;
    margin-bottom: 16px;
    font-size: 0.88rem;
    line-height: 1.6;
    color: #444;
}
.instructions strong { color: #000; }

/* ── Scale Values (Your specific request) ── */
.scale-list li { color: #555; display: flex; align-items: center; gap: 10px; font-size: 0.85rem; }

.scale-val {
    background: #ffffff !important; 
    color: #000000 !important;
    border: 1px solid #ddd; /* Added a light border so white-on-white is visible */
    border-radius: 4px; 
    padding: 2px 8px;
    font-weight: 700; 
    font-size: 0.8rem;
    min-width: 36px; 
    text-align: center;
    flex-shrink: 0;
}

/* ── Progress Bar ── */
.progress-wrap  { margin-bottom: 16px; }
.progress-label { font-size: 0.8rem; color: #888; margin-bottom: 4px; text-align: right; }
.progress-track { height: 4px; border-radius: 4px; overflow: hidden; background: #eee; }
.progress-fill  { height: 100%; border-radius: 4px; transition: width 0.3s ease; background: #000; }

/* ── Header ── */
h1 { font-size: 1.5rem; font-weight: 700; margin: 28px 0 4px; text-align: center; color: #000; }
.subtitle { text-align: center; color: #888; margin-bottom: 20px; font-size: 0.9rem; }

/* ── Done Banner ── */
.done-banner {
    background: #fdfdfd;
    border: 1px solid #bbb;
    border-radius: 8px;
    padding: 64px 24px;
    text-align: center;
    margin: 32px 0;
}
.done-icon  { font-size: 3.5rem; margin-bottom: 16px; }
.done-banner h2 { font-size: 1.8rem; margin: 0 0 12px; color: #000; }
.done-banner p  { margin: 0; font-size: 0.95rem; line-height: 1.7; color: #555; }

footer { display: none !important; }
"""


# ── App ───────────────────────────────────────────────────────────────────────
def make_app():
    with gr.Blocks(css=CSS, title="Image Similarity Rating") as demo:

        # State
        user_id_state = gr.State(lambda: str(uuid.uuid4()))
        queue_state   = gr.State([])
        index_state   = gr.State(0)
        votes_state   = gr.State([])
        total_state   = gr.State(0)

        # Header
        gr.HTML("<h1>Image Similarity Rating</h1>")
        gr.HTML("<p class='subtitle'>Rate how similar Image B is to Image A.</p>")

        # Done banner β€” hidden until all pairs are rated
        done_html = gr.HTML(visible=False)

        # Rating UI β€” hidden when done
        with gr.Column(visible=True) as rating_col:

            # Progress
            progress_html = gr.HTML()

            # Images
            with gr.Row(equal_height=True):
                img_left  = gr.Image(label="Image A β€” Original",  show_label=True,
                                     interactive=False, height=520)
                img_right = gr.Image(label="Image B β€” Generated", show_label=True,
                                     interactive=False, height=520)

            # Instructions
            gr.HTML("""
            <div class="instructions">
                <strong>How similar is Image B to Image A?</strong><br>
                Image A is the original; Image B was reconstructed by an AI model.
                Rate their overall visual and semantic similarity:
                <ul class="scale-list">
                    <li><span class="scale-val">0</span>   Completely different</li>
                    <li><span class="scale-val">1–3</span> Very different; only a few elements match</li>
                    <li><span class="scale-val">4–6</span> Partial match; some key elements correct, but notable differences</li>
                    <li><span class="scale-val">7–9</span> Strong match; mostly correct with minor differences</li>
                    <li><span class="scale-val">10</span>  Identical or indistinguishable</li>
                </ul>
            </div>
            """)

            # Slider + button
            score_slider = gr.Slider(minimum=0, maximum=10, step=1, value=5,
                                     label="Similarity score (0–10)", interactive=True)
            next_btn = gr.Button("Submit and continue β†’", variant="primary", size="lg")

        # ── Helpers ───────────────────────────────────────────────────────

        def build_progress(idx, total):
            pct = int(idx / total * 100) if total else 0
            return f"""
            <div class="progress-wrap">
              <div class="progress-label">{idx} / {total}</div>
              <div class="progress-track">
                <div class="progress-fill" style="width:{pct}%"></div>
              </div>
            </div>"""

        # ── Init on load ──────────────────────────────────────────────────

        def on_load(user_id):
            queue = random.sample(_pairs, len(_pairs))
            total = len(queue)
            entry = queue[0]
            return (
                queue, 0, [], total,
                build_progress(0, total),
                entry["original_path"],
                entry["final_path"],
                5,
                gr.update(visible=True),   # rating_col visible
                gr.update(visible=False),  # done_html hidden
            )

        demo.load(
            on_load,
            inputs=[user_id_state],
            outputs=[queue_state, index_state, votes_state, total_state,
                     progress_html, img_left, img_right, score_slider,
                     rating_col, done_html],
        )

        # ── Submit vote ───────────────────────────────────────────────────

        def on_next(score, queue, idx, votes, total, user_id):
            entry = queue[idx]
            vote = {
                "user_id":             user_id,
                "timestamp":           datetime.utcnow().isoformat(),
                "vote_index":          idx,
                "score":               int(score),
                "describer":           entry["describer"],
                "generator":           entry["generator"],
                "experiment":          entry["experiment"],
                "episode":             entry["episode"],
                "final_image_url":     entry["final_path"],
                "original_image_url":  entry["original_path"],
            }
            votes = votes + [vote]
            idx  += 1

            # ── All pairs rated β†’ show done banner, hide everything else ──
            if idx >= total:
                save_votes_to_hub(votes)
                done = """
                <div class="done-banner">
                  <div class="done-icon">βœ“</div>
                  <h2>Thank you!</h2>
                  <p>You have rated all image pairs.<br>
                     Your responses have been saved and will help us evaluate AI-generated images.</p>
                </div>"""
                return (
                    votes, idx,
                    build_progress(total, total),
                    gr.update(),   # img_left unchanged (hidden with column)
                    gr.update(),   # img_right unchanged
                    gr.update(),   # score_slider unchanged
                    gr.update(visible=False),  # rating_col β†’ hide entire block
                    gr.update(value=done, visible=True),  # done_html β†’ show
                )

            # ── Next pair ─────────────────────────────────────────────────
            next_entry = queue[idx]
            return (
                votes, idx,
                build_progress(idx, total),
                gr.update(value=next_entry["original_path"]),
                gr.update(value=next_entry["final_path"]),
                5,
                gr.update(visible=True),   # rating_col stays visible
                gr.update(visible=False),  # done_html stays hidden
            )

        next_btn.click(
            on_next,
            inputs=[score_slider, queue_state, index_state, votes_state,
                    total_state, user_id_state],
            outputs=[votes_state, index_state,
                     progress_html, img_left, img_right, score_slider,
                     rating_col, done_html],
        )

    return demo


if __name__ == "__main__":
    make_app().launch()