userstudy / app_fixed.py
Wenjiawang0312
store
d1d5132
import gradio as gr
import os
import json
from datetime import datetime
from datasets import load_dataset
import random
# 全局变量存储数据集
DATASET = None
VIDEO_DATA = None
# 从Hugging Face dataset加载视频
def load_videos_from_huggingface():
global DATASET, VIDEO_DATA
try:
print("正在加载数据集: WenjiaWang/videoforuser...")
DATASET = load_dataset("WenjiaWang/videoforuser", split="train")
print(f"成功加载数据集,共 {len(DATASET)} 个视频")
# 组织视频数据:按场景分组
VIDEO_DATA = {}
for idx, item in enumerate(DATASET):
# 获取视频路径信息
if 'video' in item:
video_path = item['video']
elif 'path' in item:
video_path = item['path']
else:
print(f"警告: 第 {idx} 项没有视频路径字段")
continue
# 从路径中提取场景名和方法名
# 假设路径格式类似: "videos/scene_name/method.mp4"
path_parts = video_path.split('/')
if len(path_parts) >= 2:
scene_name = path_parts[-2] # 倒数第二部分是场景名
file_name = path_parts[-1] # 最后部分是文件名
# 提取方法名
method_name = file_name.replace('.mp4', '')
if scene_name not in VIDEO_DATA:
VIDEO_DATA[scene_name] = {}
# 存储视频信息(包括在dataset中的索引)
VIDEO_DATA[scene_name][method_name] = {
'index': idx,
'path': video_path,
'item': item
}
print(f"组织完成,共 {len(VIDEO_DATA)} 个场景")
return True
except Exception as e:
print(f"加载数据集失败: {e}")
import traceback
traceback.print_exc()
return False
# 获取所有场景列表
def get_question_folders():
if VIDEO_DATA is None:
success = load_videos_from_huggingface()
if not success:
return []
return sorted(list(VIDEO_DATA.keys()))
# 获取某个场景的所有视频
def get_videos_for_question(scene_name):
if VIDEO_DATA is None or scene_name not in VIDEO_DATA:
return {}, {}
scene_videos = VIDEO_DATA[scene_name]
# 创建方法名到真实名称的映射
method_names = list(scene_videos.keys())
# 随机打乱顺序以匿名化
shuffled_methods = method_names.copy()
random.shuffle(shuffled_methods)
videos = {}
method_mapping = {}
for i, method_name in enumerate(shuffled_methods):
display_name = f"Method {chr(65+i)}" # Method A, B, C, D
# 获取视频数据
video_info = scene_videos[method_name]
video_item = video_info['item']
# 从dataset item中获取视频文件
if 'video' in video_item:
videos[display_name] = video_item['video'] # 这应该是视频文件路径或对象
method_mapping[display_name] = method_name
return videos, method_mapping
# 保存评分数据
def save_ratings(scene_name, ratings_data, method_mapping):
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# 将显示名称映射到真实方法名
mapped_ratings = {}
for display_name, ratings in ratings_data.items():
real_method = method_mapping.get(display_name, display_name)
mapped_ratings[real_method] = ratings
# 读取现有数据
all_data = []
if os.path.exists("ratings_data.json"):
try:
with open("ratings_data.json", "r", encoding="utf-8") as f:
all_data = json.load(f)
except:
all_data = []
# 添加新数据
entry = {
"timestamp": timestamp,
"scene": scene_name,
"ratings": mapped_ratings
}
all_data.append(entry)
# 保存数据
with open("ratings_data.json", "w", encoding="utf-8") as f:
json.dump(all_data, f, ensure_ascii=False, indent=2)
return f"✓ 评分已保存 / Ratings saved"
# 创建Gradio界面
def create_video_survey_app():
# 预加载数据集
print("初始化应用...")
load_videos_from_huggingface()
question_folders = get_question_folders()
if not question_folders:
print("错误: 没有找到任何场景数据")
return None
print(f"找到 {len(question_folders)} 个场景")
with gr.Blocks(title="视频生成质量用户研究", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🎬 视频生成质量用户研究 / Video Generation Quality User Study")
gr.Markdown("""
### 说明 / Instructions:
- 请观看每个视频并进行评分 / Please watch each video and rate them
- 评分标准 / Rating criteria:
- **动态生成质量** / Dynamic Generation Quality: 视频中物体运动的流畅性和真实性
- **静态一致性** / Static Consistency: 视频中静态物体的稳定性和一致性
- **整体质量** / Overall Quality: 视频的整体观感
- 评分范围:1-5分(5分最好)/ Rating scale: 1-5 (5 = Best)
""")
# 状态变量
current_question_idx = gr.State(0)
current_method_mapping = gr.State({})
# 进度显示
with gr.Row():
prev_btn = gr.Button("⬅️ 上一题 / Previous", size="sm")
question_text = gr.Markdown(f"**场景 1 / {len(question_folders)}**")
next_btn = gr.Button("下一题 / Next ➡️", size="sm", variant="primary")
status_text = gr.Textbox(label="状态 / Status", interactive=False, visible=False)
# 视频显示区域(4个视频)
video_components = []
rating_components = []
for i in range(4):
method_name = f"Method {chr(65+i)}"
with gr.Group():
gr.Markdown(f"### 🎥 {method_name}")
video = gr.Video(label="", height=300)
video_components.append(video)
with gr.Row():
dynamic = gr.Slider(
minimum=1, maximum=5, step=1, value=3,
label="动态质量 / Dynamic Quality",
info="1=差 / Poor, 5=优秀 / Excellent"
)
static = gr.Slider(
minimum=1, maximum=5, step=1, value=3,
label="静态一致性 / Static Consistency",
info="1=差 / Poor, 5=优秀 / Excellent"
)
overall = gr.Slider(
minimum=1, maximum=5, step=1, value=3,
label="整体质量 / Overall Quality",
info="1=差 / Poor, 5=优秀 / Excellent"
)
rating_components.append({
"dynamic": dynamic,
"static": static,
"overall": overall
})
# 更新问题显示
def update_question(question_idx, save_previous=False, prev_ratings=None, prev_mapping=None):
if question_idx < 0:
question_idx = 0
if question_idx >= len(question_folders):
question_idx = len(question_folders) - 1
# 如果需要,保存上一题的评分
save_msg = ""
if save_previous and prev_ratings and prev_mapping:
prev_scene = question_folders[question_idx - 1] if question_idx > 0 else None
if prev_scene:
save_msg = save_ratings(prev_scene, prev_ratings, prev_mapping)
scene_name = question_folders[question_idx]
videos, method_mapping = get_videos_for_question(scene_name)
# 更新视频显示
video_updates = []
for i in range(4):
method_name = f"Method {chr(65+i)}"
if method_name in videos:
video_updates.append(gr.Video(value=videos[method_name], visible=True))
else:
video_updates.append(gr.Video(value=None, visible=False))
# 重置评分
rating_updates = [gr.Slider(value=3) for _ in range(12)] # 4个视频 x 3个评分
question_markdown = f"**场景 {question_idx + 1} / {len(question_folders)}**: `{scene_name}`"
return (
[question_idx, method_mapping, question_markdown, save_msg] +
video_updates +
rating_updates
)
# 收集当前评分
def collect_ratings(*rating_values):
ratings = {}
for i in range(4):
method_name = f"Method {chr(65+i)}"
base_idx = i * 3
ratings[method_name] = {
"dynamic_quality": rating_values[base_idx],
"static_consistency": rating_values[base_idx + 1],
"overall_quality": rating_values[base_idx + 2]
}
return ratings
# 下一题按钮
def on_next(question_idx, method_mapping, *rating_values):
# 收集当前评分
current_ratings = collect_ratings(*rating_values)
# 保存当前评分
scene_name = question_folders[question_idx]
save_msg = save_ratings(scene_name, current_ratings, method_mapping)
# 移动到下一题
new_idx = question_idx + 1
if new_idx >= len(question_folders):
return [
question_idx,
method_mapping,
f"**✅ 所有场景已完成!/ All scenes completed!**",
save_msg + "\n🎉 感谢参与!/ Thank you for participating!"
] + [gr.Video()] * 4 + [gr.Slider(value=3)] * 12
return update_question(new_idx)
# 上一题按钮
def on_prev(question_idx, *args):
new_idx = question_idx - 1
if new_idx < 0:
new_idx = 0
return update_question(new_idx)
# 收集所有评分组件
all_rating_inputs = []
for comp in rating_components:
all_rating_inputs.extend([comp["dynamic"], comp["static"], comp["overall"]])
# 绑定事件
next_btn.click(
on_next,
inputs=[current_question_idx, current_method_mapping] + all_rating_inputs,
outputs=[
current_question_idx,
current_method_mapping,
question_text,
status_text
] + video_components + all_rating_inputs
)
prev_btn.click(
on_prev,
inputs=[current_question_idx] + all_rating_inputs,
outputs=[
current_question_idx,
current_method_mapping,
question_text,
status_text
] + video_components + all_rating_inputs
)
# 初始化第一个问题
demo.load(
lambda: update_question(0),
outputs=[
current_question_idx,
current_method_mapping,
question_text,
status_text
] + video_components + all_rating_inputs
)
return demo
if __name__ == "__main__":
app = create_video_survey_app()
if app:
app.launch(server_name="0.0.0.0", server_port=7860, share=False)
else:
print("应用初始化失败 / App initialization failed")