Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import glob | |
| import json | |
| import random | |
| import re | |
| from functools import partial | |
| from datetime import datetime | |
| from collections import defaultdict, Counter | |
| import gradio as gr | |
| from loguru import logger | |
| # --- Global State (unchanged) --- | |
| # --- Global State (unchanged) --- | |
| GLOBAL_STATE = { | |
| "participant_id": None, | |
| "data_loaded": False, | |
| "all_eval_data": [], | |
| "shuffled_indices": [], | |
| "current_prompt_index": 0, | |
| "current_criterion_index": 0, | |
| "image_mapping": {}, | |
| "image_dir": "", | |
| "evaluation_results": {}, | |
| "image_orders": {}, | |
| "start_time": None, | |
| "end_time": None, | |
| "current_ranks": {}, | |
| "current_absolute_score": None, | |
| # ▼▼▼ 追加 ▼▼▼ | |
| "current_absolute_score_worst": None, | |
| } | |
| # --- Configuration (unchanged) --- | |
| BASE_RESULTS_DIR = "./results" | |
| LOG_DIR = "./logs" | |
| COMBINED_DATA_DIR = "./combined_data" | |
| IMAGE_SUBDIR = os.path.join("lapwing", "images") | |
| MAPPING_FILENAME = "combination_to_filename.json" | |
| CONDITIONS = ["Ours", "w_o_Proto_Loss", "w_o_HitL", "w_o_Tuning", "LLM-based"] | |
| CRITERIA = ["Alignment", "Naturalness", "Attractiveness"] | |
| CRITERIA_GUIDANCE_JP = [ | |
| "テキストと表情がどれだけ一致しているか", | |
| "テキストの感情に沿ったセリフを言っていると想像したとき、表情がどれだけ自然か", | |
| "テキストの感情に沿ったセリフを言っていると想像したとき、表情がどれだけ魅力的か" | |
| ] | |
| CRITERIA_GUIDANCE_EN = [ | |
| "how well the expression aligns with the text", | |
| "imagining the character is speaking a line that matches the emotion of the text, how natural the facial expression is", | |
| "imagining the character is speaking a line that matches the emotion of the text, how attractive the facial expression is" | |
| ] | |
| IMAGE_LABELS = ['A', 'B', 'C', 'D', 'E'] | |
| # --- Helper Functions --- | |
| def get_image_path_from_prediction(prediction: dict) -> str: | |
| if not GLOBAL_STATE["image_mapping"]: | |
| logger.error("Image mapping is not loaded.") | |
| return "" | |
| indices = prediction.get("blendshape_index", {}) | |
| if not isinstance(indices, dict): | |
| logger.error(f"blendshape_index is not a dictionary: {indices}") | |
| return "" | |
| sorted_indices = sorted(indices.items(), key=lambda item: int(item[0])) | |
| key = ",".join(str(idx) for _, idx in sorted_indices) | |
| filename = GLOBAL_STATE["image_mapping"].get(key) | |
| if not filename: | |
| logger.warning(f"No image found for blendshape key: {key}") | |
| return "" | |
| return os.path.join(GLOBAL_STATE["image_dir"], filename) | |
| # ▼▼▼ 2. prompt_categoryを読み込むように修正 ▼▼▼ | |
| def load_evaluation_data(participant_id: str): | |
| mapping_path = os.path.join(COMBINED_DATA_DIR, MAPPING_FILENAME) | |
| if not os.path.exists(mapping_path): | |
| return f"<p class='feedback red'>Error: Mapping file not found at {mapping_path}</p>", gr.update( | |
| interactive=True), gr.update(interactive=False) | |
| with open(mapping_path, 'r', encoding='utf-8') as f: | |
| GLOBAL_STATE["image_mapping"] = json.load(f)["mapping"] | |
| GLOBAL_STATE["image_dir"] = os.path.join(COMBINED_DATA_DIR, IMAGE_SUBDIR) | |
| logger.info(f"Successfully loaded image mapping. Image directory: {GLOBAL_STATE['image_dir']}") | |
| participant_dir = os.path.join(BASE_RESULTS_DIR, participant_id) | |
| if not os.path.isdir(participant_dir): | |
| return f"<p class='feedback red'>Error: Participant directory not found: {participant_dir}</p>", gr.update( | |
| interactive=True), gr.update(interactive=False) | |
| merged_data = defaultdict(lambda: {"predictions": {}, "category": None}) | |
| found_files = 0 | |
| for cond in CONDITIONS: | |
| cond_dir = os.path.join(participant_dir, cond) | |
| pattern = os.path.join(cond_dir, f"{participant_id}_{cond}_*.jsonl") | |
| files = glob.glob(pattern) | |
| if not files: | |
| logger.warning(f"No prediction file found for condition '{cond}' with pattern: {pattern}") | |
| continue | |
| found_files += 1 | |
| with open(files[0], 'r', encoding='utf-8') as f: | |
| for line in f: | |
| data = json.loads(line) | |
| prompt = data["text_prompt"] | |
| merged_data[prompt]["predictions"][cond] = data["prediction"] | |
| if not merged_data[prompt]["category"]: | |
| merged_data[prompt]["category"] = data.get("prompt_category") | |
| if found_files != len(CONDITIONS): | |
| return f"<p class='feedback red'>Error: Found prediction files for only {found_files}/{len(CONDITIONS)} conditions.</p>", gr.update( | |
| interactive=True), gr.update(interactive=False) | |
| GLOBAL_STATE["all_eval_data"] = [ | |
| {"prompt": p, "predictions": d["predictions"], "category": d["category"]} | |
| for p, d in merged_data.items() if len(d["predictions"]) == len(CONDITIONS) | |
| ] | |
| # ▲▲▲ END OF UPDATE ▲▲▲ | |
| if not GLOBAL_STATE["all_eval_data"]: | |
| return "<p class='feedback red'>Error: No valid evaluation data could be loaded.</p>", gr.update( | |
| interactive=True), gr.update(interactive=False) | |
| GLOBAL_STATE["shuffled_indices"] = list(range(len(GLOBAL_STATE["all_eval_data"]))) | |
| random.shuffle(GLOBAL_STATE["shuffled_indices"]) | |
| GLOBAL_STATE["current_prompt_index"] = 0 | |
| GLOBAL_STATE["current_criterion_index"] = 0 | |
| GLOBAL_STATE["data_loaded"] = True | |
| GLOBAL_STATE["start_time"] = datetime.now() | |
| for i in range(len(GLOBAL_STATE["all_eval_data"])): | |
| prompt_text = GLOBAL_STATE["all_eval_data"][i]["prompt"] | |
| GLOBAL_STATE["evaluation_results"][prompt_text] = {} | |
| logger.info(f"Loaded and merged data for {len(GLOBAL_STATE['all_eval_data'])} prompts.") | |
| done_msg = "<p class='feedback green'>Data loaded successfully. Please proceed to the 'Evaluation' tab. / データの読み込みに成功しました。「評価」タブに進んでください。</p>" | |
| return done_msg, gr.update(interactive=False, visible=False), gr.update(interactive=True) | |
| # --- Core Logic --- | |
| def _create_button_updates(): | |
| updates = [] | |
| for img_label in IMAGE_LABELS: | |
| selected_rank = GLOBAL_STATE["current_ranks"].get(img_label) | |
| for rank_val in range(1, 6): | |
| if rank_val == selected_rank: | |
| updates.append(gr.update(variant='primary')) | |
| else: | |
| updates.append(gr.update(variant='secondary')) | |
| return updates | |
| def handle_rank_button_click(image_label, rank): | |
| if GLOBAL_STATE["current_ranks"].get(image_label) == rank: | |
| GLOBAL_STATE["current_ranks"][image_label] = None | |
| else: | |
| GLOBAL_STATE["current_ranks"][image_label] = rank | |
| return _create_button_updates() | |
| def handle_absolute_score_click(score): | |
| if GLOBAL_STATE["current_absolute_score"] == score: | |
| GLOBAL_STATE["current_absolute_score"] = None | |
| else: | |
| GLOBAL_STATE["current_absolute_score"] = score | |
| updates = [] | |
| for i in range(1, 8): | |
| if i == GLOBAL_STATE["current_absolute_score"]: | |
| updates.append(gr.update(variant='primary')) | |
| else: | |
| updates.append(gr.update(variant='secondary')) | |
| return updates | |
| # ▼▼▼ 追加 ▼▼▼ | |
| def handle_absolute_score_worst_click(score): | |
| if GLOBAL_STATE["current_absolute_score_worst"] == score: | |
| GLOBAL_STATE["current_absolute_score_worst"] = None | |
| else: | |
| GLOBAL_STATE["current_absolute_score_worst"] = score | |
| updates = [] | |
| for i in range(1, 8): | |
| if i == GLOBAL_STATE["current_absolute_score_worst"]: | |
| updates.append(gr.update(variant='primary')) | |
| else: | |
| updates.append(gr.update(variant='secondary')) | |
| return updates | |
| # ▼▼▼ 1. UIフリーズ問題を修正 ▼▼▼ | |
| # ▼▼▼ 修正後の display_current_prompt_and_criterion 関数 ▼▼▼ | |
| def display_current_prompt_and_criterion(): | |
| if not GLOBAL_STATE["data_loaded"] or GLOBAL_STATE["current_prompt_index"] >= len(GLOBAL_STATE["all_eval_data"]): | |
| done_msg = "<p class='feedback green' style='text-align: center; font-size: 1.2em;'>All prompts have been evaluated! Please proceed to the 'Export' tab. <br>すべてのプロンプトの評価が完了しました!「エクスポート」タブに進んでください。</p>" | |
| empty_button_updates = [gr.update(variant='secondary')] * 25 | |
| empty_abs_updates = [gr.update(variant='secondary')] * 7 | |
| return [ | |
| gr.update(value="Finished! / 完了!"), | |
| gr.update(value=""), | |
| gr.update(value=done_msg), | |
| gr.update(value="", visible=False), | |
| *[gr.update(value=None)] * 5, | |
| *empty_button_updates, | |
| gr.update(visible=False), # abs_group_best | |
| *empty_abs_updates, | |
| gr.update(visible=False), # abs_group_worst | |
| *empty_abs_updates, | |
| gr.update(interactive=False), | |
| gr.update(interactive=False) | |
| ] | |
| prompt_idx = GLOBAL_STATE["shuffled_indices"][GLOBAL_STATE["current_prompt_index"]] | |
| criterion_idx = GLOBAL_STATE["current_criterion_index"] | |
| current_data = GLOBAL_STATE["all_eval_data"][prompt_idx] | |
| prompt_text = current_data["prompt"] | |
| criterion_name = CRITERIA[criterion_idx] | |
| progress_text = f"Prompt {GLOBAL_STATE['current_prompt_index'] + 1} / {len(GLOBAL_STATE['all_eval_data'])} - **{criterion_name}**" | |
| prompt_display_text = f"## \"{prompt_text}\"" | |
| guidance_text = f"### Please rank the 5 images based on **{CRITERIA_GUIDANCE_EN[criterion_idx]}**.<br>5つの画像を、**「{CRITERIA_GUIDANCE_JP[criterion_idx]}」**を基準にランキング付けしてください。" | |
| if criterion_idx == 0: | |
| GLOBAL_STATE["image_orders"] = {} | |
| if criterion_name not in GLOBAL_STATE["image_orders"]: | |
| conditions_shuffled = random.sample(CONDITIONS, len(CONDITIONS)) | |
| GLOBAL_STATE["image_orders"][criterion_name] = conditions_shuffled | |
| current_image_order = GLOBAL_STATE["image_orders"][criterion_name] | |
| image_updates = [] | |
| for cond_name in current_image_order: | |
| prediction = current_data["predictions"][cond_name] | |
| img_path = get_image_path_from_prediction(prediction) | |
| image_updates.append(gr.update(value=img_path if img_path and os.path.exists(img_path) else None)) | |
| saved_ranks_dict = GLOBAL_STATE["evaluation_results"].get(prompt_text, {}).get("ranks", {}).get(criterion_name) | |
| if saved_ranks_dict: | |
| label_to_condition = {label: cond for label, cond in zip(IMAGE_LABELS, current_image_order)} | |
| condition_to_label = {v: k for k, v in label_to_condition.items()} | |
| GLOBAL_STATE["current_ranks"] = { | |
| condition_to_label[cond]: rank for cond, rank in saved_ranks_dict.items() if cond in condition_to_label | |
| } | |
| else: | |
| GLOBAL_STATE["current_ranks"] = {label: None for label in IMAGE_LABELS} | |
| button_updates = _create_button_updates() | |
| # --- Absolute Score (Best) --- | |
| is_alignment_criterion = (criterion_name == "Alignment") | |
| abs_group_update = gr.update(visible=is_alignment_criterion) | |
| saved_abs_score = GLOBAL_STATE["evaluation_results"].get(prompt_text, {}).get("absolute_score") | |
| GLOBAL_STATE["current_absolute_score"] = saved_abs_score if is_alignment_criterion else None | |
| abs_button_updates = [] | |
| for i in range(1, 8): | |
| variant = 'primary' if i == GLOBAL_STATE["current_absolute_score"] else 'secondary' | |
| abs_button_updates.append(gr.update(variant=variant)) | |
| # --- Absolute Score (Worst) --- | |
| abs_group_worst_update = gr.update(visible=is_alignment_criterion) | |
| saved_abs_score_worst = GLOBAL_STATE["evaluation_results"].get(prompt_text, {}).get("absolute_score_worst") | |
| GLOBAL_STATE["current_absolute_score_worst"] = saved_abs_score_worst if is_alignment_criterion else None | |
| abs_button_worst_updates = [] | |
| for i in range(1, 8): | |
| variant = 'primary' if i == GLOBAL_STATE["current_absolute_score_worst"] else 'secondary' | |
| abs_button_worst_updates.append(gr.update(variant=variant)) | |
| return [ | |
| gr.update(value=progress_text), | |
| gr.update(value=prompt_display_text), | |
| gr.update(value=guidance_text), | |
| gr.update(value="", visible=False), | |
| *image_updates, | |
| *button_updates, | |
| abs_group_update, | |
| *abs_button_updates, | |
| abs_group_worst_update, | |
| *abs_button_worst_updates, | |
| gr.update( | |
| interactive=(GLOBAL_STATE["current_prompt_index"] > 0 or GLOBAL_STATE["current_criterion_index"] > 0)), | |
| gr.update(interactive=True) | |
| ] | |
| # ▼▼▼ 修正後の validate_and_navigate 関数 ▼▼▼ | |
| def validate_and_navigate(): | |
| ranks = GLOBAL_STATE["current_ranks"] | |
| error_msg = None | |
| criterion_name = CRITERIA[GLOBAL_STATE["current_criterion_index"]] | |
| is_alignment_criterion = (criterion_name == "Alignment") | |
| # --- Validation --- | |
| if any(r is None for r in ranks.values()): | |
| error_msg = "Please rank all 5 images. / 5つすべての画像を評価してください。" | |
| elif 1 not in ranks.values(): | |
| error_msg = "You must assign a rank of '1' to at least one image. / 最低1つは「1位」を付けてください。" | |
| elif is_alignment_criterion and GLOBAL_STATE["current_absolute_score"] is None: | |
| error_msg = "Please provide an absolute score for the BEST matching image (1-7). / 最も一致している画像について、絶対評価(1~7)を選択してください。" | |
| elif is_alignment_criterion and GLOBAL_STATE["current_absolute_score_worst"] is None: | |
| error_msg = "Please provide an absolute score for the WORST matching image (1-7). / 最も一致していない画像について、絶対評価(1~7)を選択してください。" | |
| # ▼▼▼ 変更箇所 ここから ▼▼▼ | |
| elif ( | |
| is_alignment_criterion | |
| and GLOBAL_STATE["current_absolute_score"] is not None | |
| and GLOBAL_STATE["current_absolute_score_worst"] is not None | |
| and GLOBAL_STATE["current_absolute_score_worst"] > GLOBAL_STATE["current_absolute_score"] | |
| ): | |
| error_msg = ( | |
| "The score for the WORST matching image cannot be higher than the score for the BEST matching image.<br>" | |
| "「最も一致していない画像」のスコアが「最も一致している画像」のスコアを上回ることはできません。" | |
| ) | |
| # ▲▲▲ 変更箇所 ここまで ▲▲▲ | |
| if error_msg: | |
| # The number of components to update is now 53 (1 tab + 52 eval components) | |
| no_change_updates = [gr.update()] * 53 | |
| no_change_updates[4] = gr.update( # error_display is the 5th component (index 4) | |
| value=f"<p class='feedback red' style='font-size: 1.2em; text-align: center;'>{error_msg}</p>", | |
| visible=True) | |
| return no_change_updates | |
| # ... (Rank tie-breaking validation logic is unchanged) ... | |
| sorted_ranks = sorted(list(ranks.values())) | |
| rank_counts = Counter(sorted_ranks) | |
| i = 0 | |
| while i < len(sorted_ranks): | |
| current_rank = sorted_ranks[i] | |
| count = rank_counts[current_rank] | |
| if i + count < len(sorted_ranks): | |
| next_rank = sorted_ranks[i + count] | |
| expected_next_rank = current_rank + count | |
| if next_rank < expected_next_rank: | |
| error_msg = f"Ranking rule violation (tie-breaking). After {count} instance(s) of rank '{current_rank}', the next rank must be >= {expected_next_rank}, but it is '{next_rank}'. / 順位付けのルール違反です。'{current_rank}'位が{count}つあるため、次の順位は{expected_next_rank}位以上である必要がありますが、'{next_rank}'位が入力されています。" | |
| break | |
| i += count | |
| if error_msg: | |
| no_change_updates = [gr.update()] * 53 | |
| no_change_updates[4] = gr.update( | |
| value=f"<p class='feedback red' style='font-size: 1.2em; text-align: center;'>{error_msg}</p>", | |
| visible=True) | |
| return no_change_updates | |
| # --- End of Validation --- | |
| prompt_idx = GLOBAL_STATE["shuffled_indices"][GLOBAL_STATE["current_prompt_index"]] | |
| current_data = GLOBAL_STATE["all_eval_data"][prompt_idx] | |
| prompt_text = current_data["prompt"] | |
| current_image_order = GLOBAL_STATE["image_orders"][criterion_name] | |
| label_to_condition = {label: cond for label, cond in zip(IMAGE_LABELS, current_image_order)} | |
| ranks_by_condition = {label_to_condition[label]: rank for label, rank in ranks.items()} | |
| if "ranks" not in GLOBAL_STATE["evaluation_results"][prompt_text]: | |
| GLOBAL_STATE["evaluation_results"][prompt_text]["ranks"] = {} | |
| if "orders" not in GLOBAL_STATE["evaluation_results"][prompt_text]: | |
| GLOBAL_STATE["evaluation_results"][prompt_text]["orders"] = {} | |
| GLOBAL_STATE["evaluation_results"][prompt_text]["ranks"][criterion_name] = ranks_by_condition | |
| GLOBAL_STATE["evaluation_results"][prompt_text]["orders"][criterion_name] = current_image_order | |
| if is_alignment_criterion: | |
| GLOBAL_STATE["evaluation_results"][prompt_text]["absolute_score"] = GLOBAL_STATE["current_absolute_score"] | |
| GLOBAL_STATE["evaluation_results"][prompt_text]["absolute_score_worst"] = GLOBAL_STATE[ | |
| "current_absolute_score_worst"] | |
| logger.info( | |
| f"Saved rank for P:{GLOBAL_STATE['participant_id']}, Prompt:'{prompt_text}', Criterion:{criterion_name}, Ranks:{ranks_by_condition}") | |
| GLOBAL_STATE["current_criterion_index"] += 1 | |
| if GLOBAL_STATE["current_criterion_index"] >= len(CRITERIA): | |
| GLOBAL_STATE["current_criterion_index"] = 0 | |
| GLOBAL_STATE["current_prompt_index"] += 1 | |
| if GLOBAL_STATE["current_prompt_index"] >= len(GLOBAL_STATE["all_eval_data"]): | |
| GLOBAL_STATE["end_time"] = datetime.now() | |
| eval_panel_updates = display_current_prompt_and_criterion() | |
| # Activate export tab on completion | |
| return [gr.update(interactive=True)] + eval_panel_updates | |
| else: | |
| # Keep export tab state as is | |
| return [gr.update()] + display_current_prompt_and_criterion() | |
| def navigate_previous(): | |
| GLOBAL_STATE["current_criterion_index"] -= 1 | |
| if GLOBAL_STATE["current_criterion_index"] < 0: | |
| GLOBAL_STATE["current_criterion_index"] = len(CRITERIA) - 1 | |
| GLOBAL_STATE["current_prompt_index"] -= 1 | |
| GLOBAL_STATE["current_prompt_index"] = max(0, GLOBAL_STATE["current_prompt_index"]) | |
| return display_current_prompt_and_criterion() | |
| # ▼▼▼ 修正後の export_results 関数 ▼▼▼ | |
| def export_results(participant_id, alignment_reason, naturalness_reason, attractiveness_reason, optional_comment): | |
| if not alignment_reason.strip() or not naturalness_reason.strip() or not attractiveness_reason.strip(): | |
| error_msg = "<p class='feedback red'>Please fill in the reasoning for all three criteria (Alignment, Naturalness, Attractiveness). / 3つの評価基準(一致度, 自然さ, 魅力度)すべての判断理由を記入してください。</p>" | |
| return None, error_msg | |
| if not participant_id: | |
| return None, "<p class='feedback red'>Participant ID is missing. / 参加者IDがありません。</p>" | |
| output_dir = os.path.join(BASE_RESULTS_DIR, participant_id) | |
| os.makedirs(output_dir, exist_ok=True) | |
| filename = f"evaluation_results_{participant_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" | |
| filepath = os.path.join(output_dir, filename) | |
| duration = (GLOBAL_STATE["end_time"] - GLOBAL_STATE["start_time"]).total_seconds() if GLOBAL_STATE.get( | |
| "start_time") and GLOBAL_STATE.get("end_time") else None | |
| prompt_to_category = {item["prompt"]: item["category"] for item in GLOBAL_STATE["all_eval_data"]} | |
| final_results_list = [] | |
| for prompt, data in GLOBAL_STATE["evaluation_results"].items(): | |
| if not data: continue | |
| ranks_data = data.get("ranks", {}) | |
| orders_data = data.get("orders", {}) | |
| final_results_list.append({ | |
| "prompt": prompt, | |
| "prompt_category": prompt_to_category.get(prompt), | |
| "image_order_alignment": orders_data.get("Alignment", []), | |
| "image_order_naturalness": orders_data.get("Naturalness", []), | |
| "image_order_attractiveness": orders_data.get("Attractiveness", []), | |
| "alignment_ranks": ranks_data.get("Alignment", {}), | |
| "naturalness_ranks": ranks_data.get("Naturalness", {}), | |
| "attractiveness_ranks": ranks_data.get("Attractiveness", {}), | |
| "alignment_absolute_score": data.get("absolute_score"), | |
| # ▼▼▼ 追加 ▼▼▼ | |
| "alignment_absolute_score_worst": data.get("absolute_score_worst") | |
| }) | |
| export_data = { | |
| "metadata": { | |
| "participant_id": participant_id, | |
| "export_timestamp": datetime.now().isoformat(), | |
| "total_prompts_evaluated": len(final_results_list), | |
| "evaluation_duration_seconds": duration, | |
| "reasoning": { | |
| "alignment": alignment_reason, | |
| "naturalness": naturalness_reason, | |
| "attractiveness": attractiveness_reason, | |
| }, | |
| "optional_comment": optional_comment, | |
| }, | |
| "results": final_results_list | |
| } | |
| try: | |
| with open(filepath, 'w', encoding='utf-8') as f: | |
| json.dump(export_data, f, ensure_ascii=False, indent=2) | |
| logger.info(f"Successfully exported results to: {filepath}") | |
| except Exception as e: | |
| logger.error(f"Failed to write export file: {e}") | |
| return None, f"<p class='feedback red'>An error occurred during file export: {e}</p>" | |
| upload_link = "https://drive.google.com/drive/folders/1ujIPF-67Y6OG8qBm1TYG3FsmuYxqSAcR?usp=drive_link" | |
| status_message = f""" | |
| <div class='feedback green' style='text-align: left;'> | |
| <p><b>エクスポートが完了しました。/ Export complete.</b></p> | |
| <p>上のボタンからJSONファイルをダウンロードし、指定された場所にアップロードして実験を終了してください。ご協力ありがとうございました。</p> | |
| <p>Please download the JSON file and upload it to the designated location. Thank you for your cooperation.</p> | |
| <p><b>アップロード先 / Upload to:</b> <a href='{upload_link}' target='_blank'>{upload_link}</a></p> | |
| </div>""" | |
| return gr.update(value=filepath, visible=True), status_message | |
| ## ▼▼▼ 修正後の create_gradio_interface 関数 ▼▼▼ | |
| def create_gradio_interface(): | |
| css = """ | |
| .gradio-container { font-family: 'Arial', sans-serif; } | |
| .feedback { padding: 10px; border-radius: 5px; font-weight: bold; text-align: center; margin-top: 10px; } | |
| .feedback.green { background-color: #e6ffed; color: #2f6f4a; } | |
| .feedback.red { background-color: #ffe6e6; color: #b30000; } | |
| .image-label { font-size: 2.5em; font-weight: bold; margin-bottom: 10px; color: #333; } | |
| .prompt-display { text-align: center; margin-bottom: 5px; padding: 15px; background-color: #f0f8ff; border-radius: 8px; } | |
| .prompt-sub-guidance { text-align: center; font-size: 0.9em; color: #555; margin-top: 5px; margin-bottom: 15px; } | |
| .rank-instruction { | |
| color: #D32F2F; | |
| font-size: 1.1em; | |
| text-align: left; | |
| margin-bottom: 20px; | |
| padding: 15px; | |
| border: 1px solid #f5c6cb; | |
| border-radius: 8px; | |
| background-color: #f8d7da; | |
| line-height: 1.6; | |
| } | |
| .rank-instruction ul { padding-left: 20px; margin: 0; } | |
| .rank-guidance { text-align: center; margin-bottom: 10px; font-size: 1.2em; } | |
| .rank-btn-row { justify-content: center; gap: 5px !important; } | |
| .rank-btn { | |
| min-width: 65px !important; | |
| max-width: 65px !important; | |
| height: 45px !important; | |
| font-size: 1.2em !important; | |
| font-weight: bold !important; | |
| border-radius: 8px !important; | |
| border: 1px solid #ccc !important; | |
| } | |
| .rank-btn.secondary { | |
| background: #f0f0f0 !important; | |
| color: #333 !important; | |
| } | |
| .rank-btn.secondary:hover { | |
| background: #e0e0e0 !important; | |
| border-color: #bbb !important; | |
| } | |
| .absolute-eval-group { | |
| border: 1px solid #ddd; | |
| border-radius: 8px; | |
| padding: 15px; | |
| margin-top: 20px; | |
| } | |
| """ | |
| with gr.Blocks(title="Expression Evaluation Experiment", css=css) as app: | |
| gr.Markdown("# Text-to-Expression Evaluation Experiment / テキストからの表情生成 評価実験") | |
| with gr.Tabs() as tabs: | |
| with gr.TabItem("1. Setup / セットアップ") as tab_setup: | |
| gr.Markdown("## (A) Participant Information / 参加者情報") | |
| gr.Markdown("Please enter your participant ID and click 'Confirm'. / 参加者IDを入力して「確定」を押してください。") | |
| with gr.Row(): | |
| participant_id_input = gr.Textbox(label="Participant ID", placeholder="e.g., P01") | |
| confirm_id_btn = gr.Button("Confirm / 確定", variant="primary") | |
| setup_warning = gr.Markdown(visible=False) | |
| with gr.Group(visible=False) as setup_main_group: | |
| gr.Markdown("---") | |
| gr.Markdown("## (B) Instructions & Data Loading / 注意事項とデータ読み込み") | |
| gr.Markdown( | |
| """<div style='padding: 15px; border: 1px solid #f0ad4e; border-radius: 5px; background-color: #fcf8e3;'> | |
| <h4>注意事項 / Instructions</h4> | |
| <ul> | |
| <li><b>この作業はPCで行ってください。/ Please perform this task on a PC.</b></li> | |
| <li>途中で止めずに最後まで続けてください。ファイルをアップロードして完了となります。/ Please continue until the end. The experiment is complete when you upload the file.</li> | |
| <li>ブラウザーをリロードしないでください (データが破損します)。/ Do not reload the browser (this will corrupt the data).</li> | |
| </ul></div>""") | |
| gr.Markdown( | |
| "Click the button below to load your evaluation data. / 下のボタンを押して、評価データを読み込んでください。") | |
| load_data_btn = gr.Button("Load Data / データ読み込み", variant="primary") | |
| setup_status = gr.Markdown("Waiting to start...") | |
| with gr.TabItem("2. Evaluation / 評価", interactive=False) as tab_evaluation: | |
| progress_text = gr.Markdown("Prompt 0 / 0") | |
| image_components = [] | |
| rank_buttons = [] | |
| with gr.Row(equal_height=False): | |
| for label in IMAGE_LABELS: | |
| with gr.Column(scale=1): | |
| with gr.Group(): | |
| gr.Markdown(f"<div class='image-label' style='text-align: center;'>{label}</div>") | |
| img = gr.Image(type="filepath", show_label=False, height=300) | |
| image_components.append(img) | |
| with gr.Row(elem_classes="rank-btn-row"): | |
| rank_list = ["1位", "2位", "3位", "4位", "5位"] | |
| for rank_val in range(1, 6): | |
| btn = gr.Button(str(rank_list[rank_val-1]), variant='secondary', elem_classes="rank-btn") | |
| rank_buttons.append(btn) | |
| prompt_display = gr.Markdown("## \"Prompt Text Here\"", elem_classes="prompt-display") | |
| gr.Markdown( | |
| "<p class='prompt-sub-guidance'>You may use AI or web search for the meaning of the text. However, please do not ask an AI about the emotion of the image itself.<br>意味についてはAIに聞いたりネット検索しても構いません。ただし、画像そのものの感情をAIに尋ねるのを止めてください。</p>") | |
| guidance_display = gr.Markdown("### Guidance", elem_classes="rank-guidance") | |
| error_display = gr.Markdown(visible=False) | |
| gr.Markdown( | |
| """ | |
| <b>ランキングの付け方 / How to Rank:</b> | |
| <ul> | |
| <li><b>全く同じ表情の画像には、同じ順位</b>を付けてください。(Assign the <b>same rank</b> to identical expressions.)</li> | |
| <li><b>少しでも違う表情の画像には、違う順位</b>を付けてください。(Assign <b>different ranks</b> to different expressions.)</li> | |
| <li><b>必ず1位から</b>順位を付けてください。(You <b>must</b> assign a rank of '1' to at least one image.)</li> | |
| <li>同順位がある場合、<b>その人数分だけ次の順位を飛ばしてください</b>。(When you have ties, <b>skip the next rank(s) accordingly</b>.) | |
| <ul> | |
| <li>例1: 1位が2つある場合、次は3位になります (Ex. 1: If there are two '1st' places, the next rank is '3rd'. e.g., <code>1, 1, 3, 4, 5</code>).</li> | |
| <li>例2: 1位が1つ、2位が3つある場合、次は5位になります (Ex. 2: If there is one '1st' and three '2nd' places, the next rank is '5th'. e.g., <code>1, 2, 2, 2, 5</code>).</li> | |
| </ul> | |
| </li> | |
| </ul> | |
| """, | |
| elem_classes="rank-instruction" | |
| ) | |
| # ▼▼▼ 修正: 絶対評価(Best)のUI ▼▼▼ | |
| with gr.Group(visible=False, elem_classes="absolute-eval-group") as absolute_eval_group_best: | |
| gr.Markdown("---") | |
| gr.Markdown( | |
| "#### 絶対評価 (Best) / Absolute Score (Best)\n最もテキストと一致している画像について、どのていど一致しているかを評価してください。\n(Please evaluate the degree of alignment for the image that **best** matches the text.)") | |
| absolute_score_buttons = [] | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown( | |
| "<p style='text-align: right; margin-top: 10px;'>1 (全く一致してない / Not at all)</p>") | |
| with gr.Column(scale=3): | |
| with gr.Row(elem_classes="rank-btn-row"): | |
| for i in range(1, 8): | |
| btn = gr.Button(str(i), variant='secondary', elem_classes="rank-btn") | |
| absolute_score_buttons.append(btn) | |
| with gr.Column(scale=1): | |
| gr.Markdown("<p style='text-align: left; margin-top: 10px;'>7 (完全に一致 / Absolutely)</p>") | |
| # ▼▼▼ 追加: 絶対評価(Worst)のUI ▼▼▼ | |
| with gr.Group(visible=False, elem_classes="absolute-eval-group") as absolute_eval_group_worst: | |
| gr.Markdown( | |
| "#### 絶対評価 (Worst) / Absolute Score (Worst)\n最もテキストと一致していない画像について、どのていど一致していないかを評価してください。\n(Please evaluate the degree of alignment for the image that **least** matches the text.)") | |
| absolute_score_worst_buttons = [] | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown( | |
| "<p style='text-align: right; margin-top: 10px;'>1 (全く一致してない / Not at all)</p>") | |
| with gr.Column(scale=3): | |
| with gr.Row(elem_classes="rank-btn-row"): | |
| for i in range(1, 8): | |
| btn = gr.Button(str(i), variant='secondary', elem_classes="rank-btn") | |
| absolute_score_worst_buttons.append(btn) | |
| with gr.Column(scale=1): | |
| gr.Markdown("<p style='text-align: left; margin-top: 10px;'>7 (完全に一致 / Absolutely)</p>") | |
| with gr.Row(): | |
| prev_btn = gr.Button("← Previous / 前へ", interactive=False) | |
| next_btn = gr.Button("Save & Next / 保存して次へ →", variant="primary") | |
| with gr.TabItem("3. Export / エクスポート", interactive=False) as tab_export: | |
| gr.Markdown("## (C) Final Comments & Export / 最終コメントとエクスポート") | |
| gr.Markdown( | |
| "Thank you for completing the evaluation. Please provide the reasoning for your judgments for each criterion below. / 評価お疲れ様でした。以下の各評価基準について、判断の理由をご記入ください。") | |
| with gr.Group(): | |
| gr.Markdown("#### Reasoning for Judgments (Required) / 判断理由(必須)") | |
| alignment_reason_box = gr.Textbox(label="Alignment / 一致度", lines=3, | |
| placeholder="Why did you rank them this way for alignment? / なぜ一致度について、このような順位付けをしましたか?") | |
| naturalness_reason_box = gr.Textbox(label="Naturalness / 自然さ", lines=3, | |
| placeholder="Why did you rank them this way for naturalness? / なぜ自然さについて、このような順位付けをしましたか?") | |
| attractiveness_reason_box = gr.Textbox(label="Attractiveness / 魅力度", lines=3, | |
| placeholder="Why did you rank them this way for attractiveness? / なぜ魅力度について、このような順位付けをしましたか?") | |
| with gr.Group(): | |
| gr.Markdown("#### Overall Comments (Optional) / 全体的な感想(任意)") | |
| optional_comment_box = gr.Textbox(label="Any other comments? / その他、実験全体に関するご意見・ご感想", | |
| lines=4, | |
| placeholder="e.g., 'Image B often looked the most natural.' / 例:「Bの画像が最も自然に見えることが多かったです。」") | |
| gr.Markdown("---") | |
| gr.Markdown( | |
| "Finally, click the button below to export your results. / 最後に、下のボタンを押して結果をエクスポートしてください。") | |
| export_btn = gr.Button("Export Results / 結果をエクスポート", variant="primary") | |
| download_file = gr.File(label="Download JSON", visible=False) | |
| export_status = gr.Markdown() | |
| # --- Event Handlers --- | |
| def check_and_confirm_id(pid): | |
| pid = pid.strip() | |
| if re.fullmatch(r"P\d{2}", pid): | |
| GLOBAL_STATE["participant_id"] = pid | |
| return gr.update(visible=False), gr.update(visible=True) | |
| else: | |
| error_msg = "<p class='feedback red'>Invalid ID. Must be 'P' followed by two digits (e.g., P01). / 無効なIDです。「P」と数字2桁の形式(例: P01)で入力してください。</p>" | |
| return gr.update(value=error_msg, visible=True), gr.update(visible=False) | |
| confirm_id_btn.click(check_and_confirm_id, [participant_id_input], [setup_warning, setup_main_group]) | |
| load_data_btn.click(load_evaluation_data, [participant_id_input], [setup_status, load_data_btn, tab_evaluation]) | |
| # ▼▼▼ 修正: all_eval_outputs に新しいUIコンポーネントを追加 ▼▼▼ | |
| all_eval_outputs = [ | |
| progress_text, prompt_display, guidance_display, error_display, *image_components, | |
| *rank_buttons, | |
| absolute_eval_group_best, *absolute_score_buttons, | |
| absolute_eval_group_worst, *absolute_score_worst_buttons, | |
| prev_btn, next_btn | |
| ] | |
| btn_idx = 0 | |
| for label in IMAGE_LABELS: | |
| for rank_val in range(1, 6): | |
| btn = rank_buttons[btn_idx] | |
| btn.click( | |
| partial(handle_rank_button_click, label, rank_val), | |
| [], | |
| rank_buttons | |
| ) | |
| btn_idx += 1 | |
| for i, btn in enumerate(absolute_score_buttons): | |
| btn.click( | |
| partial(handle_absolute_score_click, i + 1), | |
| [], | |
| absolute_score_buttons | |
| ) | |
| # ▼▼▼ 追加: 新しいボタンのイベントハンドラを接続 ▼▼▼ | |
| for i, btn in enumerate(absolute_score_worst_buttons): | |
| btn.click( | |
| partial(handle_absolute_score_worst_click, i + 1), | |
| [], | |
| absolute_score_worst_buttons | |
| ) | |
| tab_evaluation.select(display_current_prompt_and_criterion, [], all_eval_outputs) | |
| # ▼▼▼ 修正: next_btn の出力に tab_export を追加 ▼▼▼ | |
| next_btn.click(validate_and_navigate, [], [tab_export, *all_eval_outputs]) | |
| prev_btn.click(navigate_previous, [], all_eval_outputs) | |
| export_tab_interactive_components = [alignment_reason_box, naturalness_reason_box, attractiveness_reason_box, | |
| optional_comment_box, export_btn] | |
| def on_select_export_tab(): | |
| # end_time is set only when all evaluations are complete | |
| if GLOBAL_STATE.get("end_time"): | |
| return [gr.update(interactive=True)] * 5 | |
| # This logic is now handled by next_btn click, but kept as a fallback. | |
| return [gr.update(interactive=False)] * 5 | |
| tab_export.select(on_select_export_tab, [], export_tab_interactive_components) | |
| export_btn.click( | |
| export_results, | |
| [participant_id_input, alignment_reason_box, naturalness_reason_box, attractiveness_reason_box, | |
| optional_comment_box], | |
| [download_file, export_status] | |
| ) | |
| return app | |
| if __name__ == "__main__": | |
| os.makedirs(LOG_DIR, exist_ok=True) | |
| log_file_path = os.path.join(LOG_DIR, "evaluation_ui_log_{time}.log") | |
| random.seed(datetime.now().timestamp()) | |
| logger.remove() | |
| logger.add(sys.stderr, level="INFO") | |
| logger.add(log_file_path, rotation="10 MB") | |
| app = create_gradio_interface() | |
| app.launch(share=True, debug=True) |