Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import random | |
| import time | |
| import uuid | |
| from datetime import datetime | |
| import tarfile | |
| from PIL import Image | |
| import shutil | |
| import gradio as gr | |
| import gspread | |
| from oauth2client.service_account import ServiceAccountCredentials | |
| from huggingface_hub import snapshot_download | |
| # ===================== CONFIG ===================== | |
| DATASET_REPO_ID = "zaqxsw0526/styleid-data-rebuttal" | |
| DATA_ROOT_DIR = "data_rebuttal" | |
| T2_SOURCE_DIR = os.path.join(DATA_ROOT_DIR, "p2") | |
| STYLIZED_ROOT_DIR = DATA_ROOT_DIR | |
| SET_DIR_MAP = { | |
| "ipa_cp": "ipa_cross_prompt", | |
| "mtg_cm": "mtg_cross_method", | |
| "flux_cm": "flux_cross_method", | |
| } | |
| SETS = list(SET_DIR_MAP.keys()) | |
| STYLES_BY_SET = { | |
| "ipa_cp": ["comic", "low-poly", "manga", "oil-painting", "pixel-art", "renaissance", "sketch", "ukiyo-e"], | |
| "mtg_cm": ["brave", "detroit", "doc_brown", "jojo", "moana", "picasso", "pocahontas", "titan_erwin"], | |
| "flux_cm": ["comic", "low-poly", "manga", "oil-painting", "pixel-art", "renaissance", "sketch", "ukiyo-e"], | |
| } | |
| N_TRIALS_TASK2 = 90 | |
| # ---- Google Sheets config ---- | |
| SERVICE_ACCOUNT_FILE = "styleid-479107-ccb8a1b36ea4.json" | |
| SPREADSHEET_KEY = "1nWQ9UY3BDlpegclD1zy40Kfjj-0HZcoDc16RfO4QFQM" | |
| SHEET_NAME_T2 = "rebuttal_t2" | |
| HEADER_T2 = [ | |
| "timestamp", | |
| "participant_id", | |
| "name", | |
| "age", | |
| "gender", | |
| "ethnicity", | |
| "task", | |
| "trial_index", | |
| "human_name", | |
| "stylized_name", | |
| "is_positive_pair", | |
| "user_answer", | |
| "is_correct", | |
| "method", # set_key (ipa_cp / mtg_cm) | |
| "method_dir", # set_dir (ipa_cross_prompt / mtg_cross_method) | |
| "style", # style folder name | |
| "strength", | |
| "response_time_ms", | |
| "is_repeat_trial", # 0 or 1 | |
| "repeat_answer_match", # 1 (same), 0 (different), "" (N/A) | |
| ] | |
| # ===================== DATA DOWNLOAD / EXTRACT ===================== | |
| def ensure_data_downloaded(): | |
| """ | |
| Download rebuttal dataset into DATA_ROOT_DIR if not present. | |
| Extract tar files if included. | |
| Fixes the case where mtg_cross_method.tar extracts its style folders directly under DATA_ROOT_DIR. | |
| Returns the actual root directory. | |
| """ | |
| print(f"[STYLEID] ๐ Checking dataset under ./{DATA_ROOT_DIR} ...", flush=True) | |
| MTG_STYLE_DIRS = [ | |
| "brave", | |
| "detroit", | |
| "doc_brown", | |
| "jojo", | |
| "moana", | |
| "picasso", | |
| "pocahontas", | |
| "titan_erwin", | |
| ] | |
| FLUX_STYLE_DIRS = ["comic", "low-poly", "manga", "oil-painting", "pixel-art", "renaissance", "sketch", "ukiyo-e"] | |
| def normalize_mtg_layout(base_dir: str): | |
| """ | |
| If mtg_cross_method/ doesn't exist but style dirs exist directly under base_dir, | |
| move them into base_dir/mtg_cross_method/<style>. | |
| """ | |
| mtg_dir = os.path.join(base_dir, "mtg_cross_method") | |
| if os.path.isdir(mtg_dir): | |
| return | |
| existing = [s for s in MTG_STYLE_DIRS if os.path.isdir(os.path.join(base_dir, s))] | |
| if not existing: | |
| return | |
| os.makedirs(mtg_dir, exist_ok=True) | |
| for s in existing: | |
| src = os.path.join(base_dir, s) | |
| dst = os.path.join(mtg_dir, s) | |
| if os.path.exists(dst): | |
| continue | |
| shutil.move(src, dst) | |
| print(f"[STYLEID] โ Normalized mtg layout into '{mtg_dir}' (moved {len(existing)} style dirs).", flush=True) | |
| def normalize_flux_layout(base_dir: str): | |
| """ | |
| If flux_cross_method/ doesn't exist but flux style dirs exist directly under base_dir, | |
| move them into base_dir/flux_cross_method/<style>. | |
| """ | |
| flux_dir = os.path.join(base_dir, "flux_cross_method") | |
| if os.path.isdir(flux_dir): | |
| return | |
| existing = [s for s in FLUX_STYLE_DIRS if os.path.isdir(os.path.join(base_dir, s))] | |
| if not existing: | |
| return | |
| os.makedirs(flux_dir, exist_ok=True) | |
| for s in existing: | |
| src = os.path.join(base_dir, s) | |
| dst = os.path.join(flux_dir, s) | |
| if os.path.exists(dst): | |
| continue | |
| shutil.move(src, dst) | |
| print(f"[STYLEID] โ Normalized flux layout into '{flux_dir}' (moved {len(existing)} style dirs).", flush=True) | |
| def find_actual_root(base_dir: str): | |
| # direct | |
| direct_ok = ( | |
| os.path.isdir(os.path.join(base_dir, "ipa_cross_prompt")) | |
| and os.path.isdir(os.path.join(base_dir, "mtg_cross_method")) | |
| and os.path.isdir(os.path.join(base_dir, "flux_cross_method")) | |
| ) | |
| if direct_ok: | |
| return base_dir | |
| # one-level nested (prefix folder inside tar) | |
| if not os.path.isdir(base_dir): | |
| return None | |
| for name in os.listdir(base_dir): | |
| cand = os.path.join(base_dir, name) | |
| if not os.path.isdir(cand): | |
| continue | |
| nested_ok = ( | |
| os.path.isdir(os.path.join(cand, "ipa_cross_prompt")) | |
| and os.path.isdir(os.path.join(cand, "mtg_cross_method")) | |
| ) | |
| if nested_ok: | |
| return cand | |
| return None | |
| # already extracted? | |
| normalize_mtg_layout(DATA_ROOT_DIR) | |
| normalize_flux_layout(DATA_ROOT_DIR) | |
| actual_root = find_actual_root(DATA_ROOT_DIR) | |
| if actual_root is not None: | |
| print(f"[STYLEID] โ Data already present. Ready. (root={actual_root})", flush=True) | |
| return actual_root | |
| os.makedirs(DATA_ROOT_DIR, exist_ok=True) | |
| print(f"[STYLEID] โฌ๏ธ Downloading dataset '{DATASET_REPO_ID}' into '{DATA_ROOT_DIR}' ...", flush=True) | |
| try: | |
| snapshot_download( | |
| repo_id=DATASET_REPO_ID, | |
| repo_type="dataset", | |
| local_dir=DATA_ROOT_DIR, | |
| local_dir_use_symlinks=False, | |
| allow_patterns=["p2/**", "ipa_cross_prompt.tar", "mtg_cross_method.tar", "flux_cross_method.tar"], | |
| ) | |
| except TypeError: | |
| snapshot_download( | |
| repo_id=DATASET_REPO_ID, | |
| repo_type="dataset", | |
| local_dir=DATA_ROOT_DIR, | |
| local_dir_use_symlinks=False, | |
| ) | |
| print("[STYLEID] โ Download complete. Searching for TAR files to extract ...", flush=True) | |
| def is_within_directory(directory: str, target: str) -> bool: | |
| abs_directory = os.path.abspath(directory) | |
| abs_target = os.path.abspath(target) | |
| return os.path.commonprefix([abs_directory, abs_target]) == abs_directory | |
| def safe_extract(tar: tarfile.TarFile, path: str): | |
| for member in tar.getmembers(): | |
| member_path = os.path.join(path, member.name) | |
| if not is_within_directory(path, member_path): | |
| raise RuntimeError(f"Unsafe path detected in tar: {member.name}") | |
| tar.extractall(path) | |
| tar_exts = (".tar", ".tar.gz", ".tgz", ".tar.bz2", ".tbz2", ".tar.xz", ".txz") | |
| extracted_any = False | |
| for root, _, files in os.walk(DATA_ROOT_DIR): | |
| for fname in files: | |
| if fname.lower().endswith(tar_exts): | |
| tar_path = os.path.join(root, fname) | |
| print(f"[STYLEID] ๐ฆ Extracting '{tar_path}' ...", flush=True) | |
| with tarfile.open(tar_path, "r:*") as tf: | |
| safe_extract(tf, DATA_ROOT_DIR) | |
| print(f"[STYLEID] โ Finished extracting '{tar_path}'.", flush=True) | |
| extracted_any = True | |
| if extracted_any: | |
| print("[STYLEID] ๐ TAR extraction complete.", flush=True) | |
| else: | |
| print("[STYLEID] โน๏ธ No TAR files found. Dataset may already be unpacked.", flush=True) | |
| normalize_mtg_layout(DATA_ROOT_DIR) | |
| normalize_flux_layout(DATA_ROOT_DIR) | |
| actual_root = find_actual_root(DATA_ROOT_DIR) | |
| if actual_root is None: | |
| top = os.listdir(DATA_ROOT_DIR)[:50] | |
| raise RuntimeError( | |
| f"[STYLEID] โ Extraction finished but could not find expected dirs under '{DATA_ROOT_DIR}'. " | |
| f"Top entries: {top}" | |
| ) | |
| print(f"[STYLEID] โ Found extracted root: {actual_root}", flush=True) | |
| return actual_root | |
| DATA_ROOT_DIR = "data_rebuttal" | |
| DATA_ROOT_DIR = ensure_data_downloaded() | |
| T2_SOURCE_DIR = os.path.join(DATA_ROOT_DIR, "p2") | |
| STYLIZED_ROOT_DIR = DATA_ROOT_DIR | |
| # ===================== GOOGLE SHEETS ===================== | |
| def ensure_header(ws, header): | |
| try: | |
| first = ws.row_values(1) | |
| except Exception: | |
| first = [] | |
| if not first: | |
| ws.insert_row(header, 1) | |
| else: | |
| if first[: len(header)] != header: | |
| ws.insert_row(header, 1) | |
| scope = [ | |
| "https://www.googleapis.com/auth/spreadsheets", | |
| "https://www.googleapis.com/auth/drive", | |
| ] | |
| creds = ServiceAccountCredentials.from_json_keyfile_name(SERVICE_ACCOUNT_FILE, scope) | |
| gc = gspread.authorize(creds) | |
| sh = gc.open_by_key(SPREADSHEET_KEY) | |
| try: | |
| sheet_t2 = sh.worksheet(SHEET_NAME_T2) | |
| except gspread.WorksheetNotFound: | |
| sheet_t2 = sh.add_worksheet(title=SHEET_NAME_T2, rows=8000, cols=30) | |
| ensure_header(sheet_t2, HEADER_T2) | |
| # ===================== INDEX BUILDING ===================== | |
| def build_source_base_map(source_dir: str): | |
| """base -> filename (assumes png)""" | |
| if not os.path.isdir(source_dir): | |
| raise RuntimeError(f"[STYLEID] โ Source dir not found: {source_dir} (need p2)") | |
| files = [f for f in os.listdir(source_dir) if f.lower().endswith(".png")] | |
| base2file = {} | |
| for f in files: | |
| base = os.path.splitext(f)[0] | |
| base2file[base] = f | |
| return base2file | |
| SOURCE_BASE2FILE = build_source_base_map(T2_SOURCE_DIR) | |
| SOURCE_BASES = set(SOURCE_BASE2FILE.keys()) | |
| print(f"[STYLEID] โ p2 bases: {len(SOURCE_BASES)}", flush=True) | |
| def build_stylized_buckets(): | |
| """ | |
| Buckets per (set_key, style_name): | |
| BUCKETS[set_key][style_name] = [(base, strength, full_path), ...] (only if base exists in p2) | |
| Also: | |
| VALID_BASES_BY_SET[set_key] = list of bases that have ANY stylized in that set (and exist in p2) | |
| """ | |
| buckets = {k: {s: [] for s in STYLES_BY_SET[k]} for k in SETS} | |
| valid_bases_by_set = {k: set() for k in SETS} | |
| # ipa ์ชฝ: base_3 / base-3 | |
| pat1 = re.compile(r"^(.*)_([0-9]+)$") | |
| pat2 = re.compile(r"^(.*)-([0-9]+)$") | |
| for set_key, set_dir in SET_DIR_MAP.items(): | |
| set_path = os.path.join(STYLIZED_ROOT_DIR, set_dir) | |
| if not os.path.isdir(set_path): | |
| raise RuntimeError(f"[STYLEID] โ Missing set dir: {set_path}") | |
| print(f"[STYLEID] ๐ scanning set={set_key} path={set_path}", flush=True) | |
| for style_name in STYLES_BY_SET[set_key]: | |
| style_path = os.path.join(set_path, style_name) | |
| if not os.path.isdir(style_path): | |
| raise RuntimeError(f"[STYLEID] โ Missing style dir: {style_path}") | |
| total_png = 0 | |
| parsed = 0 | |
| kept = 0 | |
| for root, _, files in os.walk(style_path): | |
| for fname in files: | |
| if not fname.lower().endswith(".png"): | |
| continue | |
| total_png += 1 | |
| stem = os.path.splitext(fname)[0] | |
| m = pat1.match(stem) or pat2.match(stem) | |
| if m: | |
| base = m.group(1) | |
| strength = int(m.group(2)) | |
| else: | |
| base = stem | |
| strength = 0 | |
| parsed += 1 | |
| if base not in SOURCE_BASES: | |
| continue | |
| full_path = os.path.join(root, fname) | |
| buckets[set_key][style_name].append((base, strength, full_path)) | |
| valid_bases_by_set[set_key].add(base) | |
| kept += 1 | |
| print( | |
| f"[STYLEID] - {style_name}: total_png={total_png}, parsed={parsed}, kept(after p2 filter)={kept}", | |
| flush=True, | |
| ) | |
| valid_bases_by_set = {k: sorted(list(v)) for k, v in valid_bases_by_set.items()} | |
| return buckets, valid_bases_by_set | |
| STYLIZED_BUCKETS, VALID_BASES_BY_SET = build_stylized_buckets() | |
| # ===================== SCHEDULING ===================== | |
| def make_balanced_binary_schedule(n: int): | |
| """Return list of length n containing 1/0 balanced as much as possible.""" | |
| pos = n // 2 | |
| neg = n - pos | |
| arr = [1] * pos + [0] * neg | |
| random.shuffle(arr) | |
| return arr | |
| def make_balanced_style_list(styles: list[str], n: int): | |
| """ | |
| Make a list length n where each style appears as evenly as possible. | |
| """ | |
| k = len(styles) | |
| q = n // k | |
| r = n % k | |
| lst = [] | |
| for s in styles: | |
| lst.extend([s] * q) | |
| # distribute remainder randomly across styles | |
| if r > 0: | |
| lst.extend(random.sample(styles, r)) | |
| random.shuffle(lst) | |
| return lst | |
| def make_joint_set_style_schedule(n_total: int): | |
| """ | |
| Return list[(set_key, style_name)] length n_total | |
| - sets are evenly distributed (e.g., 90 trials & 3 sets -> 30/30/30) | |
| - within each set, styles are evenly distributed | |
| """ | |
| n_sets = len(SETS) | |
| base = n_total // n_sets | |
| rem = n_total % n_sets | |
| counts = {k: base for k in SETS} | |
| if rem > 0: | |
| for k in random.sample(SETS, rem): | |
| counts[k] += 1 | |
| items = [] | |
| for set_key, n in counts.items(): | |
| styles = make_balanced_style_list(STYLES_BY_SET[set_key], n) | |
| items.extend([(set_key, s) for s in styles]) | |
| random.shuffle(items) | |
| return items | |
| # ===================== IMAGE LOADING ===================== | |
| def load_and_resize(path, size=(384, 384)): | |
| if path is None: | |
| return None | |
| if not os.path.exists(path): | |
| print(f"[STYLEID] โ Missing image: {path}", flush=True) | |
| return None | |
| img = Image.open(path).convert("RGB") | |
| try: | |
| from PIL import Image as ResampleImage | |
| resampling_filter = ResampleImage.Resampling.BILINEAR | |
| except ImportError: | |
| resampling_filter = Image.BILINEAR | |
| return img.resize(size, resampling_filter) | |
| # ===================== TASK 2 (Verification) ===================== | |
| def make_new_trial_task2(state: dict): | |
| if state.get("t2_trial_index", 0) >= N_TRIALS_TASK2: | |
| status = f"Task 2 finished ({N_TRIALS_TASK2} trials)." | |
| return None, None, status, state | |
| # --- pop schedules --- | |
| if not state.get("t2_item_schedule"): | |
| # fallback | |
| set_key, style_name = random.choice(SETS), random.choice(STYLES_BY_SET[random.choice(SETS)]) | |
| else: | |
| set_key, style_name = state["t2_item_schedule"].pop() | |
| set_dir = SET_DIR_MAP[set_key] | |
| if not state.get("t2_pair_schedule"): | |
| is_positive = random.choice([0, 1]) | |
| else: | |
| is_positive = state["t2_pair_schedule"].pop() | |
| # --- pick stylized first (to satisfy style schedule exactly) --- | |
| bucket = STYLIZED_BUCKETS[set_key][style_name] | |
| if not bucket: | |
| raise RuntimeError(f"[STYLEID] โ Empty bucket: set={set_key}, style={style_name}") | |
| styl_base, strength, stylized_path = random.choice(bucket) | |
| stylized_name = os.path.basename(stylized_path) | |
| # --- pick human according to pos/neg --- | |
| if is_positive: | |
| human_base = styl_base | |
| else: | |
| candidates = VALID_BASES_BY_SET[set_key] | |
| if len(candidates) < 2: | |
| raise RuntimeError(f"[STYLEID] โ Not enough bases for negative pair in set '{set_key}'") | |
| # choose a different base | |
| human_base = random.choice([b for b in candidates if b != styl_base]) | |
| human_name = SOURCE_BASE2FILE.get(human_base) | |
| if human_name is None: | |
| raise RuntimeError(f"[STYLEID] โ human base '{human_base}' not found in p2") | |
| human_path = os.path.join(T2_SOURCE_DIR, human_name) | |
| state["t2_trial_index"] = state.get("t2_trial_index", 0) + 1 | |
| state["t2_start_time"] = time.time() | |
| state["t2_human_name"] = human_name | |
| state["t2_human_path"] = human_path | |
| state["t2_stylized_name"] = stylized_name | |
| state["t2_stylized_path"] = stylized_path | |
| state["t2_is_positive"] = int(is_positive) | |
| state["t2_method"] = set_key | |
| state["t2_method_dir"] = set_dir | |
| state["t2_style"] = style_name | |
| state["t2_strength"] = strength | |
| status = f"Task 2 - Verification | Trial {state['t2_trial_index']} / {N_TRIALS_TASK2}" | |
| return load_and_resize(human_path), load_and_resize(stylized_path), status, state | |
| def submit_answer_task2(user_answer, state): | |
| def vis_t2(show: bool): | |
| return { | |
| "t2_title": gr.update(visible=show), | |
| "answer2": gr.update(visible=show, value=None if show else None), | |
| "submit2": gr.update(visible=show), | |
| } | |
| def vis_end(show: bool): | |
| return { | |
| "end_title": gr.update(visible=show), | |
| "end_text": gr.update(visible=show), | |
| } | |
| valid_choices = ["๊ฐ์ ์ฌ๋ / Same person", "๋ค๋ฅธ ์ฌ๋ / Different person"] | |
| # no selection | |
| if user_answer not in valid_choices: | |
| feedback = "โ ๏ธ Please choose whether the two images show the same person or not." | |
| status = f"Task 2 - Verification | Trial {state.get('t2_trial_index', 0)} / {N_TRIALS_TASK2}" | |
| return ( | |
| gr.update(value=load_and_resize(state.get("t2_human_path")), visible=True), | |
| gr.update(value=load_and_resize(state.get("t2_stylized_path")), visible=True), | |
| gr.update(value=status, visible=True), | |
| gr.update(value=feedback, visible=True), | |
| state, | |
| vis_t2(True)["t2_title"], | |
| vis_t2(True)["answer2"], | |
| vis_t2(True)["submit2"], | |
| vis_end(False)["end_title"], | |
| vis_end(False)["end_text"], | |
| ) | |
| if state.get("t2_start_time") is None: | |
| feedback = "โ ๏ธ The trial has not started yet. Please click 'Start Study' first." | |
| return ( | |
| gr.update(value=load_and_resize(state.get("t2_human_path")), visible=True), | |
| gr.update(value=load_and_resize(state.get("t2_stylized_path")), visible=True), | |
| gr.update(value="", visible=True), | |
| gr.update(value=feedback, visible=True), | |
| state, | |
| vis_t2(True)["t2_title"], | |
| vis_t2(True)["answer2"], | |
| vis_t2(True)["submit2"], | |
| vis_end(False)["end_title"], | |
| vis_end(False)["end_text"], | |
| ) | |
| end_time = time.time() | |
| rt_ms = int((end_time - state["t2_start_time"]) * 1000) | |
| is_positive = int(state["t2_is_positive"]) | |
| user_pos = 1 if user_answer == "๊ฐ์ ์ฌ๋ / Same person" else 0 | |
| is_correct = 1 if user_pos == is_positive else 0 | |
| is_repeat_trial = 1 if state.get("t2_is_repeat_trial", False) else 0 | |
| repeat_match = "" | |
| if is_repeat_trial and state.get("t2_repeat_first_answer") is not None: | |
| repeat_match = 1 if user_answer == state["t2_repeat_first_answer"] else 0 | |
| row = [ | |
| datetime.now().isoformat(), | |
| state["participant_id"], | |
| state.get("name"), | |
| state.get("age"), | |
| state.get("gender"), | |
| state.get("ethnicity"), | |
| "task2", | |
| state["t2_trial_index"], | |
| state["t2_human_name"], | |
| state["t2_stylized_name"], | |
| is_positive, | |
| "same" if user_pos else "different", | |
| is_correct, | |
| state["t2_method"], | |
| state["t2_method_dir"], | |
| state["t2_style"], | |
| state["t2_strength"], | |
| rt_ms, | |
| is_repeat_trial, | |
| repeat_match, | |
| ] | |
| sheet_t2.append_row(row) | |
| # if this was repeat trial -> end | |
| if state.get("t2_is_repeat_trial", False): | |
| state["t2_is_repeat_trial"] = False | |
| state["t2_repeat_done"] = True | |
| status = "Task 2 finished." | |
| feedback = "โ Task 2 finished (attention check complete). Thank you!" | |
| return ( | |
| gr.update(value=None, visible=False), | |
| gr.update(value=None, visible=False), | |
| gr.update(value=status, visible=True), | |
| gr.update(value=feedback, visible=True), | |
| state, | |
| vis_t2(False)["t2_title"], | |
| vis_t2(False)["answer2"], | |
| vis_t2(False)["submit2"], | |
| vis_end(True)["end_title"], | |
| vis_end(True)["end_text"], | |
| ) | |
| # last original trial -> start attention repeat | |
| if state.get("t2_trial_index", 0) == N_TRIALS_TASK2 and not state.get("t2_repeat_done", False): | |
| state["t2_repeat_first_answer"] = user_answer | |
| state["t2_is_repeat_trial"] = True | |
| state["t2_start_time"] = time.time() | |
| status2 = "Task 2 - Attention check trial.\nPlease answer once more." | |
| feedback2 = f"Response saved (RT: {rt_ms} ms).\nPlease answer the same question once again." | |
| return ( | |
| gr.update(value=load_and_resize(state.get("t2_human_path")), visible=True), | |
| gr.update(value=load_and_resize(state.get("t2_stylized_path")), visible=True), | |
| gr.update(value=status2, visible=True), | |
| gr.update(value=feedback2, visible=True), | |
| state, | |
| vis_t2(True)["t2_title"], | |
| vis_t2(True)["answer2"], | |
| vis_t2(True)["submit2"], | |
| vis_end(False)["end_title"], | |
| vis_end(False)["end_text"], | |
| ) | |
| # normal -> next trial | |
| imgH, imgS, status2, state = make_new_trial_task2(state) | |
| feedback = f"Response saved (RT: {rt_ms} ms). Next trial is shown." | |
| return ( | |
| gr.update(value=imgH, visible=True), | |
| gr.update(value=imgS, visible=True), | |
| gr.update(value=status2, visible=True), | |
| gr.update(value=feedback, visible=True), | |
| state, | |
| vis_t2(True)["t2_title"], | |
| vis_t2(True)["answer2"], | |
| vis_t2(True)["submit2"], | |
| vis_end(False)["end_title"], | |
| vis_end(False)["end_text"], | |
| ) | |
| # ===================== UI ===================== | |
| with gr.Blocks() as demo: | |
| # Intro | |
| with gr.Column() as intro_col: | |
| intro_md = gr.Markdown( | |
| """ | |
| # ์์ด๋ดํฐํฐ ๋ณด์กด ์ฌ์ฉ์ ์ฐ๊ตฌ | |
| ## Task 2 โ Identity Verification | |
| - ํ ์ฅ์ ์๋ณธ ์ผ๊ตด ์ด๋ฏธ์ง์ ํ ์ฅ์ ์คํ์ผํ๋ ์ผ๊ตด ์ด๋ฏธ์ง๊ฐ ์ ์๋ฉ๋๋ค. | |
| - ๋ ์ด๋ฏธ์ง๊ฐ **๊ฐ์ ์ฌ๋์ธ์ง / ๋ค๋ฅธ ์ฌ๋์ธ์ง** ํ๋จํด ์ฃผ์ธ์. | |
| **โ ๏ธ ์ค์:** ์๋ก๊ณ ์นจ/๋ค๋ก๊ฐ๊ธฐ๋ ์งํ์ํฉ ์ด๊ธฐํ๋ฉ๋๋ค. | |
| # Identity Preservation User Study | |
| ## Task 2 โ Identity Verification | |
| - You will see one real (source) face image and one stylized face image. | |
| - Please decide whether the two images show **the same person** or **different people**. | |
| **โ ๏ธ Important:** Refreshing the page or navigating away will **reset your progress**. | |
| """, | |
| visible=True, | |
| ) | |
| intro_info_md = gr.Markdown("## Participant Information / ์ฐธ๊ฐ์ ์ ๋ณด", visible=True) | |
| with gr.Row(): | |
| name = gr.Textbox(label="์ด๋ฆ (Name)", visible=True) | |
| age = gr.Textbox(label="๋์ด (Age)", placeholder="e.g., 25", visible=True) | |
| with gr.Row(): | |
| gender = gr.Radio(choices=["Female", "Male"], label="์ฑ๋ณ (Gender)", interactive=True, visible=True) | |
| ethnicity = gr.Radio( | |
| choices=["Asian", "Black", "White", "Hispanic/Latino", "Middle Eastern", "Mixed/Other", "Prefer not to say"], | |
| label="์ธ์ข (Ethnicity)", | |
| interactive=True, | |
| visible=True, | |
| ) | |
| intro_btn = gr.Button("์์ / Start Study", visible=True) | |
| # Task2 | |
| with gr.Column() as task2_col: | |
| t2_title_md = gr.Markdown("## Task 2 : ์์ด๋ดํฐํฐ ํ์ธ (Identity Verification)", visible=False) | |
| status2 = gr.Markdown("", visible=False) | |
| feedback2 = gr.Markdown("", visible=False) | |
| with gr.Row(): | |
| imgH = gr.Image(label="์๋ณธ ์ผ๊ตด / Source", type="pil", height=320, visible=True) | |
| imgS = gr.Image(label="์คํ์ผํ ์ผ๊ตด / Stylized", type="pil", height=320, visible=True) | |
| answer2 = gr.Radio( | |
| choices=["๊ฐ์ ์ฌ๋ / Same person", "๋ค๋ฅธ ์ฌ๋ / Different person"], | |
| label="๋ ์ด๋ฏธ์ง๋ ๊ฐ์ ์ฌ๋์ธ๊ฐ์, ๋ค๋ฅธ ์ฌ๋์ธ๊ฐ์? \nDo these two images show the same person?", | |
| interactive=True, | |
| visible=True, | |
| ) | |
| submit2 = gr.Button("Submit response (Task 2)", visible=False) | |
| # End | |
| with gr.Column() as end_col: | |
| end_title_md = gr.Markdown("## Thank you!", visible=False) | |
| end_text_md = gr.Markdown("Task 2๋ฅผ ๋ชจ๋ ์๋ฃํ์ จ์ต๋๋ค. ์ฐธ์ฌํด์ฃผ์ ์ ๊ฐ์ฌํฉ๋๋ค.", visible=False) | |
| # State | |
| state = gr.State( | |
| { | |
| "participant_id": None, | |
| "name": None, | |
| "age": None, | |
| "gender": None, | |
| "ethnicity": None, | |
| "t2_trial_index": 0, | |
| "t2_pair_schedule": [], # pos/neg 50:50 | |
| "t2_item_schedule": [], # (set,style) schedule with balanced set & balanced styles within set | |
| "t2_start_time": None, | |
| "t2_human_name": None, | |
| "t2_stylized_name": None, | |
| "t2_human_path": None, | |
| "t2_stylized_path": None, | |
| "t2_is_positive": None, | |
| "t2_method": None, | |
| "t2_method_dir": None, | |
| "t2_style": None, | |
| "t2_strength": None, | |
| "t2_is_repeat_trial": False, | |
| "t2_repeat_first_answer": None, | |
| "t2_repeat_done": False, | |
| } | |
| ) | |
| # Events | |
| def init_study(name_v, age_v, gender_choice, ethnicity_v, st): | |
| if not age_v or not gender_choice or not ethnicity_v: | |
| status = "โ ๏ธ Please fill in age, gender, and ethnicity." | |
| return ( | |
| gr.update(value=None, visible=True), | |
| gr.update(value=None, visible=True), | |
| gr.update(value="", visible=True), | |
| gr.update(value=status, visible=True), | |
| st, | |
| gr.update(visible=True), | |
| gr.update(visible=True), | |
| gr.update(visible=True), | |
| gr.update(visible=True), | |
| gr.update(visible=True), | |
| gr.update(visible=True), | |
| gr.update(visible=True), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| ) | |
| gender_code = "f" if gender_choice == "Female" else "m" | |
| if st.get("participant_id") is None: | |
| st["participant_id"] = uuid.uuid4().hex | |
| st["name"] = name_v | |
| st["age"] = age_v | |
| st["gender"] = gender_code | |
| st["ethnicity"] = ethnicity_v | |
| st["t2_trial_index"] = 0 | |
| st["t2_is_repeat_trial"] = False | |
| st["t2_repeat_first_answer"] = None | |
| st["t2_repeat_done"] = False | |
| # pos/neg schedule 50:50 | |
| st["t2_pair_schedule"] = make_balanced_binary_schedule(N_TRIALS_TASK2) | |
| st["t2_item_schedule"] = make_joint_set_style_schedule(N_TRIALS_TASK2) | |
| imgH_val, imgS_val, status, st = make_new_trial_task2(st) | |
| return ( | |
| gr.update(value=imgH_val, visible=True), | |
| gr.update(value=imgS_val, visible=True), | |
| gr.update(value=status, visible=True), | |
| gr.update(value="Task 2 started. Please answer the question below.", visible=True), | |
| st, | |
| gr.update(visible=False), # intro_md | |
| gr.update(visible=False), # intro_info_md | |
| gr.update(visible=False), # name | |
| gr.update(visible=False), # age | |
| gr.update(visible=False), # gender | |
| gr.update(visible=False), # ethnicity | |
| gr.update(visible=False), # intro_btn | |
| gr.update(visible=True), # t2_title_md | |
| gr.update(visible=True), # answer2 | |
| gr.update(visible=True), # submit2 | |
| gr.update(visible=False), # end_title_md | |
| gr.update(visible=False), # end_text_md | |
| ) | |
| intro_btn.click( | |
| init_study, | |
| inputs=[name, age, gender, ethnicity, state], | |
| outputs=[ | |
| imgH, imgS, status2, feedback2, state, | |
| intro_md, intro_info_md, name, age, gender, ethnicity, intro_btn, | |
| t2_title_md, answer2, submit2, | |
| end_title_md, end_text_md, | |
| ], | |
| ) | |
| submit2.click( | |
| submit_answer_task2, | |
| inputs=[answer2, state], | |
| outputs=[ | |
| imgH, imgS, status2, feedback2, state, | |
| t2_title_md, answer2, submit2, | |
| end_title_md, end_text_md, | |
| ], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |