Update app.py
Browse files
app.py
CHANGED
|
@@ -3,16 +3,17 @@ import os
|
|
| 3 |
import random
|
| 4 |
import time
|
| 5 |
from datetime import datetime
|
| 6 |
-
# import pandas as pd # 未在后续代码中使用
|
| 7 |
from functools import partial
|
| 8 |
import json
|
| 9 |
import io
|
| 10 |
from huggingface_hub import HfApi
|
| 11 |
from huggingface_hub.hf_api import HfHubHTTPError
|
| 12 |
import traceback
|
|
|
|
| 13 |
|
| 14 |
-
# ==== 全局配置 ====
|
| 15 |
-
BASE_IMAGE_DIR
|
|
|
|
| 16 |
TARGET_DIR_BASENAME = "gt"
|
| 17 |
TARGET_DIR = os.path.join(BASE_IMAGE_DIR, TARGET_DIR_BASENAME)
|
| 18 |
|
|
@@ -31,19 +32,64 @@ if os.path.exists(BASE_IMAGE_DIR):
|
|
| 31 |
except Exception as e: print(f"错误:在扫描 '{BASE_IMAGE_DIR}' 时发生错误: {e}"); METHOD_ROOTS = []
|
| 32 |
else: print(f"警告:基础目录 '{BASE_IMAGE_DIR}' 不存在。将无法加载候选图片。")
|
| 33 |
|
| 34 |
-
SUBJECTS = ["subj01", "subj02", "subj05", "subj07"]
|
| 35 |
SENTINEL_TRIAL_INTERVAL = 20
|
| 36 |
-
NUM_TRIALS_PER_RUN = 100
|
| 37 |
-
LOG_BATCH_SIZE = 5
|
| 38 |
|
| 39 |
DATASET_REPO_ID = "YanmHa/image-aligned-experiment-data"
|
| 40 |
-
INDIVIDUAL_LOGS_FOLDER = "individual_choice_logs"
|
| 41 |
-
BATCH_LOG_FOLDER = "run_logs_batch"
|
| 42 |
CSS = ".gr-block {margin-top: 4px !important; margin-bottom: 4px !important;} .compact_button { padding: 4px 8px; min-width: auto; }"
|
| 43 |
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
master_image_list = []
|
| 46 |
-
# ... (master_image_list 加载逻辑与您上一版代码相同) ...
|
| 47 |
if os.path.exists(TARGET_DIR):
|
| 48 |
try:
|
| 49 |
master_image_list = sorted(
|
|
@@ -60,120 +106,169 @@ if not master_image_list: print(f"关键错误:由于 '{TARGET_DIR}' 问题,
|
|
| 60 |
|
| 61 |
# ==== 辅助函数 ====
|
| 62 |
def get_next_trial_info(current_trial_idx_in_run, current_run_image_list_for_trial, num_trials_in_this_run_for_trial):
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
| 66 |
img_filename_original = current_run_image_list_for_trial[current_trial_idx_in_run]
|
| 67 |
target_full_path = os.path.join(TARGET_DIR, img_filename_original)
|
| 68 |
trial_number_for_display = current_trial_idx_in_run + 1
|
|
|
|
| 69 |
pool = []
|
| 70 |
for m_root_path in METHOD_ROOTS:
|
| 71 |
method_name = os.path.basename(m_root_path)
|
| 72 |
for s_id in SUBJECTS:
|
| 73 |
-
base, ext = os.path.splitext(img_filename_original)
|
|
|
|
| 74 |
candidate_path = os.path.join(m_root_path, s_id, reconstructed_filename)
|
| 75 |
-
if os.path.exists(candidate_path):
|
|
|
|
|
|
|
|
|
|
| 76 |
trial_info = {"image_id": img_filename_original, "target_path": target_full_path, "cur_no": trial_number_for_display, "is_sentinel": False,
|
| 77 |
"left_display_label": "N/A", "left_internal_label": "N/A", "left_path": None,
|
| 78 |
"right_display_label": "N/A", "right_internal_label": "N/A", "right_path": None}
|
|
|
|
| 79 |
is_potential_sentinel_trial = (trial_number_for_display > 0 and trial_number_for_display % SENTINEL_TRIAL_INTERVAL == 0)
|
|
|
|
| 80 |
if is_potential_sentinel_trial:
|
| 81 |
-
|
|
|
|
|
|
|
| 82 |
else:
|
| 83 |
print(f"生成哨兵试验 for '{img_filename_original}' (trial {trial_number_for_display})")
|
| 84 |
trial_info["is_sentinel"] = True
|
| 85 |
sentinel_candidate_target_tuple = ("目标图像", target_full_path)
|
| 86 |
random_reconstruction_candidate_tuple = random.choice(pool)
|
| 87 |
candidates_for_sentinel = [
|
| 88 |
-
(("目标图像", target_full_path), sentinel_candidate_target_tuple[0]),
|
| 89 |
(("重建图", random_reconstruction_candidate_tuple[1]), random_reconstruction_candidate_tuple[0])
|
| 90 |
]
|
| 91 |
random.shuffle(candidates_for_sentinel)
|
| 92 |
trial_info.update({
|
| 93 |
-
"left_display_label": candidates_for_sentinel[0][0][0],
|
| 94 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
})
|
| 96 |
-
|
| 97 |
-
if len(pool) < 2:
|
| 98 |
-
|
| 99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
trial_info.update({
|
| 101 |
-
"left_display_label": "候选图 1", "left_path":
|
| 102 |
-
"right_display_label": "候选图 2", "right_path":
|
| 103 |
})
|
| 104 |
return trial_info, current_trial_idx_in_run + 1
|
| 105 |
|
| 106 |
-
#
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
# --- 批量保存累积日志的函数 ---
|
| 110 |
-
def save_collected_logs_batch(list_of_log_entries, user_identifier_str, batch_identifier): # batch_identifier 可以是轮次号或时间戳
|
| 111 |
global DATASET_REPO_ID, BATCH_LOG_FOLDER
|
| 112 |
if not list_of_log_entries:
|
| 113 |
-
print("
|
| 114 |
-
return True
|
| 115 |
|
| 116 |
identifier_safe = str(user_identifier_str if user_identifier_str else "unknown_user_session").replace('.', '_').replace(':', '_').replace('/', '_').replace(' ', '_')
|
| 117 |
-
print(f"用户 {identifier_safe} - 准备批量保存 {len(list_of_log_entries)}
|
| 118 |
-
|
| 119 |
try:
|
| 120 |
token = os.getenv("HF_TOKEN")
|
| 121 |
-
if not token: print("错误:HF_TOKEN
|
| 122 |
-
if not DATASET_REPO_ID: print("错误:DATASET_REPO_ID
|
| 123 |
|
| 124 |
api = HfApi(token=token)
|
| 125 |
timestamp_str = datetime.now().strftime('%Y%m%d_%H%M%S_%f')
|
| 126 |
batch_filename = f"batch_user-{identifier_safe}_id-{batch_identifier}_{timestamp_str}_logs-{len(list_of_log_entries)}.jsonl"
|
| 127 |
-
path_in_repo = f"{BATCH_LOG_FOLDER}/{identifier_safe}/{batch_filename}"
|
| 128 |
-
|
| 129 |
jsonl_content = ""
|
| 130 |
for log_entry in list_of_log_entries:
|
| 131 |
try:
|
| 132 |
if isinstance(log_entry, dict): jsonl_content += json.dumps(log_entry, ensure_ascii=False) + "\n"
|
| 133 |
-
else: print(f"
|
| 134 |
except Exception as json_err:
|
| 135 |
-
print(f"
|
| 136 |
-
jsonl_content += json.dumps({"error": "
|
| 137 |
-
|
| 138 |
-
if not jsonl_content.strip(): print(f"用户 {identifier_safe} (批次 {batch_identifier}) 无可序列化日志。"); return True
|
| 139 |
|
| 140 |
log_bytes = jsonl_content.encode('utf-8')
|
| 141 |
file_like_object = io.BytesIO(log_bytes)
|
| 142 |
-
|
| 143 |
-
print(f"准备批量上传日志文件: {path_in_repo} ({len(log_bytes)} bytes)")
|
| 144 |
api.upload_file(
|
| 145 |
path_or_fileobj=file_like_object, path_in_repo=path_in_repo, repo_id=DATASET_REPO_ID, repo_type="dataset",
|
| 146 |
-
commit_message=f"Batch logs for
|
| 147 |
)
|
| 148 |
-
print(f"
|
| 149 |
return True
|
| 150 |
except Exception as e:
|
| 151 |
-
print(f"
|
| 152 |
return False
|
| 153 |
|
|
|
|
| 154 |
# ==== 主要的 Gradio 事件处理函数 ====
|
| 155 |
def process_experiment_step(
|
| 156 |
s_trial_idx_val, s_run_no_val, s_user_logs_val, s_current_trial_data_val, s_user_session_id_val,
|
| 157 |
s_current_run_image_list_val, s_num_trials_this_run_val,
|
| 158 |
action_type=None, choice_value=None, request: gr.Request = None
|
| 159 |
):
|
| 160 |
-
global master_image_list, NUM_TRIALS_PER_RUN, outputs_ui_components_definition, LOG_BATCH_SIZE
|
| 161 |
|
|
|
|
| 162 |
output_s_trial_idx = s_trial_idx_val; output_s_run_no = s_run_no_val
|
| 163 |
output_s_user_logs = list(s_user_logs_val); output_s_current_trial_data = dict(s_current_trial_data_val) if s_current_trial_data_val else {}
|
| 164 |
output_s_user_session_id = s_user_session_id_val; output_s_current_run_image_list = list(s_current_run_image_list_val)
|
| 165 |
output_s_num_trials_this_run = s_num_trials_this_run_val
|
| 166 |
user_ip_fallback = request.client.host if request else "unknown_ip"
|
| 167 |
user_identifier_for_logging = output_s_user_session_id if output_s_user_session_id else user_ip_fallback
|
|
|
|
| 168 |
len_ui_outputs = len(outputs_ui_components_definition)
|
| 169 |
def create_ui_error_tuple(message, progress_msg_text): return (gr.update(visible=False),) * 3 + ("", "", message, progress_msg_text) + (gr.update(interactive=True), gr.update(interactive=False), gr.update(interactive=False)) + (gr.update(visible=False),)
|
| 170 |
def create_no_change_tuple(): return (gr.update(),) * len_ui_outputs
|
| 171 |
user_id_display_text = output_s_user_session_id if output_s_user_session_id else "用户ID待分配"
|
| 172 |
|
|
|
|
| 173 |
if action_type == "record_choice":
|
|
|
|
| 174 |
if output_s_current_trial_data.get("data") and output_s_current_trial_data["data"].get("left_internal_label"):
|
| 175 |
chosen_internal_label = (output_s_current_trial_data["data"]["left_internal_label"] if choice_value == "left" else output_s_current_trial_data["data"]["right_internal_label"])
|
| 176 |
-
# ... (log_entry 创建逻辑,与上一版一致,包含 chosen_method, chosen_subject, chosen_filename) ...
|
| 177 |
parsed_chosen_method, parsed_chosen_subject, parsed_chosen_filename = "N/A", "N/A", "N/A"
|
| 178 |
if chosen_internal_label == "目标图像": parsed_chosen_method, parsed_chosen_subject, parsed_chosen_filename = "TARGET", "GT", output_s_current_trial_data["data"]["image_id"]
|
| 179 |
else:
|
|
@@ -182,84 +277,102 @@ def process_experiment_step(
|
|
| 182 |
elif len(parts) == 2: parsed_chosen_method, parsed_chosen_subject = parts[0].strip(), parts[1].strip()
|
| 183 |
elif len(parts) == 1: parsed_chosen_method = parts[0].strip()
|
| 184 |
log_entry = {
|
| 185 |
-
"timestamp": datetime.now().isoformat(), "user_identifier": user_identifier_for_logging, "run_no": output_s_run_no,
|
| 186 |
"image_id": output_s_current_trial_data["data"]["image_id"],
|
| 187 |
-
"left_internal_label": output_s_current_trial_data["data"]["left_internal_label"],
|
| 188 |
"right_internal_label": output_s_current_trial_data["data"]["right_internal_label"],
|
| 189 |
-
"chosen_side": choice_value, "chosen_internal_label": chosen_internal_label,
|
| 190 |
"chosen_method": parsed_chosen_method, "chosen_subject": parsed_chosen_subject, "chosen_filename": parsed_chosen_filename,
|
| 191 |
-
"trial_sequence_in_run": output_s_current_trial_data["data"]["cur_no"],
|
| 192 |
"is_sentinel": output_s_current_trial_data["data"]["is_sentinel"]
|
| 193 |
}
|
| 194 |
output_s_user_logs.append(log_entry)
|
| 195 |
print(f"用户 {user_identifier_for_logging} 记录选择 (img: {log_entry['image_id']})。当前批次日志数: {len(output_s_user_logs)}")
|
| 196 |
|
| 197 |
-
|
|
|
|
| 198 |
if len(output_s_user_logs) >= LOG_BATCH_SIZE:
|
| 199 |
-
print(f"
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
|
|
|
| 206 |
else:
|
| 207 |
-
print("
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
error_ui_updates = create_ui_error_tuple("记录选择时内部错误。", f"用户ID: {user_id_display_text} | 进度:{output_s_trial_idx}/{output_s_num_trials_this_run}")
|
|
|
|
| 212 |
return output_s_trial_idx, output_s_run_no, output_s_user_logs, output_s_current_trial_data, output_s_user_session_id, output_s_current_run_image_list, output_s_num_trials_this_run, *error_ui_updates
|
| 213 |
|
| 214 |
-
#
|
| 215 |
if action_type == "start_experiment":
|
| 216 |
is_first = (output_s_num_trials_this_run == 0 and output_s_trial_idx == 0 and output_s_run_no == 1)
|
| 217 |
is_completed_for_restart = (output_s_num_trials_this_run > 0 and output_s_trial_idx >= output_s_num_trials_this_run)
|
| 218 |
if is_first or is_completed_for_restart:
|
| 219 |
if not master_image_list: error_ui = create_ui_error_tuple("错误: 无可用目标图片!", f"用户ID: {user_id_display_text} | 进度: 0/0"); return 0, output_s_run_no, output_s_user_logs, {}, output_s_user_session_id, [], 0, *error_ui
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
# 轮次号的增加在这里是正确的,用于显示
|
| 223 |
-
if is_completed_for_restart:
|
| 224 |
output_s_run_no += 1
|
| 225 |
-
|
| 226 |
num_avail = len(master_image_list); run_size = min(num_avail, NUM_TRIALS_PER_RUN)
|
| 227 |
if run_size == 0: error_ui = create_ui_error_tuple("错误: 采样图片数为0!", f"用户ID: {user_id_display_text} | 进度: 0/0"); return 0, output_s_run_no, output_s_user_logs, {}, output_s_user_session_id, [], 0, *error_ui
|
| 228 |
-
|
| 229 |
output_s_current_run_image_list = random.sample(master_image_list, run_size)
|
| 230 |
output_s_num_trials_this_run = run_size
|
| 231 |
output_s_trial_idx = 0
|
| 232 |
-
# output_s_user_logs 不在这里重置,它会持续累积直到达到 LOG_BATCH_SIZE
|
| 233 |
output_s_current_trial_data = {}
|
| 234 |
-
|
| 235 |
-
# 用户会话ID在欢迎页已设置,这里不再修改,除非有特殊需求
|
| 236 |
print(f"开始/继续轮次 {output_s_run_no} (用户ID: {output_s_user_session_id}). 随机选择 {output_s_num_trials_this_run} 张图片.")
|
| 237 |
-
else:
|
| 238 |
print(f"用户 {user_identifier_for_logging} 在第 {output_s_run_no} 轮,试验 {output_s_trial_idx} 点击开始,但轮次未完成。忽略。")
|
| 239 |
no_change_ui = create_no_change_tuple()
|
| 240 |
return output_s_trial_idx, output_s_run_no, output_s_user_logs, output_s_current_trial_data, output_s_user_session_id, output_s_current_run_image_list, output_s_num_trials_this_run, *no_change_ui
|
| 241 |
-
|
| 242 |
-
#
|
| 243 |
if output_s_trial_idx >= output_s_num_trials_this_run and output_s_num_trials_this_run > 0:
|
| 244 |
-
# 仅打印完成信息,实际的批量保存由 record_choice 中的计数器触发
|
| 245 |
print(f"用户 {output_s_user_session_id} 已完成第 {output_s_run_no} 轮。等待下一批或下一轮开始。")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
prog_text = f"用户ID: {output_s_user_session_id} | 进度:{output_s_num_trials_this_run}/{output_s_num_trials_this_run} | 第 {output_s_run_no} 轮 🎉"
|
| 247 |
ui_updates = list(create_ui_error_tuple(f"🎉 第 {output_s_run_no} 轮完成!请点击“开始试验 / 下一轮”继续或开始新批次。", prog_text))
|
| 248 |
-
# ... (UI 更新与之前一致) ...
|
| 249 |
ui_updates[7]=gr.update(interactive=True); ui_updates[8]=gr.update(interactive=False); ui_updates[9]=gr.update(interactive=False)
|
| 250 |
ui_updates[0]=gr.update(value=None,visible=False); ui_updates[1]=gr.update(value=None,visible=False); ui_updates[2]=gr.update(value=None,visible=False)
|
| 251 |
yield output_s_trial_idx, output_s_run_no, output_s_user_logs, output_s_current_trial_data, output_s_user_session_id, output_s_current_run_image_list, output_s_num_trials_this_run, *ui_updates; return
|
| 252 |
|
| 253 |
-
|
| 254 |
-
# ... (
|
| 255 |
if not output_s_current_run_image_list or output_s_num_trials_this_run == 0:
|
| 256 |
error_ui = create_ui_error_tuple("错误: 无法加载试验图片 (列表为空)", f"用户ID: {user_id_display_text} | 进度: N/A")
|
| 257 |
return output_s_trial_idx, output_s_run_no, output_s_user_logs, {"data": None}, output_s_user_session_id, [], 0, *error_ui
|
| 258 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 259 |
if trial_info is None:
|
| 260 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
output_s_current_trial_data = {"data": None}
|
| 262 |
-
return output_s_trial_idx, output_s_run_no, output_s_user_logs, output_s_current_trial_data, output_s_user_session_id, output_s_current_run_image_list, output_s_num_trials_this_run, *
|
|
|
|
| 263 |
output_s_current_trial_data = {"data": trial_info}
|
| 264 |
prog_text = f"用户ID: {output_s_user_session_id} | 进度:{trial_info['cur_no']}/{output_s_num_trials_this_run} | 第 {output_s_run_no} 轮"
|
| 265 |
ui_show_target_updates = list(create_no_change_tuple())
|
|
@@ -275,7 +388,9 @@ def process_experiment_step(
|
|
| 275 |
ui_show_candidates_updates[7]=gr.update(interactive=False); ui_show_candidates_updates[8]=gr.update(interactive=True); ui_show_candidates_updates[9]=gr.update(interactive=True)
|
| 276 |
yield next_s_trial_idx_for_state, output_s_run_no, output_s_user_logs, output_s_current_trial_data, output_s_user_session_id, output_s_current_run_image_list, output_s_num_trials_this_run, *ui_show_candidates_updates
|
| 277 |
|
| 278 |
-
|
|
|
|
|
|
|
| 279 |
welcome_page_markdown = """
|
| 280 |
## 欢迎加入实验!
|
| 281 |
您好!非常感谢您抽出宝贵时间参与我们的视觉偏好评估实验。您的选择将帮助我们改进重建算法,让机器生成的图像更贴近人类视觉体验!
|
|
@@ -303,7 +418,6 @@ welcome_page_markdown = """
|
|
| 303 |
再次感谢您的参与与支持!您每一次认真选择都对我们的研究意义重大。祝您一切顺利,实验愉快!
|
| 304 |
"""
|
| 305 |
def handle_agree_and_start(name, gender, age, education, request: gr.Request):
|
| 306 |
-
# ... (此函数与您上一版代码完全一致) ...
|
| 307 |
error_messages_list = []
|
| 308 |
if not name or str(name).strip() == "": error_messages_list.append("姓名 不能为空。")
|
| 309 |
if gender is None or str(gender).strip() == "": error_messages_list.append("性别 必须选择。")
|
|
@@ -311,32 +425,30 @@ def handle_agree_and_start(name, gender, age, education, request: gr.Request):
|
|
| 311 |
elif not (isinstance(age, (int, float)) and 1 <= age <= 120):
|
| 312 |
try: num_age = float(age);
|
| 313 |
except (ValueError, TypeError): error_messages_list.append("年龄必须是一个有效的数字。")
|
| 314 |
-
else:
|
| 315 |
if not (1 <= num_age <= 120): error_messages_list.append("年龄必须在 1 到 120 之间。")
|
| 316 |
-
if education is None or str(education).strip() == "其他": error_messages_list.append("学历 必须选择。")
|
| 317 |
if error_messages_list:
|
| 318 |
full_error_message = "请修正以下错误:\n" + "\n".join([f"- {msg}" for msg in error_messages_list])
|
| 319 |
print(f"用户输入验证失败: {full_error_message}")
|
| 320 |
return gr.update(), False, gr.update(visible=True), gr.update(visible=False), full_error_message
|
| 321 |
s_name = str(name).strip().replace(" ","_").replace("/","_").replace("\\","_")
|
| 322 |
s_gender = str(gender).strip().replace(" ","_").replace("/","_").replace("\\","_")
|
| 323 |
-
s_age = str(int(float(age)))
|
| 324 |
s_education = str(education).strip().replace(" ","_").replace("/","_").replace("\\","_")
|
| 325 |
user_id_str = f"N-{s_name}_G-{s_gender}_A-{s_age}_E-{s_education}"
|
| 326 |
print(f"用户信息收集完毕,生成用户ID: {user_id_str}")
|
| 327 |
return user_id_str, True, gr.update(visible=False), gr.update(visible=True), ""
|
| 328 |
|
| 329 |
with gr.Blocks(css=CSS, title="图像重建主观评估") as demo:
|
| 330 |
-
# ... (所有 State 变量定义,与您上一版代码相同) ...
|
| 331 |
s_show_experiment_ui = gr.State(False); s_trial_index = gr.State(0); s_run_no = gr.State(1)
|
| 332 |
s_user_logs = gr.State([]); s_current_trial_data = gr.State({}); s_user_session_id = gr.State(None)
|
| 333 |
s_current_run_image_list = gr.State([]); s_num_trials_this_run = gr.State(0)
|
| 334 |
-
|
| 335 |
-
welcome_container = gr.Column(visible=True)
|
| 336 |
experiment_container = gr.Column(visible=False)
|
| 337 |
|
| 338 |
with welcome_container:
|
| 339 |
-
# ... (欢迎页UI,与您上一版代码相同) ...
|
| 340 |
gr.Markdown(welcome_page_markdown)
|
| 341 |
with gr.Row(): user_name_input = gr.Textbox(label="请输入您的姓名或代号 (例如 张三 或 User001)", placeholder="例如:张三 -> ZS"); user_gender_input = gr.Radio(label="性别", choices=["男", "女"])
|
| 342 |
with gr.Row(): user_age_input = gr.Number(label="年龄 (请输入1-120的整数)", minimum=1, maximum=120, step=1); user_education_input = gr.Dropdown(label="学历", choices=["其他","初中及以下","高中(含中专)", "大专(含在读)", "本科(含在读)", "硕士(含在读)", "博士(含在读)"])
|
|
@@ -344,7 +456,6 @@ with gr.Blocks(css=CSS, title="图像重建主观评估") as demo:
|
|
| 344 |
btn_agree_and_start = gr.Button("我已阅读上述说明并同意参与实验")
|
| 345 |
|
| 346 |
with experiment_container:
|
| 347 |
-
# ... (实验主界面UI,与您上一版代码相同,包括新的按钮布局和隐藏的标签) ...
|
| 348 |
gr.Markdown("## 🧠 图像重建主观评估实验"); gr.Markdown(f"每轮实验大约有 {NUM_TRIALS_PER_RUN} 次比较。")
|
| 349 |
with gr.Row():
|
| 350 |
with gr.Column(scale=1, min_width=300): left_img = gr.Image(label="左候选图", visible=False, height=400, interactive=False); left_lbl = gr.Textbox(label="左图信息", visible=False, interactive=False, max_lines=1); btn_left = gr.Button("选择左图 (更相似)", interactive=False, elem_classes="compact_button")
|
|
@@ -355,17 +466,17 @@ with gr.Blocks(css=CSS, title="图像重建主观评估") as demo:
|
|
| 355 |
with gr.Row(): btn_start = gr.Button("开始试验 / 下一轮")
|
| 356 |
file_out_placeholder = gr.File(label=" ", visible=False, interactive=False)
|
| 357 |
|
| 358 |
-
outputs_ui_components_definition = [
|
| 359 |
target_img, left_img, right_img, left_lbl, right_lbl, status_text, progress_text,
|
| 360 |
-
btn_start, btn_left, btn_right, file_out_placeholder
|
| 361 |
]
|
| 362 |
-
click_inputs_base = [
|
| 363 |
s_trial_index, s_run_no, s_user_logs, s_current_trial_data, s_user_session_id,
|
| 364 |
s_current_run_image_list, s_num_trials_this_run
|
| 365 |
]
|
| 366 |
-
event_outputs = [
|
| 367 |
s_trial_index, s_run_no, s_user_logs, s_current_trial_data, s_user_session_id,
|
| 368 |
-
s_current_run_image_list, s_num_trials_this_run, *outputs_ui_components_definition
|
| 369 |
]
|
| 370 |
|
| 371 |
btn_agree_and_start.click(fn=handle_agree_and_start, inputs=[user_name_input, user_gender_input, user_age_input, user_education_input], outputs=[s_user_session_id, s_show_experiment_ui, welcome_container, experiment_container, welcome_error_msg])
|
|
@@ -373,9 +484,7 @@ with gr.Blocks(css=CSS, title="图像重建主观评估") as demo:
|
|
| 373 |
btn_left.click(fn=partial(process_experiment_step, action_type="record_choice", choice_value="left"), inputs=click_inputs_base, outputs=event_outputs, queue=True)
|
| 374 |
btn_right.click(fn=partial(process_experiment_step, action_type="record_choice", choice_value="right"), inputs=click_inputs_base, outputs=event_outputs, queue=True)
|
| 375 |
|
| 376 |
-
# ==== 程序入口 ====
|
| 377 |
if __name__ == "__main__":
|
| 378 |
-
# ... (与您上一版相同的启动检查和打印逻辑, 确保 allowed_paths 正确) ...
|
| 379 |
if not master_image_list: print("\n关键错误:程序无法启动,因无目标图片。"); exit()
|
| 380 |
else:
|
| 381 |
print(f"从 '{TARGET_DIR}' 加载 {len(master_image_list)} 张目标图片。每轮选 {NUM_TRIALS_PER_RUN} 张。")
|
|
@@ -383,18 +492,18 @@ if __name__ == "__main__":
|
|
| 383 |
else: print(f"方法根目录: {METHOD_ROOTS}")
|
| 384 |
if not SUBJECTS: print("警告: SUBJECTS 列表为空。")
|
| 385 |
else: print(f"Subjects: {SUBJECTS}")
|
| 386 |
-
print(f"
|
| 387 |
-
if
|
| 388 |
-
# INDIVIDUAL_LOGS_FOLDER 的打印可以保留或移除,因为现在主要用批量
|
| 389 |
-
# if INDIVIDUAL_LOGS_FOLDER: print(f" - 单个选择日志文件夹 (可能未使用): '{INDIVIDUAL_LOGS_FOLDER}/'")
|
| 390 |
-
if not os.getenv("HF_TOKEN"): print("警告: HF_TOKEN 未设置。日志无法保存。\n 请在 Space Secrets 中设置 HF_TOKEN。")
|
| 391 |
else: print("HF_TOKEN 已找到。")
|
|
|
|
|
|
|
| 392 |
path_to_allow_serving_from = BASE_IMAGE_DIR
|
| 393 |
allowed_paths_list = []
|
| 394 |
if os.path.exists(path_to_allow_serving_from) and os.path.isdir(path_to_allow_serving_from):
|
| 395 |
allowed_paths_list.append(os.path.abspath(path_to_allow_serving_from))
|
| 396 |
print(f"Gradio `demo.launch()` 配置 allowed_paths: {allowed_paths_list}")
|
| 397 |
else: print(f"关键警告:图片基础目录 '{path_to_allow_serving_from}' ({os.path.abspath(path_to_allow_serving_from) if path_to_allow_serving_from else 'N/A'}) 不存在或非目录。")
|
|
|
|
| 398 |
print("启动 Gradio 应用...")
|
| 399 |
if allowed_paths_list: demo.launch(allowed_paths=allowed_paths_list)
|
| 400 |
else: demo.launch()
|
|
|
|
| 3 |
import random
|
| 4 |
import time
|
| 5 |
from datetime import datetime
|
|
|
|
| 6 |
from functools import partial
|
| 7 |
import json
|
| 8 |
import io
|
| 9 |
from huggingface_hub import HfApi
|
| 10 |
from huggingface_hub.hf_api import HfHubHTTPError
|
| 11 |
import traceback
|
| 12 |
+
from itertools import combinations
|
| 13 |
|
| 14 |
+
# ==== 全局配置 (部分保持不变) ====
|
| 15 |
+
# ... (BASE_IMAGE_DIR, TARGET_DIR, METHOD_ROOTS, SUBJECTS, etc. 保持不变) ...
|
| 16 |
+
BASE_IMAGE_DIR = "/data/images/images"
|
| 17 |
TARGET_DIR_BASENAME = "gt"
|
| 18 |
TARGET_DIR = os.path.join(BASE_IMAGE_DIR, TARGET_DIR_BASENAME)
|
| 19 |
|
|
|
|
| 32 |
except Exception as e: print(f"错误:在扫描 '{BASE_IMAGE_DIR}' 时发生错误: {e}"); METHOD_ROOTS = []
|
| 33 |
else: print(f"警告:基础目录 '{BASE_IMAGE_DIR}' 不存在。将无法加载候选图片。")
|
| 34 |
|
| 35 |
+
SUBJECTS = ["subj01", "subj02", "subj05", "subj07"]
|
| 36 |
SENTINEL_TRIAL_INTERVAL = 20
|
| 37 |
+
NUM_TRIALS_PER_RUN = 100
|
| 38 |
+
LOG_BATCH_SIZE = 5
|
| 39 |
|
| 40 |
DATASET_REPO_ID = "YanmHa/image-aligned-experiment-data"
|
| 41 |
+
INDIVIDUAL_LOGS_FOLDER = "individual_choice_logs"
|
| 42 |
+
BATCH_LOG_FOLDER = "run_logs_batch"
|
| 43 |
CSS = ".gr-block {margin-top: 4px !important; margin-bottom: 4px !important;} .compact_button { padding: 4px 8px; min-width: auto; }"
|
| 44 |
|
| 45 |
+
|
| 46 |
+
# ==== 全局持久化历史记录 ====
|
| 47 |
+
GLOBAL_HISTORY_FILE = "global_experiment_shown_pairs.json"
|
| 48 |
+
global_shown_pairs_cache = {}
|
| 49 |
+
global_history_has_unsaved_changes = False # <--- 新增:跟踪全局历史是否有未保存的更改
|
| 50 |
+
|
| 51 |
+
def load_global_shown_pairs():
|
| 52 |
+
global global_shown_pairs_cache, global_history_has_unsaved_changes
|
| 53 |
+
if os.path.exists(GLOBAL_HISTORY_FILE):
|
| 54 |
+
try:
|
| 55 |
+
with open(GLOBAL_HISTORY_FILE, 'r', encoding='utf-8') as f:
|
| 56 |
+
data_from_file = json.load(f)
|
| 57 |
+
global_shown_pairs_cache = {
|
| 58 |
+
target_img: {frozenset(pair) for pair in pairs_list}
|
| 59 |
+
for target_img, pairs_list in data_from_file.items()
|
| 60 |
+
}
|
| 61 |
+
print(f"已成功从 '{GLOBAL_HISTORY_FILE}' 加载全局已展示图片对历史。")
|
| 62 |
+
except Exception as e:
|
| 63 |
+
print(f"错误:加载全局历史文件 '{GLOBAL_HISTORY_FILE}' 失败: {e}。将使用空历史记录。")
|
| 64 |
+
global_shown_pairs_cache = {}
|
| 65 |
+
else:
|
| 66 |
+
print(f"信息:全局历史文件 '{GLOBAL_HISTORY_FILE}' 未找到。将创建新的空历史记录。")
|
| 67 |
+
global_shown_pairs_cache = {}
|
| 68 |
+
global_history_has_unsaved_changes = False # 初始化或加载后,标记为无未保存更改
|
| 69 |
+
|
| 70 |
+
def save_global_shown_pairs():
|
| 71 |
+
global global_shown_pairs_cache, global_history_has_unsaved_changes # 确保可以修改这个flag
|
| 72 |
+
# print("尝试保存全局图片对历史...") # 调试信息
|
| 73 |
+
try:
|
| 74 |
+
data_to_save = {
|
| 75 |
+
target_img: [sorted(list(pair_fset)) for pair_fset in pairs_set]
|
| 76 |
+
for target_img, pairs_set in global_shown_pairs_cache.items()
|
| 77 |
+
}
|
| 78 |
+
temp_file = GLOBAL_HISTORY_FILE + ".tmp"
|
| 79 |
+
with open(temp_file, 'w', encoding='utf-8') as f:
|
| 80 |
+
json.dump(data_to_save, f, ensure_ascii=False, indent=2)
|
| 81 |
+
os.replace(temp_file, GLOBAL_HISTORY_FILE)
|
| 82 |
+
print(f"已成功将全局已展示图片对历史保存到 '{GLOBAL_HISTORY_FILE}'。")
|
| 83 |
+
global_history_has_unsaved_changes = False # 保存成功后重置标志
|
| 84 |
+
return True
|
| 85 |
+
except Exception as e:
|
| 86 |
+
print(f"错误:保存全局历史文件 '{GLOBAL_HISTORY_FILE}' 失败: {e}")
|
| 87 |
+
return False
|
| 88 |
+
|
| 89 |
+
load_global_shown_pairs()
|
| 90 |
+
|
| 91 |
+
# ==== 加载所有可用的目标图片 (保持不变) ====
|
| 92 |
master_image_list = []
|
|
|
|
| 93 |
if os.path.exists(TARGET_DIR):
|
| 94 |
try:
|
| 95 |
master_image_list = sorted(
|
|
|
|
| 106 |
|
| 107 |
# ==== 辅助函数 ====
|
| 108 |
def get_next_trial_info(current_trial_idx_in_run, current_run_image_list_for_trial, num_trials_in_this_run_for_trial):
|
| 109 |
+
global TARGET_DIR, METHOD_ROOTS, SUBJECTS, SENTINEL_TRIAL_INTERVAL, global_shown_pairs_cache, global_history_has_unsaved_changes
|
| 110 |
+
|
| 111 |
+
# ... (函数开始部分与上一版相同,直到哨兵试验逻辑结束) ...
|
| 112 |
+
if not current_run_image_list_for_trial or current_trial_idx_in_run >= num_trials_in_this_run_for_trial:
|
| 113 |
+
return None, current_trial_idx_in_run
|
| 114 |
+
|
| 115 |
img_filename_original = current_run_image_list_for_trial[current_trial_idx_in_run]
|
| 116 |
target_full_path = os.path.join(TARGET_DIR, img_filename_original)
|
| 117 |
trial_number_for_display = current_trial_idx_in_run + 1
|
| 118 |
+
|
| 119 |
pool = []
|
| 120 |
for m_root_path in METHOD_ROOTS:
|
| 121 |
method_name = os.path.basename(m_root_path)
|
| 122 |
for s_id in SUBJECTS:
|
| 123 |
+
base, ext = os.path.splitext(img_filename_original)
|
| 124 |
+
reconstructed_filename = f"{base}_0{ext}"
|
| 125 |
candidate_path = os.path.join(m_root_path, s_id, reconstructed_filename)
|
| 126 |
+
if os.path.exists(candidate_path):
|
| 127 |
+
internal_label = f"{method_name}/{s_id}/{reconstructed_filename}"
|
| 128 |
+
pool.append((internal_label, candidate_path))
|
| 129 |
+
|
| 130 |
trial_info = {"image_id": img_filename_original, "target_path": target_full_path, "cur_no": trial_number_for_display, "is_sentinel": False,
|
| 131 |
"left_display_label": "N/A", "left_internal_label": "N/A", "left_path": None,
|
| 132 |
"right_display_label": "N/A", "right_internal_label": "N/A", "right_path": None}
|
| 133 |
+
|
| 134 |
is_potential_sentinel_trial = (trial_number_for_display > 0 and trial_number_for_display % SENTINEL_TRIAL_INTERVAL == 0)
|
| 135 |
+
|
| 136 |
if is_potential_sentinel_trial:
|
| 137 |
+
# 哨兵试验逻辑不变
|
| 138 |
+
if not pool:
|
| 139 |
+
print(f"警告:哨兵图 '{img_filename_original}' (trial {trial_number_for_display}) 无候选。")
|
| 140 |
else:
|
| 141 |
print(f"生成哨兵试验 for '{img_filename_original}' (trial {trial_number_for_display})")
|
| 142 |
trial_info["is_sentinel"] = True
|
| 143 |
sentinel_candidate_target_tuple = ("目标图像", target_full_path)
|
| 144 |
random_reconstruction_candidate_tuple = random.choice(pool)
|
| 145 |
candidates_for_sentinel = [
|
| 146 |
+
(("目标图像", target_full_path), sentinel_candidate_target_tuple[0]),
|
| 147 |
(("重建图", random_reconstruction_candidate_tuple[1]), random_reconstruction_candidate_tuple[0])
|
| 148 |
]
|
| 149 |
random.shuffle(candidates_for_sentinel)
|
| 150 |
trial_info.update({
|
| 151 |
+
"left_display_label": candidates_for_sentinel[0][0][0],
|
| 152 |
+
"left_path": candidates_for_sentinel[0][0][1],
|
| 153 |
+
"left_internal_label": candidates_for_sentinel[0][1],
|
| 154 |
+
"right_display_label": candidates_for_sentinel[1][0][0],
|
| 155 |
+
"right_path": candidates_for_sentinel[1][0][1],
|
| 156 |
+
"right_internal_label": candidates_for_sentinel[1][1],
|
| 157 |
})
|
| 158 |
+
else: # 常规试验
|
| 159 |
+
if len(pool) < 2:
|
| 160 |
+
print(f"警告:常规图 '{img_filename_original}' (trial {trial_number_for_display}) 候选少于2 (找到 {len(pool)})。此试验无法进行。")
|
| 161 |
+
return None, current_trial_idx_in_run
|
| 162 |
+
|
| 163 |
+
target_global_history_set = global_shown_pairs_cache.setdefault(img_filename_original, set())
|
| 164 |
+
all_possible_pairs_in_pool = []
|
| 165 |
+
for c1, c2 in combinations(pool, 2):
|
| 166 |
+
pair_labels_fset = frozenset({c1[0], c2[0]})
|
| 167 |
+
all_possible_pairs_in_pool.append( ((c1, c2), pair_labels_fset) )
|
| 168 |
+
|
| 169 |
+
unseen_globally_pairs_with_data = [
|
| 170 |
+
item for item in all_possible_pairs_in_pool
|
| 171 |
+
if item[1] not in target_global_history_set
|
| 172 |
+
]
|
| 173 |
+
selected_candidates_tuples = None
|
| 174 |
+
|
| 175 |
+
if unseen_globally_pairs_with_data:
|
| 176 |
+
chosen_pair_data_and_labels = random.choice(unseen_globally_pairs_with_data)
|
| 177 |
+
selected_candidates_tuples = chosen_pair_data_and_labels[0]
|
| 178 |
+
chosen_pair_frozenset = chosen_pair_data_and_labels[1]
|
| 179 |
+
target_global_history_set.add(chosen_pair_frozenset)
|
| 180 |
+
global_history_has_unsaved_changes = True # <--- 标记全局历史已更新
|
| 181 |
+
# 不再在此处调用 save_global_shown_pairs()
|
| 182 |
+
# print(f"调试:目标 '{img_filename_original}': 新全局唯一对 {chosen_pair_frozenset} 已添加至缓存。未保存更改标志: {global_history_has_unsaved_changes}")
|
| 183 |
+
else:
|
| 184 |
+
# ... (处理所有对都已展示过的情况,与上一版相同) ...
|
| 185 |
+
print(f"警告:目标图 '{img_filename_original}': 来自当前池的所有 ({len(all_possible_pairs_in_pool)}) 个候选对均已在全局展示过。")
|
| 186 |
+
if all_possible_pairs_in_pool:
|
| 187 |
+
print("将随机选择一个重复的对(全局重复)。")
|
| 188 |
+
chosen_pair_data_and_labels = random.choice(all_possible_pairs_in_pool)
|
| 189 |
+
selected_candidates_tuples = chosen_pair_data_and_labels[0]
|
| 190 |
+
else:
|
| 191 |
+
print(f"错误:即使允许全局重复,也无法从池中选择图片对(池大小 {len(pool)})。")
|
| 192 |
+
return None, current_trial_idx_in_run
|
| 193 |
+
|
| 194 |
+
display_order_candidates = list(selected_candidates_tuples)
|
| 195 |
+
if random.random() > 0.5:
|
| 196 |
+
display_order_candidates = display_order_candidates[::-1]
|
| 197 |
+
|
| 198 |
trial_info.update({
|
| 199 |
+
"left_display_label": "候选图 1", "left_path": display_order_candidates[0][1], "left_internal_label": display_order_candidates[0][0],
|
| 200 |
+
"right_display_label": "候选图 2", "right_path": display_order_candidates[1][1], "right_internal_label": display_order_candidates[1][0],
|
| 201 |
})
|
| 202 |
return trial_info, current_trial_idx_in_run + 1
|
| 203 |
|
| 204 |
+
# ==== 批量保存用户选择日志函数 (保持不变) ====
|
| 205 |
+
def save_collected_logs_batch(list_of_log_entries, user_identifier_str, batch_identifier):
|
|
|
|
|
|
|
|
|
|
| 206 |
global DATASET_REPO_ID, BATCH_LOG_FOLDER
|
| 207 |
if not list_of_log_entries:
|
| 208 |
+
print("批量保存用户日志:没有累积的日志。")
|
| 209 |
+
return True # 认为无日志即成功
|
| 210 |
|
| 211 |
identifier_safe = str(user_identifier_str if user_identifier_str else "unknown_user_session").replace('.', '_').replace(':', '_').replace('/', '_').replace(' ', '_')
|
| 212 |
+
print(f"用户 {identifier_safe} - 准备批量保存 {len(list_of_log_entries)} 条选择日志 (批次标识: {batch_identifier})...")
|
|
|
|
| 213 |
try:
|
| 214 |
token = os.getenv("HF_TOKEN")
|
| 215 |
+
if not token: print("错误:HF_TOKEN 未设置。无法批量保存选择日志。"); return False
|
| 216 |
+
if not DATASET_REPO_ID: print("错误:DATASET_REPO_ID 未配置。无法批量保存选择日志。"); return False
|
| 217 |
|
| 218 |
api = HfApi(token=token)
|
| 219 |
timestamp_str = datetime.now().strftime('%Y%m%d_%H%M%S_%f')
|
| 220 |
batch_filename = f"batch_user-{identifier_safe}_id-{batch_identifier}_{timestamp_str}_logs-{len(list_of_log_entries)}.jsonl"
|
| 221 |
+
path_in_repo = f"{BATCH_LOG_FOLDER}/{identifier_safe}/{batch_filename}"
|
|
|
|
| 222 |
jsonl_content = ""
|
| 223 |
for log_entry in list_of_log_entries:
|
| 224 |
try:
|
| 225 |
if isinstance(log_entry, dict): jsonl_content += json.dumps(log_entry, ensure_ascii=False) + "\n"
|
| 226 |
+
else: print(f"警告:批量保存选择日志时,条目非字典:{log_entry}")
|
| 227 |
except Exception as json_err:
|
| 228 |
+
print(f"错误:批量保存选择日志序列化单条时出错: {log_entry}. 错误: {json_err}")
|
| 229 |
+
jsonl_content += json.dumps({"error": "serialization_failed_in_batch_user_log", "original_data_preview": str(log_entry)[:100],"timestamp": datetime.now().isoformat()}, ensure_ascii=False) + "\n"
|
| 230 |
+
if not jsonl_content.strip(): print(f"用户 {identifier_safe} (批次 {batch_identifier}) 无可序列化选择日志。"); return True
|
|
|
|
| 231 |
|
| 232 |
log_bytes = jsonl_content.encode('utf-8')
|
| 233 |
file_like_object = io.BytesIO(log_bytes)
|
| 234 |
+
print(f"准备批量上传选择日志文件: {path_in_repo} ({len(log_bytes)} bytes)")
|
|
|
|
| 235 |
api.upload_file(
|
| 236 |
path_or_fileobj=file_like_object, path_in_repo=path_in_repo, repo_id=DATASET_REPO_ID, repo_type="dataset",
|
| 237 |
+
commit_message=f"Batch user choice logs for {identifier_safe}, batch_id {batch_identifier} ({len(list_of_log_entries)} entries)"
|
| 238 |
)
|
| 239 |
+
print(f"批量选择日志已成功保存到 HF Dataset: {DATASET_REPO_ID}/{path_in_repo}")
|
| 240 |
return True
|
| 241 |
except Exception as e:
|
| 242 |
+
print(f"批量保存选择日志 (user {identifier_safe}, batch_id {batch_identifier}) 失败: {e}"); traceback.print_exc()
|
| 243 |
return False
|
| 244 |
|
| 245 |
+
|
| 246 |
# ==== 主要的 Gradio 事件处理函数 ====
|
| 247 |
def process_experiment_step(
|
| 248 |
s_trial_idx_val, s_run_no_val, s_user_logs_val, s_current_trial_data_val, s_user_session_id_val,
|
| 249 |
s_current_run_image_list_val, s_num_trials_this_run_val,
|
| 250 |
action_type=None, choice_value=None, request: gr.Request = None
|
| 251 |
):
|
| 252 |
+
global master_image_list, NUM_TRIALS_PER_RUN, outputs_ui_components_definition, LOG_BATCH_SIZE, global_history_has_unsaved_changes
|
| 253 |
|
| 254 |
+
# ... (函数开始部分与上一版类似) ...
|
| 255 |
output_s_trial_idx = s_trial_idx_val; output_s_run_no = s_run_no_val
|
| 256 |
output_s_user_logs = list(s_user_logs_val); output_s_current_trial_data = dict(s_current_trial_data_val) if s_current_trial_data_val else {}
|
| 257 |
output_s_user_session_id = s_user_session_id_val; output_s_current_run_image_list = list(s_current_run_image_list_val)
|
| 258 |
output_s_num_trials_this_run = s_num_trials_this_run_val
|
| 259 |
user_ip_fallback = request.client.host if request else "unknown_ip"
|
| 260 |
user_identifier_for_logging = output_s_user_session_id if output_s_user_session_id else user_ip_fallback
|
| 261 |
+
|
| 262 |
len_ui_outputs = len(outputs_ui_components_definition)
|
| 263 |
def create_ui_error_tuple(message, progress_msg_text): return (gr.update(visible=False),) * 3 + ("", "", message, progress_msg_text) + (gr.update(interactive=True), gr.update(interactive=False), gr.update(interactive=False)) + (gr.update(visible=False),)
|
| 264 |
def create_no_change_tuple(): return (gr.update(),) * len_ui_outputs
|
| 265 |
user_id_display_text = output_s_user_session_id if output_s_user_session_id else "用户ID待分配"
|
| 266 |
|
| 267 |
+
|
| 268 |
if action_type == "record_choice":
|
| 269 |
+
# ... (日志记录逻辑与上一版相同) ...
|
| 270 |
if output_s_current_trial_data.get("data") and output_s_current_trial_data["data"].get("left_internal_label"):
|
| 271 |
chosen_internal_label = (output_s_current_trial_data["data"]["left_internal_label"] if choice_value == "left" else output_s_current_trial_data["data"]["right_internal_label"])
|
|
|
|
| 272 |
parsed_chosen_method, parsed_chosen_subject, parsed_chosen_filename = "N/A", "N/A", "N/A"
|
| 273 |
if chosen_internal_label == "目标图像": parsed_chosen_method, parsed_chosen_subject, parsed_chosen_filename = "TARGET", "GT", output_s_current_trial_data["data"]["image_id"]
|
| 274 |
else:
|
|
|
|
| 277 |
elif len(parts) == 2: parsed_chosen_method, parsed_chosen_subject = parts[0].strip(), parts[1].strip()
|
| 278 |
elif len(parts) == 1: parsed_chosen_method = parts[0].strip()
|
| 279 |
log_entry = {
|
| 280 |
+
"timestamp": datetime.now().isoformat(), "user_identifier": user_identifier_for_logging, "run_no": output_s_run_no,
|
| 281 |
"image_id": output_s_current_trial_data["data"]["image_id"],
|
| 282 |
+
"left_internal_label": output_s_current_trial_data["data"]["left_internal_label"],
|
| 283 |
"right_internal_label": output_s_current_trial_data["data"]["right_internal_label"],
|
| 284 |
+
"chosen_side": choice_value, "chosen_internal_label": chosen_internal_label,
|
| 285 |
"chosen_method": parsed_chosen_method, "chosen_subject": parsed_chosen_subject, "chosen_filename": parsed_chosen_filename,
|
| 286 |
+
"trial_sequence_in_run": output_s_current_trial_data["data"]["cur_no"],
|
| 287 |
"is_sentinel": output_s_current_trial_data["data"]["is_sentinel"]
|
| 288 |
}
|
| 289 |
output_s_user_logs.append(log_entry)
|
| 290 |
print(f"用户 {user_identifier_for_logging} 记录选择 (img: {log_entry['image_id']})。当前批次日志数: {len(output_s_user_logs)}")
|
| 291 |
|
| 292 |
+
|
| 293 |
+
# !!! 修改:当用户日志达到批量大小时,同时尝试保存全局历史(如果它有更改)!!!
|
| 294 |
if len(output_s_user_logs) >= LOG_BATCH_SIZE:
|
| 295 |
+
print(f"累积用户选择日志达到 {LOG_BATCH_SIZE} 条,准备批量保存...")
|
| 296 |
+
batch_id_for_filename = f"run{output_s_run_no}_trialidx{output_s_trial_idx}_logcount{len(output_s_user_logs)}"
|
| 297 |
+
|
| 298 |
+
# 1. 保存用户选择日志
|
| 299 |
+
user_logs_save_success = save_collected_logs_batch(list(output_s_user_logs), user_identifier_for_logging, batch_id_for_filename)
|
| 300 |
+
if user_logs_save_success:
|
| 301 |
+
print("批量用户选择日志已成功(或尝试)保存,将清空累积的用户选择日志列表。")
|
| 302 |
+
output_s_user_logs = [] # 清空已保存的用户日志
|
| 303 |
else:
|
| 304 |
+
print("警告:批量用户选择日志保存失败。选择日志将继续累积,下次达到阈值时重试。")
|
| 305 |
+
|
| 306 |
+
# 2. 检查并保存全局图片对历史(如果自上次保存后有更改)
|
| 307 |
+
if global_history_has_unsaved_changes:
|
| 308 |
+
print("检测到全局图片对历史自上次保存后有更新,将一并保存...")
|
| 309 |
+
save_global_shown_pairs() # 此函数内部会在成功保存后将 global_history_has_unsaved_changes 置为 False
|
| 310 |
+
else:
|
| 311 |
+
print("全局图片对历史自上次保存后无更新,无需保存。")
|
| 312 |
+
else: # 处理记录选择时数据为空的错误
|
| 313 |
+
print(f"用户 {user_identifier_for_logging} 错误:记录选择时数据为空!")
|
| 314 |
error_ui_updates = create_ui_error_tuple("记录选择时内部错误。", f"用户ID: {user_id_display_text} | 进度:{output_s_trial_idx}/{output_s_num_trials_this_run}")
|
| 315 |
+
# 返回所有状态变量的当前值以及错误UI更新
|
| 316 |
return output_s_trial_idx, output_s_run_no, output_s_user_logs, output_s_current_trial_data, output_s_user_session_id, output_s_current_run_image_list, output_s_num_trials_this_run, *error_ui_updates
|
| 317 |
|
| 318 |
+
# ... (start_experiment 逻辑与上一版相同) ...
|
| 319 |
if action_type == "start_experiment":
|
| 320 |
is_first = (output_s_num_trials_this_run == 0 and output_s_trial_idx == 0 and output_s_run_no == 1)
|
| 321 |
is_completed_for_restart = (output_s_num_trials_this_run > 0 and output_s_trial_idx >= output_s_num_trials_this_run)
|
| 322 |
if is_first or is_completed_for_restart:
|
| 323 |
if not master_image_list: error_ui = create_ui_error_tuple("错误: 无可用目标图片!", f"用户ID: {user_id_display_text} | 进度: 0/0"); return 0, output_s_run_no, output_s_user_logs, {}, output_s_user_session_id, [], 0, *error_ui
|
| 324 |
+
|
| 325 |
+
if is_completed_for_restart:
|
|
|
|
|
|
|
| 326 |
output_s_run_no += 1
|
| 327 |
+
|
| 328 |
num_avail = len(master_image_list); run_size = min(num_avail, NUM_TRIALS_PER_RUN)
|
| 329 |
if run_size == 0: error_ui = create_ui_error_tuple("错误: 采样图片数为0!", f"用户ID: {user_id_display_text} | 进度: 0/0"); return 0, output_s_run_no, output_s_user_logs, {}, output_s_user_session_id, [], 0, *error_ui
|
| 330 |
+
|
| 331 |
output_s_current_run_image_list = random.sample(master_image_list, run_size)
|
| 332 |
output_s_num_trials_this_run = run_size
|
| 333 |
output_s_trial_idx = 0
|
|
|
|
| 334 |
output_s_current_trial_data = {}
|
|
|
|
|
|
|
| 335 |
print(f"开始/继续轮次 {output_s_run_no} (用户ID: {output_s_user_session_id}). 随机选择 {output_s_num_trials_this_run} 张图片.")
|
| 336 |
+
else:
|
| 337 |
print(f"用户 {user_identifier_for_logging} 在第 {output_s_run_no} 轮,试验 {output_s_trial_idx} 点击开始,但轮次未完成。忽略。")
|
| 338 |
no_change_ui = create_no_change_tuple()
|
| 339 |
return output_s_trial_idx, output_s_run_no, output_s_user_logs, output_s_current_trial_data, output_s_user_session_id, output_s_current_run_image_list, output_s_num_trials_this_run, *no_change_ui
|
| 340 |
+
|
| 341 |
+
# ... (轮次结束处理与上一版相同) ...
|
| 342 |
if output_s_trial_idx >= output_s_num_trials_this_run and output_s_num_trials_this_run > 0:
|
|
|
|
| 343 |
print(f"用户 {output_s_user_session_id} 已完成第 {output_s_run_no} 轮。等待下一批或下一轮开始。")
|
| 344 |
+
# 检查是否有未保存的全局历史,即使日志批次未满,也可能在轮次结束时考虑保存
|
| 345 |
+
# 但当前逻辑是仅在日志批次满时保存全局历史,这里可以保持一致或添加额外逻辑
|
| 346 |
+
if global_history_has_unsaved_changes:
|
| 347 |
+
print(f"提示:轮次 {output_s_run_no} 结束,仍有未保存的全局图片对历史更改。将在下次日志批量保存时一并处理。")
|
| 348 |
+
|
| 349 |
prog_text = f"用户ID: {output_s_user_session_id} | 进度:{output_s_num_trials_this_run}/{output_s_num_trials_this_run} | 第 {output_s_run_no} 轮 🎉"
|
| 350 |
ui_updates = list(create_ui_error_tuple(f"🎉 第 {output_s_run_no} 轮完成!请点击“开始试验 / 下一轮”继续或开始新批次。", prog_text))
|
|
|
|
| 351 |
ui_updates[7]=gr.update(interactive=True); ui_updates[8]=gr.update(interactive=False); ui_updates[9]=gr.update(interactive=False)
|
| 352 |
ui_updates[0]=gr.update(value=None,visible=False); ui_updates[1]=gr.update(value=None,visible=False); ui_updates[2]=gr.update(value=None,visible=False)
|
| 353 |
yield output_s_trial_idx, output_s_run_no, output_s_user_logs, output_s_current_trial_data, output_s_user_session_id, output_s_current_run_image_list, output_s_num_trials_this_run, *ui_updates; return
|
| 354 |
|
| 355 |
+
|
| 356 |
+
# ... (获取并显示下一个试验的逻辑,与上一版相同,调用 get_next_trial_info 不需要 user_id for history) ...
|
| 357 |
if not output_s_current_run_image_list or output_s_num_trials_this_run == 0:
|
| 358 |
error_ui = create_ui_error_tuple("错误: 无法加载试验图片 (列表为空)", f"用户ID: {user_id_display_text} | 进度: N/A")
|
| 359 |
return output_s_trial_idx, output_s_run_no, output_s_user_logs, {"data": None}, output_s_user_session_id, [], 0, *error_ui
|
| 360 |
+
|
| 361 |
+
trial_info, next_s_trial_idx_for_state = get_next_trial_info(
|
| 362 |
+
output_s_trial_idx,
|
| 363 |
+
output_s_current_run_image_list,
|
| 364 |
+
output_s_num_trials_this_run
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
if trial_info is None:
|
| 368 |
+
print(f"错误:用户 {user_identifier_for_logging},轮次 {output_s_run_no},试验 {output_s_trial_idx}: get_next_trial_info 返回 None。")
|
| 369 |
+
error_msg_display = "无法加载下一个试验,可能是因为候选图片不足或所有唯一组合已用尽。"
|
| 370 |
+
if len(METHOD_ROOTS) * len(SUBJECTS) < 2 :
|
| 371 |
+
error_msg_display = "候选图片来源不足,无法形成对比试验。"
|
| 372 |
+
error_ui_updates = create_ui_error_tuple(error_msg_display, f"用户ID: {user_id_display_text} | 进度:{output_s_trial_idx}/{output_s_num_trials_this_run}")
|
| 373 |
output_s_current_trial_data = {"data": None}
|
| 374 |
+
return output_s_trial_idx, output_s_run_no, output_s_user_logs, output_s_current_trial_data, output_s_user_session_id, output_s_current_run_image_list, output_s_num_trials_this_run, *error_ui_updates
|
| 375 |
+
|
| 376 |
output_s_current_trial_data = {"data": trial_info}
|
| 377 |
prog_text = f"用户ID: {output_s_user_session_id} | 进度:{trial_info['cur_no']}/{output_s_num_trials_this_run} | 第 {output_s_run_no} 轮"
|
| 378 |
ui_show_target_updates = list(create_no_change_tuple())
|
|
|
|
| 388 |
ui_show_candidates_updates[7]=gr.update(interactive=False); ui_show_candidates_updates[8]=gr.update(interactive=True); ui_show_candidates_updates[9]=gr.update(interactive=True)
|
| 389 |
yield next_s_trial_idx_for_state, output_s_run_no, output_s_user_logs, output_s_current_trial_data, output_s_user_session_id, output_s_current_run_image_list, output_s_num_trials_this_run, *ui_show_candidates_updates
|
| 390 |
|
| 391 |
+
|
| 392 |
+
# ==== Gradio UI 定义 和 程序入口 (保持不变) ====
|
| 393 |
+
# ... (welcome_page_markdown, handle_agree_and_start, gr.Blocks, if __name__ == "__main__": etc.) ...
|
| 394 |
welcome_page_markdown = """
|
| 395 |
## 欢迎加入实验!
|
| 396 |
您好!非常感谢您抽出宝贵时间参与我们的视觉偏好评估实验。您的选择将帮助我们改进重建算法,让机器生成的图像更贴近人类视觉体验!
|
|
|
|
| 418 |
再次感谢您的参与与支持!您每一次认真选择都对我们的研究意义重大。祝您一切顺利,实验愉快!
|
| 419 |
"""
|
| 420 |
def handle_agree_and_start(name, gender, age, education, request: gr.Request):
|
|
|
|
| 421 |
error_messages_list = []
|
| 422 |
if not name or str(name).strip() == "": error_messages_list.append("姓名 不能为空。")
|
| 423 |
if gender is None or str(gender).strip() == "": error_messages_list.append("性别 必须选择。")
|
|
|
|
| 425 |
elif not (isinstance(age, (int, float)) and 1 <= age <= 120):
|
| 426 |
try: num_age = float(age);
|
| 427 |
except (ValueError, TypeError): error_messages_list.append("年龄必须是一个有效的数字。")
|
| 428 |
+
else:
|
| 429 |
if not (1 <= num_age <= 120): error_messages_list.append("年龄必须在 1 到 120 之间。")
|
| 430 |
+
if education is None or str(education).strip() == "其他": error_messages_list.append("学历 必须选择。")
|
| 431 |
if error_messages_list:
|
| 432 |
full_error_message = "请修正以下错误:\n" + "\n".join([f"- {msg}" for msg in error_messages_list])
|
| 433 |
print(f"用户输入验证失败: {full_error_message}")
|
| 434 |
return gr.update(), False, gr.update(visible=True), gr.update(visible=False), full_error_message
|
| 435 |
s_name = str(name).strip().replace(" ","_").replace("/","_").replace("\\","_")
|
| 436 |
s_gender = str(gender).strip().replace(" ","_").replace("/","_").replace("\\","_")
|
| 437 |
+
s_age = str(int(float(age)))
|
| 438 |
s_education = str(education).strip().replace(" ","_").replace("/","_").replace("\\","_")
|
| 439 |
user_id_str = f"N-{s_name}_G-{s_gender}_A-{s_age}_E-{s_education}"
|
| 440 |
print(f"用户信息收集完毕,生成用户ID: {user_id_str}")
|
| 441 |
return user_id_str, True, gr.update(visible=False), gr.update(visible=True), ""
|
| 442 |
|
| 443 |
with gr.Blocks(css=CSS, title="图像重建主观评估") as demo:
|
|
|
|
| 444 |
s_show_experiment_ui = gr.State(False); s_trial_index = gr.State(0); s_run_no = gr.State(1)
|
| 445 |
s_user_logs = gr.State([]); s_current_trial_data = gr.State({}); s_user_session_id = gr.State(None)
|
| 446 |
s_current_run_image_list = gr.State([]); s_num_trials_this_run = gr.State(0)
|
| 447 |
+
|
| 448 |
+
welcome_container = gr.Column(visible=True)
|
| 449 |
experiment_container = gr.Column(visible=False)
|
| 450 |
|
| 451 |
with welcome_container:
|
|
|
|
| 452 |
gr.Markdown(welcome_page_markdown)
|
| 453 |
with gr.Row(): user_name_input = gr.Textbox(label="请输入您的姓名或代号 (例如 张三 或 User001)", placeholder="例如:张三 -> ZS"); user_gender_input = gr.Radio(label="性别", choices=["男", "女"])
|
| 454 |
with gr.Row(): user_age_input = gr.Number(label="年龄 (请输入1-120的整数)", minimum=1, maximum=120, step=1); user_education_input = gr.Dropdown(label="学历", choices=["其他","初中及以下","高中(含中专)", "大专(含在读)", "本科(含在读)", "硕士(含在读)", "博士(含在读)"])
|
|
|
|
| 456 |
btn_agree_and_start = gr.Button("我已阅读上述说明并同意参与实验")
|
| 457 |
|
| 458 |
with experiment_container:
|
|
|
|
| 459 |
gr.Markdown("## 🧠 图像重建主观评估实验"); gr.Markdown(f"每轮实验大约有 {NUM_TRIALS_PER_RUN} 次比较。")
|
| 460 |
with gr.Row():
|
| 461 |
with gr.Column(scale=1, min_width=300): left_img = gr.Image(label="左候选图", visible=False, height=400, interactive=False); left_lbl = gr.Textbox(label="左图信息", visible=False, interactive=False, max_lines=1); btn_left = gr.Button("选择左图 (更相似)", interactive=False, elem_classes="compact_button")
|
|
|
|
| 466 |
with gr.Row(): btn_start = gr.Button("开始试验 / 下一轮")
|
| 467 |
file_out_placeholder = gr.File(label=" ", visible=False, interactive=False)
|
| 468 |
|
| 469 |
+
outputs_ui_components_definition = [
|
| 470 |
target_img, left_img, right_img, left_lbl, right_lbl, status_text, progress_text,
|
| 471 |
+
btn_start, btn_left, btn_right, file_out_placeholder
|
| 472 |
]
|
| 473 |
+
click_inputs_base = [
|
| 474 |
s_trial_index, s_run_no, s_user_logs, s_current_trial_data, s_user_session_id,
|
| 475 |
s_current_run_image_list, s_num_trials_this_run
|
| 476 |
]
|
| 477 |
+
event_outputs = [
|
| 478 |
s_trial_index, s_run_no, s_user_logs, s_current_trial_data, s_user_session_id,
|
| 479 |
+
s_current_run_image_list, s_num_trials_this_run, *outputs_ui_components_definition
|
| 480 |
]
|
| 481 |
|
| 482 |
btn_agree_and_start.click(fn=handle_agree_and_start, inputs=[user_name_input, user_gender_input, user_age_input, user_education_input], outputs=[s_user_session_id, s_show_experiment_ui, welcome_container, experiment_container, welcome_error_msg])
|
|
|
|
| 484 |
btn_left.click(fn=partial(process_experiment_step, action_type="record_choice", choice_value="left"), inputs=click_inputs_base, outputs=event_outputs, queue=True)
|
| 485 |
btn_right.click(fn=partial(process_experiment_step, action_type="record_choice", choice_value="right"), inputs=click_inputs_base, outputs=event_outputs, queue=True)
|
| 486 |
|
|
|
|
| 487 |
if __name__ == "__main__":
|
|
|
|
| 488 |
if not master_image_list: print("\n关键错误:程序无法启动,因无目标图片。"); exit()
|
| 489 |
else:
|
| 490 |
print(f"从 '{TARGET_DIR}' 加载 {len(master_image_list)} 张目标图片。每轮选 {NUM_TRIALS_PER_RUN} 张。")
|
|
|
|
| 492 |
else: print(f"方法根目录: {METHOD_ROOTS}")
|
| 493 |
if not SUBJECTS: print("警告: SUBJECTS 列表为空。")
|
| 494 |
else: print(f"Subjects: {SUBJECTS}")
|
| 495 |
+
print(f"用户选择日志保存到 Dataset: '{DATASET_REPO_ID}' 的 '{BATCH_LOG_FOLDER}/' 文件夹")
|
| 496 |
+
if not os.getenv("HF_TOKEN"): print("警告: HF_TOKEN 未设置。日志无法保存到Hugging Face Dataset。\n 请在 Space Secrets 中设置 HF_TOKEN。")
|
|
|
|
|
|
|
|
|
|
| 497 |
else: print("HF_TOKEN 已找到。")
|
| 498 |
+
print(f"全局图片对历史将从 '{GLOBAL_HISTORY_FILE}' 加载/保存到此文件。")
|
| 499 |
+
|
| 500 |
path_to_allow_serving_from = BASE_IMAGE_DIR
|
| 501 |
allowed_paths_list = []
|
| 502 |
if os.path.exists(path_to_allow_serving_from) and os.path.isdir(path_to_allow_serving_from):
|
| 503 |
allowed_paths_list.append(os.path.abspath(path_to_allow_serving_from))
|
| 504 |
print(f"Gradio `demo.launch()` 配置 allowed_paths: {allowed_paths_list}")
|
| 505 |
else: print(f"关键警告:图片基础目录 '{path_to_allow_serving_from}' ({os.path.abspath(path_to_allow_serving_from) if path_to_allow_serving_from else 'N/A'}) 不存在或非目录。")
|
| 506 |
+
|
| 507 |
print("启动 Gradio 应用...")
|
| 508 |
if allowed_paths_list: demo.launch(allowed_paths=allowed_paths_list)
|
| 509 |
else: demo.launch()
|