StyleID / app.py
zaqxsw0526's picture
Update app.py
4724c4a verified
import os
import re
import random
import time
import uuid
from datetime import datetime
import tarfile
from PIL import Image
import shutil
import gradio as gr
import gspread
from oauth2client.service_account import ServiceAccountCredentials
from huggingface_hub import snapshot_download
# ===================== CONFIG =====================
DATASET_REPO_ID = "zaqxsw0526/styleid-data-rebuttal"
DATA_ROOT_DIR = "data_rebuttal"
T2_SOURCE_DIR = os.path.join(DATA_ROOT_DIR, "p2")
STYLIZED_ROOT_DIR = DATA_ROOT_DIR
SET_DIR_MAP = {
"ipa_cp": "ipa_cross_prompt",
"mtg_cm": "mtg_cross_method",
"flux_cm": "flux_cross_method",
}
SETS = list(SET_DIR_MAP.keys())
STYLES_BY_SET = {
"ipa_cp": ["comic", "low-poly", "manga", "oil-painting", "pixel-art", "renaissance", "sketch", "ukiyo-e"],
"mtg_cm": ["brave", "detroit", "doc_brown", "jojo", "moana", "picasso", "pocahontas", "titan_erwin"],
"flux_cm": ["comic", "low-poly", "manga", "oil-painting", "pixel-art", "renaissance", "sketch", "ukiyo-e"],
}
N_TRIALS_TASK2 = 90
# ---- Google Sheets config ----
SERVICE_ACCOUNT_FILE = "styleid-479107-ccb8a1b36ea4.json"
SPREADSHEET_KEY = "1nWQ9UY3BDlpegclD1zy40Kfjj-0HZcoDc16RfO4QFQM"
SHEET_NAME_T2 = "rebuttal_t2"
HEADER_T2 = [
"timestamp",
"participant_id",
"name",
"age",
"gender",
"ethnicity",
"task",
"trial_index",
"human_name",
"stylized_name",
"is_positive_pair",
"user_answer",
"is_correct",
"method", # set_key (ipa_cp / mtg_cm)
"method_dir", # set_dir (ipa_cross_prompt / mtg_cross_method)
"style", # style folder name
"strength",
"response_time_ms",
"is_repeat_trial", # 0 or 1
"repeat_answer_match", # 1 (same), 0 (different), "" (N/A)
]
# ===================== DATA DOWNLOAD / EXTRACT =====================
def ensure_data_downloaded():
"""
Download rebuttal dataset into DATA_ROOT_DIR if not present.
Extract tar files if included.
Fixes the case where mtg_cross_method.tar extracts its style folders directly under DATA_ROOT_DIR.
Returns the actual root directory.
"""
print(f"[STYLEID] ๐Ÿ” Checking dataset under ./{DATA_ROOT_DIR} ...", flush=True)
MTG_STYLE_DIRS = [
"brave",
"detroit",
"doc_brown",
"jojo",
"moana",
"picasso",
"pocahontas",
"titan_erwin",
]
FLUX_STYLE_DIRS = ["comic", "low-poly", "manga", "oil-painting", "pixel-art", "renaissance", "sketch", "ukiyo-e"]
def normalize_mtg_layout(base_dir: str):
"""
If mtg_cross_method/ doesn't exist but style dirs exist directly under base_dir,
move them into base_dir/mtg_cross_method/<style>.
"""
mtg_dir = os.path.join(base_dir, "mtg_cross_method")
if os.path.isdir(mtg_dir):
return
existing = [s for s in MTG_STYLE_DIRS if os.path.isdir(os.path.join(base_dir, s))]
if not existing:
return
os.makedirs(mtg_dir, exist_ok=True)
for s in existing:
src = os.path.join(base_dir, s)
dst = os.path.join(mtg_dir, s)
if os.path.exists(dst):
continue
shutil.move(src, dst)
print(f"[STYLEID] โœ… Normalized mtg layout into '{mtg_dir}' (moved {len(existing)} style dirs).", flush=True)
def normalize_flux_layout(base_dir: str):
"""
If flux_cross_method/ doesn't exist but flux style dirs exist directly under base_dir,
move them into base_dir/flux_cross_method/<style>.
"""
flux_dir = os.path.join(base_dir, "flux_cross_method")
if os.path.isdir(flux_dir):
return
existing = [s for s in FLUX_STYLE_DIRS if os.path.isdir(os.path.join(base_dir, s))]
if not existing:
return
os.makedirs(flux_dir, exist_ok=True)
for s in existing:
src = os.path.join(base_dir, s)
dst = os.path.join(flux_dir, s)
if os.path.exists(dst):
continue
shutil.move(src, dst)
print(f"[STYLEID] โœ… Normalized flux layout into '{flux_dir}' (moved {len(existing)} style dirs).", flush=True)
def find_actual_root(base_dir: str):
# direct
direct_ok = (
os.path.isdir(os.path.join(base_dir, "ipa_cross_prompt"))
and os.path.isdir(os.path.join(base_dir, "mtg_cross_method"))
and os.path.isdir(os.path.join(base_dir, "flux_cross_method"))
)
if direct_ok:
return base_dir
# one-level nested (prefix folder inside tar)
if not os.path.isdir(base_dir):
return None
for name in os.listdir(base_dir):
cand = os.path.join(base_dir, name)
if not os.path.isdir(cand):
continue
nested_ok = (
os.path.isdir(os.path.join(cand, "ipa_cross_prompt"))
and os.path.isdir(os.path.join(cand, "mtg_cross_method"))
)
if nested_ok:
return cand
return None
# already extracted?
normalize_mtg_layout(DATA_ROOT_DIR)
normalize_flux_layout(DATA_ROOT_DIR)
actual_root = find_actual_root(DATA_ROOT_DIR)
if actual_root is not None:
print(f"[STYLEID] โœ… Data already present. Ready. (root={actual_root})", flush=True)
return actual_root
os.makedirs(DATA_ROOT_DIR, exist_ok=True)
print(f"[STYLEID] โฌ‡๏ธ Downloading dataset '{DATASET_REPO_ID}' into '{DATA_ROOT_DIR}' ...", flush=True)
try:
snapshot_download(
repo_id=DATASET_REPO_ID,
repo_type="dataset",
local_dir=DATA_ROOT_DIR,
local_dir_use_symlinks=False,
allow_patterns=["p2/**", "ipa_cross_prompt.tar", "mtg_cross_method.tar", "flux_cross_method.tar"],
)
except TypeError:
snapshot_download(
repo_id=DATASET_REPO_ID,
repo_type="dataset",
local_dir=DATA_ROOT_DIR,
local_dir_use_symlinks=False,
)
print("[STYLEID] โœ… Download complete. Searching for TAR files to extract ...", flush=True)
def is_within_directory(directory: str, target: str) -> bool:
abs_directory = os.path.abspath(directory)
abs_target = os.path.abspath(target)
return os.path.commonprefix([abs_directory, abs_target]) == abs_directory
def safe_extract(tar: tarfile.TarFile, path: str):
for member in tar.getmembers():
member_path = os.path.join(path, member.name)
if not is_within_directory(path, member_path):
raise RuntimeError(f"Unsafe path detected in tar: {member.name}")
tar.extractall(path)
tar_exts = (".tar", ".tar.gz", ".tgz", ".tar.bz2", ".tbz2", ".tar.xz", ".txz")
extracted_any = False
for root, _, files in os.walk(DATA_ROOT_DIR):
for fname in files:
if fname.lower().endswith(tar_exts):
tar_path = os.path.join(root, fname)
print(f"[STYLEID] ๐Ÿ“ฆ Extracting '{tar_path}' ...", flush=True)
with tarfile.open(tar_path, "r:*") as tf:
safe_extract(tf, DATA_ROOT_DIR)
print(f"[STYLEID] โœ… Finished extracting '{tar_path}'.", flush=True)
extracted_any = True
if extracted_any:
print("[STYLEID] ๐ŸŽ‰ TAR extraction complete.", flush=True)
else:
print("[STYLEID] โ„น๏ธ No TAR files found. Dataset may already be unpacked.", flush=True)
normalize_mtg_layout(DATA_ROOT_DIR)
normalize_flux_layout(DATA_ROOT_DIR)
actual_root = find_actual_root(DATA_ROOT_DIR)
if actual_root is None:
top = os.listdir(DATA_ROOT_DIR)[:50]
raise RuntimeError(
f"[STYLEID] โŒ Extraction finished but could not find expected dirs under '{DATA_ROOT_DIR}'. "
f"Top entries: {top}"
)
print(f"[STYLEID] โœ… Found extracted root: {actual_root}", flush=True)
return actual_root
DATA_ROOT_DIR = "data_rebuttal"
DATA_ROOT_DIR = ensure_data_downloaded()
T2_SOURCE_DIR = os.path.join(DATA_ROOT_DIR, "p2")
STYLIZED_ROOT_DIR = DATA_ROOT_DIR
# ===================== GOOGLE SHEETS =====================
def ensure_header(ws, header):
try:
first = ws.row_values(1)
except Exception:
first = []
if not first:
ws.insert_row(header, 1)
else:
if first[: len(header)] != header:
ws.insert_row(header, 1)
scope = [
"https://www.googleapis.com/auth/spreadsheets",
"https://www.googleapis.com/auth/drive",
]
creds = ServiceAccountCredentials.from_json_keyfile_name(SERVICE_ACCOUNT_FILE, scope)
gc = gspread.authorize(creds)
sh = gc.open_by_key(SPREADSHEET_KEY)
try:
sheet_t2 = sh.worksheet(SHEET_NAME_T2)
except gspread.WorksheetNotFound:
sheet_t2 = sh.add_worksheet(title=SHEET_NAME_T2, rows=8000, cols=30)
ensure_header(sheet_t2, HEADER_T2)
# ===================== INDEX BUILDING =====================
def build_source_base_map(source_dir: str):
"""base -> filename (assumes png)"""
if not os.path.isdir(source_dir):
raise RuntimeError(f"[STYLEID] โŒ Source dir not found: {source_dir} (need p2)")
files = [f for f in os.listdir(source_dir) if f.lower().endswith(".png")]
base2file = {}
for f in files:
base = os.path.splitext(f)[0]
base2file[base] = f
return base2file
SOURCE_BASE2FILE = build_source_base_map(T2_SOURCE_DIR)
SOURCE_BASES = set(SOURCE_BASE2FILE.keys())
print(f"[STYLEID] โœ… p2 bases: {len(SOURCE_BASES)}", flush=True)
def build_stylized_buckets():
"""
Buckets per (set_key, style_name):
BUCKETS[set_key][style_name] = [(base, strength, full_path), ...] (only if base exists in p2)
Also:
VALID_BASES_BY_SET[set_key] = list of bases that have ANY stylized in that set (and exist in p2)
"""
buckets = {k: {s: [] for s in STYLES_BY_SET[k]} for k in SETS}
valid_bases_by_set = {k: set() for k in SETS}
# ipa ์ชฝ: base_3 / base-3
pat1 = re.compile(r"^(.*)_([0-9]+)$")
pat2 = re.compile(r"^(.*)-([0-9]+)$")
for set_key, set_dir in SET_DIR_MAP.items():
set_path = os.path.join(STYLIZED_ROOT_DIR, set_dir)
if not os.path.isdir(set_path):
raise RuntimeError(f"[STYLEID] โŒ Missing set dir: {set_path}")
print(f"[STYLEID] ๐Ÿ”Ž scanning set={set_key} path={set_path}", flush=True)
for style_name in STYLES_BY_SET[set_key]:
style_path = os.path.join(set_path, style_name)
if not os.path.isdir(style_path):
raise RuntimeError(f"[STYLEID] โŒ Missing style dir: {style_path}")
total_png = 0
parsed = 0
kept = 0
for root, _, files in os.walk(style_path):
for fname in files:
if not fname.lower().endswith(".png"):
continue
total_png += 1
stem = os.path.splitext(fname)[0]
m = pat1.match(stem) or pat2.match(stem)
if m:
base = m.group(1)
strength = int(m.group(2))
else:
base = stem
strength = 0
parsed += 1
if base not in SOURCE_BASES:
continue
full_path = os.path.join(root, fname)
buckets[set_key][style_name].append((base, strength, full_path))
valid_bases_by_set[set_key].add(base)
kept += 1
print(
f"[STYLEID] - {style_name}: total_png={total_png}, parsed={parsed}, kept(after p2 filter)={kept}",
flush=True,
)
valid_bases_by_set = {k: sorted(list(v)) for k, v in valid_bases_by_set.items()}
return buckets, valid_bases_by_set
STYLIZED_BUCKETS, VALID_BASES_BY_SET = build_stylized_buckets()
# ===================== SCHEDULING =====================
def make_balanced_binary_schedule(n: int):
"""Return list of length n containing 1/0 balanced as much as possible."""
pos = n // 2
neg = n - pos
arr = [1] * pos + [0] * neg
random.shuffle(arr)
return arr
def make_balanced_style_list(styles: list[str], n: int):
"""
Make a list length n where each style appears as evenly as possible.
"""
k = len(styles)
q = n // k
r = n % k
lst = []
for s in styles:
lst.extend([s] * q)
# distribute remainder randomly across styles
if r > 0:
lst.extend(random.sample(styles, r))
random.shuffle(lst)
return lst
def make_joint_set_style_schedule(n_total: int):
"""
Return list[(set_key, style_name)] length n_total
- sets are evenly distributed (e.g., 90 trials & 3 sets -> 30/30/30)
- within each set, styles are evenly distributed
"""
n_sets = len(SETS)
base = n_total // n_sets
rem = n_total % n_sets
counts = {k: base for k in SETS}
if rem > 0:
for k in random.sample(SETS, rem):
counts[k] += 1
items = []
for set_key, n in counts.items():
styles = make_balanced_style_list(STYLES_BY_SET[set_key], n)
items.extend([(set_key, s) for s in styles])
random.shuffle(items)
return items
# ===================== IMAGE LOADING =====================
def load_and_resize(path, size=(384, 384)):
if path is None:
return None
if not os.path.exists(path):
print(f"[STYLEID] โŒ Missing image: {path}", flush=True)
return None
img = Image.open(path).convert("RGB")
try:
from PIL import Image as ResampleImage
resampling_filter = ResampleImage.Resampling.BILINEAR
except ImportError:
resampling_filter = Image.BILINEAR
return img.resize(size, resampling_filter)
# ===================== TASK 2 (Verification) =====================
def make_new_trial_task2(state: dict):
if state.get("t2_trial_index", 0) >= N_TRIALS_TASK2:
status = f"Task 2 finished ({N_TRIALS_TASK2} trials)."
return None, None, status, state
# --- pop schedules ---
if not state.get("t2_item_schedule"):
# fallback
set_key, style_name = random.choice(SETS), random.choice(STYLES_BY_SET[random.choice(SETS)])
else:
set_key, style_name = state["t2_item_schedule"].pop()
set_dir = SET_DIR_MAP[set_key]
if not state.get("t2_pair_schedule"):
is_positive = random.choice([0, 1])
else:
is_positive = state["t2_pair_schedule"].pop()
# --- pick stylized first (to satisfy style schedule exactly) ---
bucket = STYLIZED_BUCKETS[set_key][style_name]
if not bucket:
raise RuntimeError(f"[STYLEID] โŒ Empty bucket: set={set_key}, style={style_name}")
styl_base, strength, stylized_path = random.choice(bucket)
stylized_name = os.path.basename(stylized_path)
# --- pick human according to pos/neg ---
if is_positive:
human_base = styl_base
else:
candidates = VALID_BASES_BY_SET[set_key]
if len(candidates) < 2:
raise RuntimeError(f"[STYLEID] โŒ Not enough bases for negative pair in set '{set_key}'")
# choose a different base
human_base = random.choice([b for b in candidates if b != styl_base])
human_name = SOURCE_BASE2FILE.get(human_base)
if human_name is None:
raise RuntimeError(f"[STYLEID] โŒ human base '{human_base}' not found in p2")
human_path = os.path.join(T2_SOURCE_DIR, human_name)
state["t2_trial_index"] = state.get("t2_trial_index", 0) + 1
state["t2_start_time"] = time.time()
state["t2_human_name"] = human_name
state["t2_human_path"] = human_path
state["t2_stylized_name"] = stylized_name
state["t2_stylized_path"] = stylized_path
state["t2_is_positive"] = int(is_positive)
state["t2_method"] = set_key
state["t2_method_dir"] = set_dir
state["t2_style"] = style_name
state["t2_strength"] = strength
status = f"Task 2 - Verification | Trial {state['t2_trial_index']} / {N_TRIALS_TASK2}"
return load_and_resize(human_path), load_and_resize(stylized_path), status, state
def submit_answer_task2(user_answer, state):
def vis_t2(show: bool):
return {
"t2_title": gr.update(visible=show),
"answer2": gr.update(visible=show, value=None if show else None),
"submit2": gr.update(visible=show),
}
def vis_end(show: bool):
return {
"end_title": gr.update(visible=show),
"end_text": gr.update(visible=show),
}
valid_choices = ["๊ฐ™์€ ์‚ฌ๋žŒ / Same person", "๋‹ค๋ฅธ ์‚ฌ๋žŒ / Different person"]
# no selection
if user_answer not in valid_choices:
feedback = "โš ๏ธ Please choose whether the two images show the same person or not."
status = f"Task 2 - Verification | Trial {state.get('t2_trial_index', 0)} / {N_TRIALS_TASK2}"
return (
gr.update(value=load_and_resize(state.get("t2_human_path")), visible=True),
gr.update(value=load_and_resize(state.get("t2_stylized_path")), visible=True),
gr.update(value=status, visible=True),
gr.update(value=feedback, visible=True),
state,
vis_t2(True)["t2_title"],
vis_t2(True)["answer2"],
vis_t2(True)["submit2"],
vis_end(False)["end_title"],
vis_end(False)["end_text"],
)
if state.get("t2_start_time") is None:
feedback = "โš ๏ธ The trial has not started yet. Please click 'Start Study' first."
return (
gr.update(value=load_and_resize(state.get("t2_human_path")), visible=True),
gr.update(value=load_and_resize(state.get("t2_stylized_path")), visible=True),
gr.update(value="", visible=True),
gr.update(value=feedback, visible=True),
state,
vis_t2(True)["t2_title"],
vis_t2(True)["answer2"],
vis_t2(True)["submit2"],
vis_end(False)["end_title"],
vis_end(False)["end_text"],
)
end_time = time.time()
rt_ms = int((end_time - state["t2_start_time"]) * 1000)
is_positive = int(state["t2_is_positive"])
user_pos = 1 if user_answer == "๊ฐ™์€ ์‚ฌ๋žŒ / Same person" else 0
is_correct = 1 if user_pos == is_positive else 0
is_repeat_trial = 1 if state.get("t2_is_repeat_trial", False) else 0
repeat_match = ""
if is_repeat_trial and state.get("t2_repeat_first_answer") is not None:
repeat_match = 1 if user_answer == state["t2_repeat_first_answer"] else 0
row = [
datetime.now().isoformat(),
state["participant_id"],
state.get("name"),
state.get("age"),
state.get("gender"),
state.get("ethnicity"),
"task2",
state["t2_trial_index"],
state["t2_human_name"],
state["t2_stylized_name"],
is_positive,
"same" if user_pos else "different",
is_correct,
state["t2_method"],
state["t2_method_dir"],
state["t2_style"],
state["t2_strength"],
rt_ms,
is_repeat_trial,
repeat_match,
]
sheet_t2.append_row(row)
# if this was repeat trial -> end
if state.get("t2_is_repeat_trial", False):
state["t2_is_repeat_trial"] = False
state["t2_repeat_done"] = True
status = "Task 2 finished."
feedback = "โœ… Task 2 finished (attention check complete). Thank you!"
return (
gr.update(value=None, visible=False),
gr.update(value=None, visible=False),
gr.update(value=status, visible=True),
gr.update(value=feedback, visible=True),
state,
vis_t2(False)["t2_title"],
vis_t2(False)["answer2"],
vis_t2(False)["submit2"],
vis_end(True)["end_title"],
vis_end(True)["end_text"],
)
# last original trial -> start attention repeat
if state.get("t2_trial_index", 0) == N_TRIALS_TASK2 and not state.get("t2_repeat_done", False):
state["t2_repeat_first_answer"] = user_answer
state["t2_is_repeat_trial"] = True
state["t2_start_time"] = time.time()
status2 = "Task 2 - Attention check trial.\nPlease answer once more."
feedback2 = f"Response saved (RT: {rt_ms} ms).\nPlease answer the same question once again."
return (
gr.update(value=load_and_resize(state.get("t2_human_path")), visible=True),
gr.update(value=load_and_resize(state.get("t2_stylized_path")), visible=True),
gr.update(value=status2, visible=True),
gr.update(value=feedback2, visible=True),
state,
vis_t2(True)["t2_title"],
vis_t2(True)["answer2"],
vis_t2(True)["submit2"],
vis_end(False)["end_title"],
vis_end(False)["end_text"],
)
# normal -> next trial
imgH, imgS, status2, state = make_new_trial_task2(state)
feedback = f"Response saved (RT: {rt_ms} ms). Next trial is shown."
return (
gr.update(value=imgH, visible=True),
gr.update(value=imgS, visible=True),
gr.update(value=status2, visible=True),
gr.update(value=feedback, visible=True),
state,
vis_t2(True)["t2_title"],
vis_t2(True)["answer2"],
vis_t2(True)["submit2"],
vis_end(False)["end_title"],
vis_end(False)["end_text"],
)
# ===================== UI =====================
with gr.Blocks() as demo:
# Intro
with gr.Column() as intro_col:
intro_md = gr.Markdown(
"""
# ์•„์ด๋ดํ‹ฐํ‹ฐ ๋ณด์กด ์‚ฌ์šฉ์ž ์—ฐ๊ตฌ
## Task 2 โ€“ Identity Verification
- ํ•œ ์žฅ์˜ ์›๋ณธ ์–ผ๊ตด ์ด๋ฏธ์ง€์™€ ํ•œ ์žฅ์˜ ์Šคํƒ€์ผํ™”๋œ ์–ผ๊ตด ์ด๋ฏธ์ง€๊ฐ€ ์ œ์‹œ๋ฉ๋‹ˆ๋‹ค.
- ๋‘ ์ด๋ฏธ์ง€๊ฐ€ **๊ฐ™์€ ์‚ฌ๋žŒ์ธ์ง€ / ๋‹ค๋ฅธ ์‚ฌ๋žŒ์ธ์ง€** ํŒ๋‹จํ•ด ์ฃผ์„ธ์š”.
**โš ๏ธ ์ค‘์š”:** ์ƒˆ๋กœ๊ณ ์นจ/๋’ค๋กœ๊ฐ€๊ธฐ๋Š” ์ง„ํ–‰์ƒํ™ฉ ์ดˆ๊ธฐํ™”๋ฉ๋‹ˆ๋‹ค.
# Identity Preservation User Study
## Task 2 โ€“ Identity Verification
- You will see one real (source) face image and one stylized face image.
- Please decide whether the two images show **the same person** or **different people**.
**โš ๏ธ Important:** Refreshing the page or navigating away will **reset your progress**.
""",
visible=True,
)
intro_info_md = gr.Markdown("## Participant Information / ์ฐธ๊ฐ€์ž ์ •๋ณด", visible=True)
with gr.Row():
name = gr.Textbox(label="์ด๋ฆ„ (Name)", visible=True)
age = gr.Textbox(label="๋‚˜์ด (Age)", placeholder="e.g., 25", visible=True)
with gr.Row():
gender = gr.Radio(choices=["Female", "Male"], label="์„ฑ๋ณ„ (Gender)", interactive=True, visible=True)
ethnicity = gr.Radio(
choices=["Asian", "Black", "White", "Hispanic/Latino", "Middle Eastern", "Mixed/Other", "Prefer not to say"],
label="์ธ์ข… (Ethnicity)",
interactive=True,
visible=True,
)
intro_btn = gr.Button("์‹œ์ž‘ / Start Study", visible=True)
# Task2
with gr.Column() as task2_col:
t2_title_md = gr.Markdown("## Task 2 : ์•„์ด๋ดํ‹ฐํ‹ฐ ํ™•์ธ (Identity Verification)", visible=False)
status2 = gr.Markdown("", visible=False)
feedback2 = gr.Markdown("", visible=False)
with gr.Row():
imgH = gr.Image(label="์›๋ณธ ์–ผ๊ตด / Source", type="pil", height=320, visible=True)
imgS = gr.Image(label="์Šคํƒ€์ผํ™” ์–ผ๊ตด / Stylized", type="pil", height=320, visible=True)
answer2 = gr.Radio(
choices=["๊ฐ™์€ ์‚ฌ๋žŒ / Same person", "๋‹ค๋ฅธ ์‚ฌ๋žŒ / Different person"],
label="๋‘ ์ด๋ฏธ์ง€๋Š” ๊ฐ™์€ ์‚ฌ๋žŒ์ธ๊ฐ€์š”, ๋‹ค๋ฅธ ์‚ฌ๋žŒ์ธ๊ฐ€์š”? \nDo these two images show the same person?",
interactive=True,
visible=True,
)
submit2 = gr.Button("Submit response (Task 2)", visible=False)
# End
with gr.Column() as end_col:
end_title_md = gr.Markdown("## Thank you!", visible=False)
end_text_md = gr.Markdown("Task 2๋ฅผ ๋ชจ๋‘ ์™„๋ฃŒํ•˜์…จ์Šต๋‹ˆ๋‹ค. ์ฐธ์—ฌํ•ด์ฃผ์…”์„œ ๊ฐ์‚ฌํ•ฉ๋‹ˆ๋‹ค.", visible=False)
# State
state = gr.State(
{
"participant_id": None,
"name": None,
"age": None,
"gender": None,
"ethnicity": None,
"t2_trial_index": 0,
"t2_pair_schedule": [], # pos/neg 50:50
"t2_item_schedule": [], # (set,style) schedule with balanced set & balanced styles within set
"t2_start_time": None,
"t2_human_name": None,
"t2_stylized_name": None,
"t2_human_path": None,
"t2_stylized_path": None,
"t2_is_positive": None,
"t2_method": None,
"t2_method_dir": None,
"t2_style": None,
"t2_strength": None,
"t2_is_repeat_trial": False,
"t2_repeat_first_answer": None,
"t2_repeat_done": False,
}
)
# Events
def init_study(name_v, age_v, gender_choice, ethnicity_v, st):
if not age_v or not gender_choice or not ethnicity_v:
status = "โš ๏ธ Please fill in age, gender, and ethnicity."
return (
gr.update(value=None, visible=True),
gr.update(value=None, visible=True),
gr.update(value="", visible=True),
gr.update(value=status, visible=True),
st,
gr.update(visible=True),
gr.update(visible=True),
gr.update(visible=True),
gr.update(visible=True),
gr.update(visible=True),
gr.update(visible=True),
gr.update(visible=True),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
)
gender_code = "f" if gender_choice == "Female" else "m"
if st.get("participant_id") is None:
st["participant_id"] = uuid.uuid4().hex
st["name"] = name_v
st["age"] = age_v
st["gender"] = gender_code
st["ethnicity"] = ethnicity_v
st["t2_trial_index"] = 0
st["t2_is_repeat_trial"] = False
st["t2_repeat_first_answer"] = None
st["t2_repeat_done"] = False
# pos/neg schedule 50:50
st["t2_pair_schedule"] = make_balanced_binary_schedule(N_TRIALS_TASK2)
st["t2_item_schedule"] = make_joint_set_style_schedule(N_TRIALS_TASK2)
imgH_val, imgS_val, status, st = make_new_trial_task2(st)
return (
gr.update(value=imgH_val, visible=True),
gr.update(value=imgS_val, visible=True),
gr.update(value=status, visible=True),
gr.update(value="Task 2 started. Please answer the question below.", visible=True),
st,
gr.update(visible=False), # intro_md
gr.update(visible=False), # intro_info_md
gr.update(visible=False), # name
gr.update(visible=False), # age
gr.update(visible=False), # gender
gr.update(visible=False), # ethnicity
gr.update(visible=False), # intro_btn
gr.update(visible=True), # t2_title_md
gr.update(visible=True), # answer2
gr.update(visible=True), # submit2
gr.update(visible=False), # end_title_md
gr.update(visible=False), # end_text_md
)
intro_btn.click(
init_study,
inputs=[name, age, gender, ethnicity, state],
outputs=[
imgH, imgS, status2, feedback2, state,
intro_md, intro_info_md, name, age, gender, ethnicity, intro_btn,
t2_title_md, answer2, submit2,
end_title_md, end_text_md,
],
)
submit2.click(
submit_answer_task2,
inputs=[answer2, state],
outputs=[
imgH, imgS, status2, feedback2, state,
t2_title_md, answer2, submit2,
end_title_md, end_text_md,
],
)
if __name__ == "__main__":
demo.launch()