userstudy / app_backup.py
Wenjiawang0312
store
d1d5132
import gradio as gr
import os
import json
from datetime import datetime
from datasets import load_dataset
import tempfile
# 从Hugging Face dataset加载视频 / Load videos from Hugging Face dataset
def load_videos_from_huggingface():
try:
dataset = load_dataset("WenjiaWang/videoforuser")
print("成功加载数据集: WenjiaWang/videoforuser / Successfully loaded dataset")
# 获取数据集中的文件夹列表
question_folders = set()
all_videos = {} # 存储所有视频路径 {folder: {method: video_path}}
# 假设数据集包含file_path字段,或者使用标准的数据集结构
if 'train' in dataset:
for item in dataset['train']:
# 根据数据集的实际结构调整字段名
if 'file_path' in item:
file_path = item['file_path']
elif 'path' in item:
file_path = item['path']
elif 'video_path' in item:
file_path = item['video_path']
else:
continue
# 从路径提取文件夹名和文件名
folder_name = os.path.basename(os.path.dirname(file_path))
file_name = os.path.basename(file_path)
question_folders.add(folder_name)
if folder_name not in all_videos:
all_videos[folder_name] = {}
# 确定方法名并映射到匿名显示名
display_name = f"Method {len(all_videos[folder_name]) + 1}"
all_videos[folder_name][display_name] = item
return list(sorted(question_folders)), all_videos
except Exception as e:
print(f"Hugging Face数据集加载失败: {e} / Hugging Face dataset loading failed")
print("回退到本地videos文件夹 / Falling back to local videos folder")
return None, None
# 获取所有问题文件夹 / Get all question folders
def get_question_folders():
# 优先使用Hugging Face数据集
hf_folders, hf_videos = load_videos_from_huggingface()
if hf_folders and hf_videos:
return hf_folders
# 回退到本地文件夹
video_dir = "videos"
if not os.path.exists(video_dir):
print(f"视频目录不存在: {video_dir} / Video directory not found")
return []
folders = [f for f in os.listdir(video_dir) if os.path.isdir(os.path.join(video_dir, f))]
return sorted(folders)
# 获取问题文件夹中的所有视频 / Get all videos in the question folder
def get_videos_for_question(question_folder):
# 优先使用Hugging Face数据集
hf_folders, hf_videos = load_videos_from_huggingface()
if hf_folders and hf_videos and question_folder in hf_videos:
videos = {}
method_mapping = {}
for display_name, item in hf_videos[question_folder].items():
# 从item中获取视频路径或内容
if 'file_content' in item:
# 如果视频内容直接存储在数据集中
import tempfile
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp:
tmp.write(item['file_content'])
videos[display_name] = tmp.name
elif 'file_path' in item:
# 如果存储的是路径
videos[display_name] = item['file_path']
elif 'path' in item:
videos[display_name] = item['path']
# 创建方法映射(假设能从路径推断真实方法名)
if 'file_path' in item:
file_path = item['file_path']
if 'stage12_new' in file_path:
method_mapping[display_name] = "Ours (stage12_new)"
elif 'gen3c' in file_path:
method_mapping[display_name] = "Gen3c"
elif 'svc' in file_path:
method_mapping[display_name] = "SVC"
elif 'trajattn' in file_path:
method_mapping[display_name] = "TrajAttn"
return videos, method_mapping
# 回退到本地文件夹
video_dir = os.path.join("videos", question_folder)
videos = {}
method_mapping = {} # 显示名称映射到真实方法名 / Map display names to real method names
if os.path.exists(video_dir):
for file in os.listdir(video_dir):
if file.endswith('.mp4'):
video_path = os.path.join(video_dir, file)
# 确定方法名称 / Determine method name
if file.startswith('stage12_new'):
real_method_name = "Ours (stage12_new)"
elif file.startswith('gen3c'):
real_method_name = "Gen3c"
elif file.startswith('svc'):
real_method_name = "SVC"
elif file.startswith('trajattn'):
real_method_name = "TrajAttn"
else:
real_method_name = file.replace('.mp4', '')
# 使用匿名显示名称 / Use anonymous display names
display_name = f"Method {len(videos) + 1}"
videos[display_name] = video_path
method_mapping[display_name] = real_method_name
return videos, method_mapping
# 保存评分数据 / Save rating data
def save_ratings(current_question, ratings_data, method_mapping):
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"ratings_{timestamp}.json"
# 将显示名称映射到真实方法名 / Map display names to real method names
mapped_ratings = {}
for display_name, ratings in ratings_data.items():
real_method = method_mapping.get(display_name, display_name)
mapped_ratings[real_method] = ratings
# 读取现有数据 / Read existing data
all_data = []
if os.path.exists("ratings_data.json"):
try:
with open("ratings_data.json", "r", encoding="utf-8") as f:
all_data = json.load(f)
except:
all_data = []
# 添加新数据 / Add new data
entry = {
"timestamp": timestamp,
"question": current_question,
"ratings": mapped_ratings
}
all_data.append(entry)
# 保存数据 / Save data
with open("ratings_data.json", "w", encoding="utf-8") as f:
json.dump(all_data, f, ensure_ascii=False, indent=2)
return f"评分已保存到 {filename}"
# 创建Gradio界面 / Create Gradio interface
def create_video_survey_app():
question_folders = get_question_folders()
current_question_idx = gr.State(0)
all_ratings = gr.State({})
current_method_mapping = gr.State({})
with gr.Blocks(title="视频生成质量用户研究") as demo:
gr.Markdown("# 视频生成质量用户研究 / Video Generation Quality User Study")
gr.Markdown("Please rate each video based on dynamic generation quality, static consistency, and overall quality. 对每个视频从动态生成能力、静态物体的一致性、整体质量三方面进行1~5的评分")
gr.Markdown("Rating scale: 1-5 (5 = Best) 评分等级:1-5分(5分为最好)")
with gr.Row():
prev_btn = gr.Button("上一题 / Previous", visible=False)
question_text = gr.Markdown(f"问题 1 / {len(question_folders)} / Question 1 / {len(question_folders)}")
next_btn = gr.Button("下一题 / Next")
# 视频显示区域 / Video display area
with gr.Row():
video_cols = []
method_display_names = ["Method 1", "Method 2", "Method 3", "Method 4"]
for i, method in enumerate(method_display_names):
with gr.Column():
gr.Markdown(f"### {method}")
video = gr.Video(visible=False)
video_cols.append(video)
# 评分区域 / Rating area
with gr.Row():
ratings_cols = []
for i, method in enumerate(method_display_names):
with gr.Column():
gr.Markdown(f"#### {method} 评分 / Rating")
dynamic_quality = gr.Slider(
minimum=1, maximum=5, step=1, value=3,
label="Dynamic Generation Quality / 动态生成能力"
)
static_consistency = gr.Slider(
minimum=1, maximum=5, step=1, value=3,
label="Static Consistency / 静态一致性"
)
overall_quality = gr.Slider(
minimum=1, maximum=5, step=1, value=3,
label="Overall Quality / 整体质量"
)
ratings_cols.append({
"dynamic_quality": dynamic_quality,
"static_consistency": static_consistency,
"overall_quality": overall_quality
})
status_text = gr.Textbox(label="状态 / Status", interactive=False)
# 更新问题显示的函数 / Function to update question display
def update_question(question_idx):
if question_idx < 0 or question_idx >= len(question_folders):
return []
question_folder = question_folders[question_idx]
videos, method_mapping = get_videos_for_question(question_folder)
video_outputs = []
method_display_names = ["Method 1", "Method 2", "Method 3", "Method 4"]
for method_display in method_display_names:
if method_display in videos:
video_outputs.append(gr.Video(value=videos[method_display], visible=True))
else:
video_outputs.append(gr.Video(visible=False))
question_markdown = f"问题 {question_idx + 1} / {len(question_folders)}: {question_folder} / Question {question_idx + 1} / {len(question_folders)}: {question_folder}"
return video_outputs + [gr.Markdown(question_markdown)] + [gr.State(value=method_mapping)]
# 保存当前评分的函数 / Function to save current ratings
def save_current_ratings(question_idx, method_mapping, *rating_values):
if question_idx >= len(question_folders):
return "没有更多问题了 / No more questions"
question_folder = question_folders[question_idx]
method_display_names = ["Method 1", "Method 2", "Method 3", "Method 4"]
ratings = {}
for i, method_display in enumerate(method_display_names):
base_idx = i * 3
ratings[method_display] = {
"dynamic_quality": rating_values[base_idx],
"static_consistency": rating_values[base_idx + 1],
"overall_quality": rating_values[base_idx + 2]
}
return save_ratings(question_folder, ratings, method_mapping)
# 事件处理 / Event handling
def on_next_click(question_idx, method_mapping, *rating_values):
# 先保存当前评分 / First save current ratings
save_message = save_current_ratings(question_idx, method_mapping, *rating_values)
# 移动到下一题 / Move to next question
new_idx = question_idx + 1
if new_idx >= len(question_folders):
return [save_message + "\n所有问题已完成!/ All questions completed!"] + [gr.Button(visible=False)] * 3
# 更新显示 / Update display
video_updates = update_question(new_idx)
return [save_message + f"\n当前问题: {question_folders[new_idx]}/ Current question: {question_folders[new_idx]}"] + video_updates
def on_prev_click(question_idx):
new_idx = question_idx - 1
if new_idx < 0:
return []
return update_question(new_idx)
# 收集所有评分组件
all_rating_components = []
for col in ratings_cols:
all_rating_components.extend([
col["dynamic_quality"],
col["static_consistency"],
col["overall_quality"]
])
# 初始化第一个问题
initial_videos = update_question(0)
# 绑定事件 / Bind events
next_btn.click(
on_next_click,
inputs=[current_question_idx, current_method_mapping] + all_rating_components,
outputs=[status_text] + video_cols + [question_text] + [current_method_mapping]
)
prev_btn.click(
on_prev_click,
inputs=[current_question_idx],
outputs=video_cols + [question_text] + [current_method_mapping]
)
return demo
if __name__ == "__main__":
app = create_video_survey_app()
app.launch(share=True)