tripleratings / app.py
sherzod-hakimov's picture
Update app.py
674b568 verified
import gradio as gr
import pandas as pd
import uuid
import os
import json
from datetime import datetime
from huggingface_hub import HfApi, hf_hub_download
import requests
from PIL import Image
from io import BytesIO
import random
# ── Config ────────────────────────────────────────────────────────────────────
SOURCE_DATASET = "selected_images.csv"
RESULTS_DATASET = "imagereconstructionteam/tripleratingsresults"
CSV_FILENAME = "selected_images.csv"
HF_TOKEN = os.environ.get("HF_TOKEN", "")
# ── Load CSV ──────────────────────────────────────────────────────────────────
def load_csv():
df = pd.read_csv(SOURCE_DATASET)
return df
DF = load_csv()
# ── Image loading helper ──────────────────────────────────────────────────────
def load_image_from_hf(image_path: str):
if not image_path:
return None
if image_path.startswith("http"):
resp = requests.get(image_path, timeout=15)
resp.raise_for_status()
return Image.open(BytesIO(resp.content)).convert("RGB")
else:
try:
local = hf_hub_download(
repo_id=SOURCE_DATASET,
filename=image_path,
repo_type="dataset",
token=HF_TOKEN if HF_TOKEN else None,
)
return Image.open(local).convert("RGB")
except Exception:
return None
# ── Save all results at once ──────────────────────────────────────────────────
def save_all_results(pending: list):
"""Append all pending result dicts to results.jsonl in one upload."""
if not pending:
return
if not HF_TOKEN:
print("No HF_TOKEN – results not persisted:", pending)
return
api = HfApi(token=HF_TOKEN)
try:
existing_path = hf_hub_download(
repo_id=RESULTS_DATASET,
filename="results.jsonl",
repo_type="dataset",
token=HF_TOKEN,
)
with open(existing_path, "r") as f:
lines = f.read()
except Exception:
lines = ""
for result in pending:
lines += json.dumps(result) + "\n"
import tempfile
with tempfile.NamedTemporaryFile("w", suffix=".jsonl", delete=False) as tmp:
tmp.write(lines)
tmp_path = tmp.name
api.upload_file(
path_or_fileobj=tmp_path,
path_in_repo="results.jsonl",
repo_id=RESULTS_DATASET,
repo_type="dataset",
commit_message=f"Add {len(pending)} result(s)",
)
os.unlink(tmp_path)
# ── Row helper ────────────────────────────────────────────────────────────────
def get_row(idx: int):
if DF.empty or idx >= len(DF):
return None
return DF.iloc[idx].to_dict()
# ── Gradio app ────────────────────────────────────────────────────────────────
def build_app():
with gr.Blocks(
title="Image Similarity Rating",
css="""
@import url('https://fonts.googleapis.com/css2?family=DM+Serif+Display:ital@0;1&family=DM+Sans:wght@300;400;500&display=swap');
:root {
/* Light Palette */
--bg: #ffffff; /* Pure white background */
--surface: #f7f7f8; /* Light gray surface */
--surface2: #efeff1; /* Slightly darker surface for depth */
--accent: #b5d930; /* Darkened lime for better light-mode visibility */
--accent2: #48c6ef; /* Vibrant blue-cyan */
--text: #0e0e11; /* Near black text */
--muted: #62626e; /* Darker muted text for readability */
--border: rgba(0,0,0,0.08); /* Subtle dark border */
--radius: 16px;
}
body, .gradio-container {
background: var(--bg) !important;
font-family: 'DM Sans', sans-serif !important;
color: var(--text) !important;
min-height: 100vh;
}
#header {
text-align: center;
padding: 3rem 1rem 1.5rem;
border-bottom: 1px solid var(--border);
margin-bottom: 2rem;
}
#header h1 {
font-family: 'DM Serif Display', serif;
font-size: clamp(2rem, 5vw, 3.2rem);
color: var(--text);
letter-spacing: -0.02em;
margin: 0 0 .4rem;
}
#header p { color: var(--muted); font-size: .95rem; margin: 0; }
/* ── Progress Bar ── */
#progress_bar {
height: 4px;
background: var(--surface2);
border-radius: 99px;
overflow: hidden;
margin: 0 auto 2rem;
max-width: 600px;
}
#progress_fill {
height: 100%;
background: linear-gradient(90deg, var(--accent2), var(--accent));
border-radius: 99px;
transition: width .4s cubic-bezier(.4,0,.2,1);
}
.orig-label {
text-align: center;
font-size: .7rem;
font-weight: 500;
text-transform: uppercase;
letter-spacing: .15em;
color: var(--muted);
margin-bottom: .5rem;
}
/* ── Image Containers ── */
.gradio-image,
.gradio-image > div,
.image-container,
[data-testid="image"] {
background: transparent !important;
border: none !important;
width: 100% !important;
max-width: 100% !important;
}
.gradio-image img,
.image-container img,
[data-testid="image"] img {
width: 100% !important;
height: 100% !important;
object-fit: contain !important;
border-radius: 14px !important;
display: block !important;
border: 1px solid var(--border) !important; /* Added border to images to pop against white */
}
#orig_img, #orig_img > div, #orig_img [data-testid="image"], #orig_img .image-container { height: 420px !important; }
#img_left, #img_right { height: 380px !important; }
/* ── Buttons ── */
.gr-button-primary, button.primary {
background: var(--accent) !important;
color: #0e0e11 !important; /* Keep text dark for the lime button */
border: none !important;
font-weight: 600 !important;
border-radius: 8px !important;
padding: .65rem 2rem !important;
box-shadow: 0 4px 12px rgba(181, 217, 48, 0.2); /* Soft shadow for depth */
}
.gr-button-secondary, button.secondary {
background: var(--surface2) !important;
color: var(--text) !important;
border: 1px solid var(--border) !important;
border-radius: 8px !important;
}
/* ── Done Box ── */
#done_box { text-align: center; padding: 4rem 2rem; }
#done_box h2 {
font-family: 'DM Serif Display', serif;
font-size: 2.5rem;
color: var(--text); /* Changed from accent to text for cleaner finish */
}
footer { display: none !important; }
.prose h1, .prose h2 { color: var(--text) !important; }
""",
) as demo:
# ── State ──────────────────────────────────────────────────────────────
user_id_state = gr.State(str(uuid.uuid4()))
current_idx = gr.State(0)
# Accumulates result dicts in memory; flushed to HF in one shot at the end
pending_results = gr.State([])
# Hidden UI component β€” avoids async state-timing race on first click.
# 1 = first_generated is on the left, 0 = final_generated is on the left
left_is_first = gr.Number(value=1, visible=False)
total = len(DF)
# ── Header ─────────────────────────────────────────────────────────────
gr.HTML("""
<div id="header">
<h1>Which reconstruction is closer?</h1>
<p>Look at the original image, then pick the generated version that resembles it most.</p>
</div>
""")
# ── Progress ───────────────────────────────────────────────────────────
progress_html = gr.HTML(f"""
<div id="progress_bar">
<div id="progress_fill" style="width:0%"></div>
</div>
<p style="text-align:center;color:var(--muted);font-size:.82rem;margin-bottom:1.5rem">
Question 1 of {total}
</p>
""")
# ── Main panel ─────────────────────────────────────────────────────────
with gr.Column(visible=True) as main_panel:
gr.HTML('<p class="orig-label">Original Image</p>')
with gr.Row():
with gr.Column(scale=1): gr.HTML("")
with gr.Column(scale=3):
orig_img = gr.Image(
label="", show_label=False,
interactive=False, height=420, elem_id="orig_img",
)
with gr.Column(scale=1): gr.HTML("")
gr.HTML('<p style="text-align:center;color:var(--muted);font-size:.82rem;margin:1.5rem 0 .4rem">'
'Which generated image looks <strong style="color:var(--text)">more similar</strong> to the original?</p>')
with gr.Row(equal_height=True):
with gr.Column():
img_left = gr.Image(
label="", show_label=False,
interactive=False, height=380, elem_id="img_left",
)
btn_left = gr.Button("✦ Choose this", variant="primary", size="lg")
with gr.Column():
img_right = gr.Image(
label="", show_label=False,
interactive=False, height=380, elem_id="img_right",
)
btn_right = gr.Button("✦ Choose this", variant="primary", size="lg")
# ── Done screen ────────────────────────────────────────────────────────
with gr.Column(visible=False) as done_panel:
gr.HTML("""
<div id="done_box">
<h2>Thank you! πŸŽ‰</h2>
<p style="color:var(--muted);font-size:1.1rem;margin-top:.5rem">
You've rated all images. Your answers have been saved.
</p>
</div>
""")
# ── Helpers ────────────────────────────────────────────────────────────
def make_progress(idx):
pct = int(min(idx, total) / total * 100)
n = min(idx + 1, total)
return (f'<div id="progress_bar"><div id="progress_fill" style="width:{pct}%"></div></div>'
f'<p style="text-align:center;color:var(--muted);font-size:.82rem;margin-bottom:1.5rem">'
f'Question {n} of {total}</p>')
def load_row(idx):
row = get_row(idx)
if row is None:
return (gr.update(visible=False), gr.update(visible=True),
None, None, None, 0, make_progress(idx))
orig = load_image_from_hf(str(row.get("original_image_path", "")))
first = load_image_from_hf(str(row.get("first_generated_image_path", "")))
final_ = load_image_from_hf(str(row.get("final_generated_image", "")))
flip = 1 if random.random() < 0.5 else 0
left_img = first if flip == 1 else final_
right_img = final_ if flip == 1 else first
return (gr.update(visible=True), gr.update(visible=False),
orig, left_img, right_img, flip, make_progress(idx))
# ── Initial load ───────────────────────────────────────────────────────
demo.load(
fn=lambda uid, idx: load_row(idx),
inputs=[user_id_state, current_idx],
outputs=[main_panel, done_panel, orig_img, img_left, img_right,
left_is_first, progress_html],
)
# ── Choice handlers ────────────────────────────────────────────────────
def handle_choice(side, uid, idx, lif, pending):
row = get_row(idx)
if row:
chose_left = (side == "left")
choice = "first" if (chose_left == (lif == 1)) else "final"
# Build a new list β€” never mutate gr.State in place
pending = pending + [{
"user_id": uid,
"timestamp": datetime.utcnow().isoformat(),
"choice": choice,
**{k: str(v) for k, v in row.items()},
}]
next_idx = idx + 1
main_v, done_v, orig, left_img, right_img, new_lif, prog = load_row(next_idx)
# Last question just answered β€” flush everything in one upload
if get_row(next_idx) is None:
save_all_results(pending)
pending = [] # clear so a stray re-render can't double-save
return next_idx, pending, main_v, done_v, orig, left_img, right_img, new_lif, prog
_outputs = [current_idx, pending_results, main_panel, done_panel,
orig_img, img_left, img_right, left_is_first, progress_html]
btn_left.click(
fn=lambda uid, idx, lif, pending: handle_choice("left", uid, idx, lif, pending),
inputs=[user_id_state, current_idx, left_is_first, pending_results],
outputs=_outputs,
)
btn_right.click(
fn=lambda uid, idx, lif, pending: handle_choice("right", uid, idx, lif, pending),
inputs=[user_id_state, current_idx, left_is_first, pending_results],
outputs=_outputs,
)
return demo
demo = build_app()
if __name__ == "__main__":
demo.launch()