JingdongZhang
Bug Fixed
f559ca1
import gradio as gr
import os
import random
import pandas as pd
import time
from datetime import datetime
from huggingface_hub import HfApi, login
# ================= 配置区域 (Configuration) =================
# 路径配置
IMAGE_DIR = "./user_study_pairs"
RESULT_FILE = "user_study_results.csv"
# 示例路径
EXAMPLE_POS_DIR = "./example_pos"
EXAMPLE_NEG_DIR = "./example_neg"
# 每个任务抽取的样本数量
SAMPLES_PER_TASK = 5
# 任务列表
TASKS = ['haze', 'shadow', 'reflections', 'lens_flares']
# 任务名称映射
TASK_NAMES = {
'haze': {'cn': '雾霾', 'en': 'Fog/Haze'},
'shadow': {'cn': '阴影', 'en': 'Shadows'},
'reflections': {'cn': '玻璃反光', 'en': 'Reflections'},
'lens_flares': {'cn': '镜头光晕', 'en': 'Lens Flares'}
}
# ================= 核心逻辑 (Core Logic) =================
def get_random_batch():
"""从文件夹中按任务分层随机抽取样本"""
if not os.path.exists(IMAGE_DIR):
return []
all_files = [f for f in os.listdir(IMAGE_DIR) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
selected_batch = []
for task in TASKS:
task_files = [f for f in all_files if f.startswith(task)]
if len(task_files) >= SAMPLES_PER_TASK:
chosen = random.sample(task_files, SAMPLES_PER_TASK)
else:
chosen = task_files
for f in chosen:
selected_batch.append((os.path.join(IMAGE_DIR, f), f, task))
random.shuffle(selected_batch)
return selected_batch
# ================= 数据保存逻辑 (Data Persistence) =================
# 配置 Dataset 仓库 ID (请修改为您的实际 Dataset ID)
DATASET_REPO_ID = "jdzhang0929/UniSER-User-Study-Results"
def save_result(user_id, results):
"""保存结果到 CSV 并同步到 Hugging Face Dataset"""
# 1. 构建当前用户的 DataFrame
data = []
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
for img_name, task, is_success in results:
data.append({
"User_ID": user_id,
"Timestamp": timestamp,
"Image_Name": img_name,
"Task": task,
"Is_Success": is_success
})
new_df = pd.DataFrame(data)
# 2. 保存到本地 (作为临时备份)
if not os.path.exists(RESULT_FILE):
new_df.to_csv(RESULT_FILE, index=False, encoding='utf_8_sig')
else:
new_df.to_csv(RESULT_FILE, mode='a', header=False, index=False, encoding='utf_8_sig')
# 3. ☁️ 同步到 Hugging Face Dataset (关键步骤)
try:
# 获取 Token (从 Space 的 Secret 中读取)
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
# 登录
login(token=hf_token)
api = HfApi()
# 生成一个唯一的文件名,防止多人同时提交冲突
# 格式: results_用户ID_时间戳.csv
unique_filename = f"results_{user_id}_{int(time.time())}.csv"
# 保存这个用户的单次记录为单独的 CSV
temp_file = f"temp_{unique_filename}"
new_df.to_csv(temp_file, index=False, encoding='utf_8_sig')
# 上传文件到 Dataset 仓库
api.upload_file(
path_or_fileobj=temp_file,
path_in_repo=unique_filename, # 在 Dataset 里的文件名
repo_id=DATASET_REPO_ID,
repo_type="dataset"
)
# 清理临时文件
os.remove(temp_file)
print(f"✅ 数据已同步到 Dataset: {unique_filename}")
else:
print("⚠️ 未找到 HF_TOKEN,无法同步到云端 Dataset")
except Exception as e:
print(f"❌ 上传 Dataset 失败: {e}")
def generate_label_html(task_key):
"""生成图片标签 (Dark Mode 兼容)"""
if task_key == 'shadow':
return """
<div class='big-label-container'>
<div style='flex:1;'>Input<br><span class='sub-label'>(输入原图)</span></div>
<div style='flex:1; border-left: 1px solid #ccc; border-right: 1px solid #ccc;'>Masked Shadow<br><span class='sub-label'>(阴影蒙版)</span></div>
<div style='flex:1;'>Output<br><span class='sub-label'>(修复结果)</span></div>
</div>
"""
else:
return """
<div class='big-label-container'>
<div style='flex:1;'>Input<br><span class='sub-label'>(输入原图)</span></div>
<div style='flex:1; border-left: 1px solid #ccc;'>Output<br><span class='sub-label'>(修复结果)</span></div>
</div>
"""
def generate_question_text(task_key):
"""生成问题描述"""
name_cn = TASK_NAMES.get(task_key, {}).get('cn', '干扰效应')
name_en = TASK_NAMES.get(task_key, {}).get('en', 'Artifacts')
if task_key == 'shadow':
q1_cn = "1. <b>蒙版红色区域 (Masked Red Region)</b> 内的阴影是否去除干净?"
q1_en = "1. Is the shadow within the <b>Masked Red Region</b> removed cleanly?"
else:
q1_cn = f"1. 画面中的 <b>“{name_cn}”</b> 是否去除干净?"
q1_en = f"1. Is the <b>“{name_en}”</b> removed cleanly?"
return f"""
<div class='question-box'>
<h3>请仔细评估 / Please Evaluate:</h3>
<p>{q1_cn}<br><span class='en-text'>{q1_en}</span></p>
<hr style='margin: 10px 0; opacity: 0.3; border-top: 1px solid #ccc;'>
<p>2. 原本的背景内容是否 <b>保持完整</b> 且未被篡改?<br>
<span class='en-text'>2. Is the original background <b>preserved</b> without distortion?</span></p>
</div>
"""
# ================= 样式定义 (CSS) =================
# 终极修复版:加入对斜体 (em/i) 的强制覆盖
CUSTOM_CSS = """
<style>
footer {visibility: hidden}
/* --- 1. 容器基础样式 --- */
.big-label-container {
display: flex;
text-align: center;
font-weight: bold;
font-size: 1.1em;
background-color: #e8f4f8; /* 浅蓝背景 */
padding: 8px 0;
border-radius: 8px;
margin-bottom: 5px;
width: 100%;
align-items: center;
}
.question-box {
background-color: #f9f9f9; /* 浅灰背景 */
padding: 15px;
border-radius: 10px;
border: 1px solid #ddd;
margin-top: 10px;
}
/* --- 2. 暴力覆盖夜间模式文字颜色 (Nuclear Override) --- */
/* 针对普通文本 (p, span, div) */
.question-box, .question-box p, .question-box span, .question-box div,
.big-label-container, .big-label-container div, .big-label-container span {
color: #333333 !important; /* 强制深灰 */
}
/* [新增] 针对斜体文字 (em, i) -> 修复英文说明看不清的问题 */
.question-box em, .question-box i {
color: #555555 !important; /* 稍微浅一点的深灰,区别于正文 */
font-style: italic;
}
/* 针对粗体字 (strong, b, h3) -> 修复中文标题 */
.question-box strong, .question-box b, .question-box h3,
.big-label-container strong, .big-label-container b {
color: #2c3e50 !important; /* 强制深蓝/深灰 */
}
/* --- 3. 特殊文字颜色 --- */
/* 英文翻译文字 (span class='en-text') */
span.en-text {
color: #666666 !important;
font-style: italic;
font-size: 0.95em;
}
/* 标签的小字说明 */
span.sub-label {
font-size: 0.8em;
font-weight: normal;
color: #555555 !important;
}
/* --- 4. Loading 界面 --- */
.loading-container {
display: flex;
justify-content: center;
align-items: center;
height: 300px;
font-size: 1.5em;
opacity: 0.7;
flex-direction: column;
}
</style>
"""
# ================= 界面构建 (UI Layout) =================
with gr.Blocks(title="Visual Perception Study") as demo:
# 注入样式
gr.HTML(CUSTOM_CSS)
# --- 状态变量 ---
state_batch = gr.State([])
state_index = gr.State(0)
state_results = gr.State([])
state_user_id = gr.State("")
# ================= 页面 1: 欢迎与示例 =================
with gr.Group(visible=True) as welcome_page:
gr.Markdown("# 📋 图像修复主观质量评估 / Perceptual Image Restoration Study")
with gr.Row():
with gr.Column(scale=2):
gr.Markdown("""
### 👋 欢迎 / Welcome
本研究旨在评估图像修复算法在真实场景下的表现。
This study evaluates image restoration algorithms in real-world scenarios.
### 🎯 评估标准 / Criteria for Success
请判断修复结果是否 **同时满足** 以下两个条件:
Please judge whether the restoration meets **BOTH** criteria:
1. **去除彻底 (Clean Removal)**: 干扰(雾霾/烟雾/阴影/反光/镜头炫光)已消失。
2. **保真度 (Content Fidelity)**: 背景物体、结构、颜色未被篡改。
""")
start_btn = gr.Button("🚀 开始测试 / Start Survey", variant="primary", size="lg")
gr.Markdown(f"""
* **题目数量 / Questions**: {len(TASKS) * SAMPLES_PER_TASK}
* **匿名性 / Anonymity**: 结果仅用于学术统计。
""")
gr.Markdown("---")
gr.Markdown("### 📚 示例参考 / Examples")
# --- Positive Examples ---
gr.Markdown("#### ✅ 成功示例 / Success Cases (Clean & Preserved)")
# 预先构建数据列表 [(路径, 标题), ...]
pos_gallery_data = []
pos_files = ["pos_haze.jpg", "pos_shadow.jpg", "pos_ref.jpg", "pos_flare.jpg"]
pos_captions = ["Haze Removal", "Shadow Removal", "Reflection Removal", "Flare Removal"]
for p_file, p_cap in zip(pos_files, pos_captions):
p_path = os.path.join(EXAMPLE_POS_DIR, p_file)
if os.path.exists(p_path):
pos_gallery_data.append((p_path, p_cap))
# 使用 Gallery 组件:columns=4 强制一行四列对齐
if pos_gallery_data:
gr.Gallery(
value=pos_gallery_data,
show_label=False,
columns=4, # 强制 4 列对齐
rows=1,
height="auto",
object_fit="contain"
)
# --- Negative Examples ---
gr.Markdown("#### ❌ 失败示例 / Failure Cases")
with gr.Row():
# Neg 1: Identity Changed
neg_path_1 = os.path.join(EXAMPLE_NEG_DIR, "neg_identity_changed.jpg")
if os.path.exists(neg_path_1):
with gr.Column():
gr.Image(neg_path_1, interactive=False, show_label=False)
gr.Markdown("**❌ 失败原因:背景结构改变**<br>*(Failure: Structure/Identity Changed)*", elem_classes="question-box")
# Neg 2: Not Removed
neg_path_2 = os.path.join(EXAMPLE_NEG_DIR, "neg_not_fully_removed.jpg")
if os.path.exists(neg_path_2):
with gr.Column():
gr.Image(neg_path_2, interactive=False, show_label=False)
gr.Markdown("**❌ 失败原因:干扰未去干净**<br>*(Failure: Artifacts Not Fully Removed)*", elem_classes="question-box")
# ================= 页面 2: 测试页 (包含 Loading) =================
# 2.1 真实的答题区域 (Group)
with gr.Group(visible=False) as quiz_main_ui:
progress_md = gr.Markdown("### ⏳ Progress: 1 / 20")
with gr.Column():
# 动态标签
label_html = gr.HTML(value="")
# 图片组件:保留 show_fullscreen_button 以支持放大
# 移除了 show_download_button
image_display = gr.Image(
label="",
interactive=False,
show_label=False,
container=True
)
question_html = gr.HTML(value="")
with gr.Row():
btn_fail = gr.Button("❌ 失败 / Failure\n(Incomplete / Distorted)", size="lg", variant="stop")
btn_pass = gr.Button("✅ 成功 / Success\n(Clean & Preserved)", size="lg", variant="primary")
# 2.2 Loading 等待区域 (Group)
with gr.Group(visible=False) as loading_ui:
gr.HTML("""
<div class='loading-container'>
<p>🔄 Loading next image...</p>
<p>正在加载下一张...</p>
</div>
""")
# ================= 页面 3: 结束页 =================
with gr.Group(visible=False) as end_page:
gr.Markdown("""
# 🎉 感谢您的参与! / Thank You!
您的评估数据已成功提交。您可以直接关闭此页面。
Your evaluation data has been submitted. You may close this page now.
""")
# 管理员面板
gr.Markdown("---")
with gr.Accordion("🔧 管理员面板 / Admin Panel", open=False):
with gr.Row():
admin_pass = gr.Textbox(label="Password", type="password")
check_btn = gr.Button("Download Data")
file_download = gr.File(label="Result CSV", interactive=False, visible=False)
msg_box = gr.Markdown("")
def verify_and_download(password):
MY_SECRET_PASSWORD = "010929"
if password == MY_SECRET_PASSWORD:
if os.path.exists(RESULT_FILE):
return {file_download: gr.update(value=RESULT_FILE, visible=True), msg_box: "✅ Success!"}
return {file_download: gr.update(visible=False), msg_box: "⚠️ No data yet."}
return {file_download: gr.update(visible=False), msg_box: "❌ Wrong Password"}
check_btn.click(verify_and_download, inputs=[admin_pass], outputs=[file_download, msg_box])
# ================= 交互逻辑 (Event Handlers) =================
def start_survey():
batch = get_random_batch()
if not batch:
return {
welcome_page: gr.update(visible=True),
gr.State(""): "Error"
}
u_id = str(int(time.time()))
first_path, _, first_task = batch[0]
return {
welcome_page: gr.update(visible=False),
quiz_main_ui: gr.update(visible=True), # 显示答题区
loading_ui: gr.update(visible=False), # 隐藏Loading
end_page: gr.update(visible=False),
state_batch: batch,
state_index: 0,
state_results: [],
state_user_id: u_id,
image_display: first_path,
progress_md: f"### ⏳ Progress: 1 / {len(batch)}",
question_html: generate_question_text(first_task),
label_html: generate_label_html(first_task)
}
# --- 关键逻辑:分两步走以解决卡顿 ---
# 步骤 1: 立即切换到 Loading 界面 (前端快速响应)
def switch_to_loading():
return {
quiz_main_ui: gr.update(visible=False), # 立即隐藏按钮,防止连点
loading_ui: gr.update(visible=True) # 显示 Loading
}
# 步骤 2: 后台处理数据,准备好后切换回答题界面
def process_and_next(is_success, index, batch, results, u_id):
# 安全检查
if not batch or index >= len(batch):
return {
quiz_main_ui: gr.update(visible=False),
loading_ui: gr.update(visible=False),
end_page: gr.update(visible=True),
state_index: index,
state_results: results,
image_display: gr.update(),
progress_md: gr.update(),
question_html: gr.update(),
label_html: gr.update()
}
# 记录数据
_, current_name, current_task = batch[index]
results.append((current_name, current_task, is_success))
next_index = index + 1
if next_index >= len(batch):
# 完成所有题目
save_result(u_id, results)
return {
quiz_main_ui: gr.update(visible=False),
loading_ui: gr.update(visible=False),
end_page: gr.update(visible=True),
state_index: next_index,
state_results: results,
image_display: gr.update(),
progress_md: gr.update(),
question_html: gr.update(),
label_html: gr.update()
}
else:
# 加载下一题
next_path, _, next_task = batch[next_index]
return {
quiz_main_ui: gr.update(visible=True), # 显示新题目
loading_ui: gr.update(visible=False), # 隐藏 Loading
end_page: gr.update(visible=False),
state_index: next_index,
state_results: results,
image_display: next_path,
progress_md: f"### ⏳ Progress: {next_index + 1} / {len(batch)}",
question_html: generate_question_text(next_task),
label_html: generate_label_html(next_task)
}
# --- 绑定事件 ---
start_btn.click(
start_survey,
inputs=[],
outputs=[welcome_page, quiz_main_ui, loading_ui, end_page, state_batch, state_index, state_results, state_user_id, image_display, progress_md, question_html, label_html]
)
# 链式调用:先变Loading,再算结果
# queue=False 保证第一步立即执行,不排队
# 点击 Pass
btn_pass.click(
switch_to_loading,
inputs=[],
outputs=[quiz_main_ui, loading_ui],
queue=False
).then(
process_and_next,
inputs=[gr.Number(1, visible=False), state_index, state_batch, state_results, state_user_id],
outputs=[quiz_main_ui, loading_ui, end_page, state_index, state_results, image_display, progress_md, question_html, label_html]
)
# 点击 Fail
btn_fail.click(
switch_to_loading,
inputs=[],
outputs=[quiz_main_ui, loading_ui],
queue=False
).then(
process_and_next,
inputs=[gr.Number(0, visible=False), state_index, state_batch, state_results, state_user_id],
outputs=[quiz_main_ui, loading_ui, end_page, state_index, state_results, image_display, progress_md, question_html, label_html]
)
if __name__ == "__main__":
demo.launch(share=True, ssr_mode=False)