import gradio as gr import pandas as pd import uuid import os import json from datetime import datetime from huggingface_hub import HfApi, hf_hub_download import requests from PIL import Image from io import BytesIO import random # ── Config ──────────────────────────────────────────────────────────────────── SOURCE_DATASET = "selected_images.csv" RESULTS_DATASET = "imagereconstructionteam/tripleratingsresults" CSV_FILENAME = "selected_images.csv" HF_TOKEN = os.environ.get("HF_TOKEN", "") # ── Load CSV ────────────────────────────────────────────────────────────────── def load_csv(): df = pd.read_csv(SOURCE_DATASET) return df DF = load_csv() # ── Image loading helper ────────────────────────────────────────────────────── def load_image_from_hf(image_path: str): if not image_path: return None if image_path.startswith("http"): resp = requests.get(image_path, timeout=15) resp.raise_for_status() return Image.open(BytesIO(resp.content)).convert("RGB") else: try: local = hf_hub_download( repo_id=SOURCE_DATASET, filename=image_path, repo_type="dataset", token=HF_TOKEN if HF_TOKEN else None, ) return Image.open(local).convert("RGB") except Exception: return None # ── Save all results at once ────────────────────────────────────────────────── def save_all_results(pending: list): """Append all pending result dicts to results.jsonl in one upload.""" if not pending: return if not HF_TOKEN: print("No HF_TOKEN – results not persisted:", pending) return api = HfApi(token=HF_TOKEN) try: existing_path = hf_hub_download( repo_id=RESULTS_DATASET, filename="results.jsonl", repo_type="dataset", token=HF_TOKEN, ) with open(existing_path, "r") as f: lines = f.read() except Exception: lines = "" for result in pending: lines += json.dumps(result) + "\n" import tempfile with tempfile.NamedTemporaryFile("w", suffix=".jsonl", delete=False) as tmp: tmp.write(lines) tmp_path = tmp.name api.upload_file( path_or_fileobj=tmp_path, path_in_repo="results.jsonl", repo_id=RESULTS_DATASET, repo_type="dataset", commit_message=f"Add {len(pending)} result(s)", ) os.unlink(tmp_path) # ── Row helper ──────────────────────────────────────────────────────────────── def get_row(idx: int): if DF.empty or idx >= len(DF): return None return DF.iloc[idx].to_dict() # ── Gradio app ──────────────────────────────────────────────────────────────── def build_app(): with gr.Blocks( title="Image Similarity Rating", css=""" @import url('https://fonts.googleapis.com/css2?family=DM+Serif+Display:ital@0;1&family=DM+Sans:wght@300;400;500&display=swap'); :root { /* Light Palette */ --bg: #ffffff; /* Pure white background */ --surface: #f7f7f8; /* Light gray surface */ --surface2: #efeff1; /* Slightly darker surface for depth */ --accent: #b5d930; /* Darkened lime for better light-mode visibility */ --accent2: #48c6ef; /* Vibrant blue-cyan */ --text: #0e0e11; /* Near black text */ --muted: #62626e; /* Darker muted text for readability */ --border: rgba(0,0,0,0.08); /* Subtle dark border */ --radius: 16px; } body, .gradio-container { background: var(--bg) !important; font-family: 'DM Sans', sans-serif !important; color: var(--text) !important; min-height: 100vh; } #header { text-align: center; padding: 3rem 1rem 1.5rem; border-bottom: 1px solid var(--border); margin-bottom: 2rem; } #header h1 { font-family: 'DM Serif Display', serif; font-size: clamp(2rem, 5vw, 3.2rem); color: var(--text); letter-spacing: -0.02em; margin: 0 0 .4rem; } #header p { color: var(--muted); font-size: .95rem; margin: 0; } /* ── Progress Bar ── */ #progress_bar { height: 4px; background: var(--surface2); border-radius: 99px; overflow: hidden; margin: 0 auto 2rem; max-width: 600px; } #progress_fill { height: 100%; background: linear-gradient(90deg, var(--accent2), var(--accent)); border-radius: 99px; transition: width .4s cubic-bezier(.4,0,.2,1); } .orig-label { text-align: center; font-size: .7rem; font-weight: 500; text-transform: uppercase; letter-spacing: .15em; color: var(--muted); margin-bottom: .5rem; } /* ── Image Containers ── */ .gradio-image, .gradio-image > div, .image-container, [data-testid="image"] { background: transparent !important; border: none !important; width: 100% !important; max-width: 100% !important; } .gradio-image img, .image-container img, [data-testid="image"] img { width: 100% !important; height: 100% !important; object-fit: contain !important; border-radius: 14px !important; display: block !important; border: 1px solid var(--border) !important; /* Added border to images to pop against white */ } #orig_img, #orig_img > div, #orig_img [data-testid="image"], #orig_img .image-container { height: 420px !important; } #img_left, #img_right { height: 380px !important; } /* ── Buttons ── */ .gr-button-primary, button.primary { background: var(--accent) !important; color: #0e0e11 !important; /* Keep text dark for the lime button */ border: none !important; font-weight: 600 !important; border-radius: 8px !important; padding: .65rem 2rem !important; box-shadow: 0 4px 12px rgba(181, 217, 48, 0.2); /* Soft shadow for depth */ } .gr-button-secondary, button.secondary { background: var(--surface2) !important; color: var(--text) !important; border: 1px solid var(--border) !important; border-radius: 8px !important; } /* ── Done Box ── */ #done_box { text-align: center; padding: 4rem 2rem; } #done_box h2 { font-family: 'DM Serif Display', serif; font-size: 2.5rem; color: var(--text); /* Changed from accent to text for cleaner finish */ } footer { display: none !important; } .prose h1, .prose h2 { color: var(--text) !important; } """, ) as demo: # ── State ────────────────────────────────────────────────────────────── user_id_state = gr.State(str(uuid.uuid4())) current_idx = gr.State(0) # Accumulates result dicts in memory; flushed to HF in one shot at the end pending_results = gr.State([]) # Hidden UI component — avoids async state-timing race on first click. # 1 = first_generated is on the left, 0 = final_generated is on the left left_is_first = gr.Number(value=1, visible=False) total = len(DF) # ── Header ───────────────────────────────────────────────────────────── gr.HTML(""" """) # ── Progress ─────────────────────────────────────────────────────────── progress_html = gr.HTML(f"""

Question 1 of {total}

""") # ── Main panel ───────────────────────────────────────────────────────── with gr.Column(visible=True) as main_panel: gr.HTML('

Original Image

') with gr.Row(): with gr.Column(scale=1): gr.HTML("") with gr.Column(scale=3): orig_img = gr.Image( label="", show_label=False, interactive=False, height=420, elem_id="orig_img", ) with gr.Column(scale=1): gr.HTML("") gr.HTML('

' 'Which generated image looks more similar to the original?

') with gr.Row(equal_height=True): with gr.Column(): img_left = gr.Image( label="", show_label=False, interactive=False, height=380, elem_id="img_left", ) btn_left = gr.Button("✦ Choose this", variant="primary", size="lg") with gr.Column(): img_right = gr.Image( label="", show_label=False, interactive=False, height=380, elem_id="img_right", ) btn_right = gr.Button("✦ Choose this", variant="primary", size="lg") # ── Done screen ──────────────────────────────────────────────────────── with gr.Column(visible=False) as done_panel: gr.HTML("""

Thank you! 🎉

You've rated all images. Your answers have been saved.

""") # ── Helpers ──────────────────────────────────────────────────────────── def make_progress(idx): pct = int(min(idx, total) / total * 100) n = min(idx + 1, total) return (f'
' f'

' f'Question {n} of {total}

') def load_row(idx): row = get_row(idx) if row is None: return (gr.update(visible=False), gr.update(visible=True), None, None, None, 0, make_progress(idx)) orig = load_image_from_hf(str(row.get("original_image_path", ""))) first = load_image_from_hf(str(row.get("first_generated_image_path", ""))) final_ = load_image_from_hf(str(row.get("final_generated_image", ""))) flip = 1 if random.random() < 0.5 else 0 left_img = first if flip == 1 else final_ right_img = final_ if flip == 1 else first return (gr.update(visible=True), gr.update(visible=False), orig, left_img, right_img, flip, make_progress(idx)) # ── Initial load ─────────────────────────────────────────────────────── demo.load( fn=lambda uid, idx: load_row(idx), inputs=[user_id_state, current_idx], outputs=[main_panel, done_panel, orig_img, img_left, img_right, left_is_first, progress_html], ) # ── Choice handlers ──────────────────────────────────────────────────── def handle_choice(side, uid, idx, lif, pending): row = get_row(idx) if row: chose_left = (side == "left") choice = "first" if (chose_left == (lif == 1)) else "final" # Build a new list — never mutate gr.State in place pending = pending + [{ "user_id": uid, "timestamp": datetime.utcnow().isoformat(), "choice": choice, **{k: str(v) for k, v in row.items()}, }] next_idx = idx + 1 main_v, done_v, orig, left_img, right_img, new_lif, prog = load_row(next_idx) # Last question just answered — flush everything in one upload if get_row(next_idx) is None: save_all_results(pending) pending = [] # clear so a stray re-render can't double-save return next_idx, pending, main_v, done_v, orig, left_img, right_img, new_lif, prog _outputs = [current_idx, pending_results, main_panel, done_panel, orig_img, img_left, img_right, left_is_first, progress_html] btn_left.click( fn=lambda uid, idx, lif, pending: handle_choice("left", uid, idx, lif, pending), inputs=[user_id_state, current_idx, left_is_first, pending_results], outputs=_outputs, ) btn_right.click( fn=lambda uid, idx, lif, pending: handle_choice("right", uid, idx, lif, pending), inputs=[user_id_state, current_idx, left_is_first, pending_results], outputs=_outputs, ) return demo demo = build_app() if __name__ == "__main__": demo.launch()