Spaces:
Sleeping
Sleeping
File size: 13,456 Bytes
d1d5132 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 |
import gradio as gr
import os
import json
from datetime import datetime
from datasets import load_dataset
import tempfile
# 从Hugging Face dataset加载视频 / Load videos from Hugging Face dataset
def load_videos_from_huggingface():
try:
dataset = load_dataset("WenjiaWang/videoforuser")
print("成功加载数据集: WenjiaWang/videoforuser / Successfully loaded dataset")
# 获取数据集中的文件夹列表
question_folders = set()
all_videos = {} # 存储所有视频路径 {folder: {method: video_path}}
# 假设数据集包含file_path字段,或者使用标准的数据集结构
if 'train' in dataset:
for item in dataset['train']:
# 根据数据集的实际结构调整字段名
if 'file_path' in item:
file_path = item['file_path']
elif 'path' in item:
file_path = item['path']
elif 'video_path' in item:
file_path = item['video_path']
else:
continue
# 从路径提取文件夹名和文件名
folder_name = os.path.basename(os.path.dirname(file_path))
file_name = os.path.basename(file_path)
question_folders.add(folder_name)
if folder_name not in all_videos:
all_videos[folder_name] = {}
# 确定方法名并映射到匿名显示名
display_name = f"Method {len(all_videos[folder_name]) + 1}"
all_videos[folder_name][display_name] = item
return list(sorted(question_folders)), all_videos
except Exception as e:
print(f"Hugging Face数据集加载失败: {e} / Hugging Face dataset loading failed")
print("回退到本地videos文件夹 / Falling back to local videos folder")
return None, None
# 获取所有问题文件夹 / Get all question folders
def get_question_folders():
# 优先使用Hugging Face数据集
hf_folders, hf_videos = load_videos_from_huggingface()
if hf_folders and hf_videos:
return hf_folders
# 回退到本地文件夹
video_dir = "videos"
if not os.path.exists(video_dir):
print(f"视频目录不存在: {video_dir} / Video directory not found")
return []
folders = [f for f in os.listdir(video_dir) if os.path.isdir(os.path.join(video_dir, f))]
return sorted(folders)
# 获取问题文件夹中的所有视频 / Get all videos in the question folder
def get_videos_for_question(question_folder):
# 优先使用Hugging Face数据集
hf_folders, hf_videos = load_videos_from_huggingface()
if hf_folders and hf_videos and question_folder in hf_videos:
videos = {}
method_mapping = {}
for display_name, item in hf_videos[question_folder].items():
# 从item中获取视频路径或内容
if 'file_content' in item:
# 如果视频内容直接存储在数据集中
import tempfile
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp:
tmp.write(item['file_content'])
videos[display_name] = tmp.name
elif 'file_path' in item:
# 如果存储的是路径
videos[display_name] = item['file_path']
elif 'path' in item:
videos[display_name] = item['path']
# 创建方法映射(假设能从路径推断真实方法名)
if 'file_path' in item:
file_path = item['file_path']
if 'stage12_new' in file_path:
method_mapping[display_name] = "Ours (stage12_new)"
elif 'gen3c' in file_path:
method_mapping[display_name] = "Gen3c"
elif 'svc' in file_path:
method_mapping[display_name] = "SVC"
elif 'trajattn' in file_path:
method_mapping[display_name] = "TrajAttn"
return videos, method_mapping
# 回退到本地文件夹
video_dir = os.path.join("videos", question_folder)
videos = {}
method_mapping = {} # 显示名称映射到真实方法名 / Map display names to real method names
if os.path.exists(video_dir):
for file in os.listdir(video_dir):
if file.endswith('.mp4'):
video_path = os.path.join(video_dir, file)
# 确定方法名称 / Determine method name
if file.startswith('stage12_new'):
real_method_name = "Ours (stage12_new)"
elif file.startswith('gen3c'):
real_method_name = "Gen3c"
elif file.startswith('svc'):
real_method_name = "SVC"
elif file.startswith('trajattn'):
real_method_name = "TrajAttn"
else:
real_method_name = file.replace('.mp4', '')
# 使用匿名显示名称 / Use anonymous display names
display_name = f"Method {len(videos) + 1}"
videos[display_name] = video_path
method_mapping[display_name] = real_method_name
return videos, method_mapping
# 保存评分数据 / Save rating data
def save_ratings(current_question, ratings_data, method_mapping):
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"ratings_{timestamp}.json"
# 将显示名称映射到真实方法名 / Map display names to real method names
mapped_ratings = {}
for display_name, ratings in ratings_data.items():
real_method = method_mapping.get(display_name, display_name)
mapped_ratings[real_method] = ratings
# 读取现有数据 / Read existing data
all_data = []
if os.path.exists("ratings_data.json"):
try:
with open("ratings_data.json", "r", encoding="utf-8") as f:
all_data = json.load(f)
except:
all_data = []
# 添加新数据 / Add new data
entry = {
"timestamp": timestamp,
"question": current_question,
"ratings": mapped_ratings
}
all_data.append(entry)
# 保存数据 / Save data
with open("ratings_data.json", "w", encoding="utf-8") as f:
json.dump(all_data, f, ensure_ascii=False, indent=2)
return f"评分已保存到 {filename}"
# 创建Gradio界面 / Create Gradio interface
def create_video_survey_app():
question_folders = get_question_folders()
current_question_idx = gr.State(0)
all_ratings = gr.State({})
current_method_mapping = gr.State({})
with gr.Blocks(title="视频生成质量用户研究") as demo:
gr.Markdown("# 视频生成质量用户研究 / Video Generation Quality User Study")
gr.Markdown("Please rate each video based on dynamic generation quality, static consistency, and overall quality. 对每个视频从动态生成能力、静态物体的一致性、整体质量三方面进行1~5的评分")
gr.Markdown("Rating scale: 1-5 (5 = Best) 评分等级:1-5分(5分为最好)")
with gr.Row():
prev_btn = gr.Button("上一题 / Previous", visible=False)
question_text = gr.Markdown(f"问题 1 / {len(question_folders)} / Question 1 / {len(question_folders)}")
next_btn = gr.Button("下一题 / Next")
# 视频显示区域 / Video display area
with gr.Row():
video_cols = []
method_display_names = ["Method 1", "Method 2", "Method 3", "Method 4"]
for i, method in enumerate(method_display_names):
with gr.Column():
gr.Markdown(f"### {method}")
video = gr.Video(visible=False)
video_cols.append(video)
# 评分区域 / Rating area
with gr.Row():
ratings_cols = []
for i, method in enumerate(method_display_names):
with gr.Column():
gr.Markdown(f"#### {method} 评分 / Rating")
dynamic_quality = gr.Slider(
minimum=1, maximum=5, step=1, value=3,
label="Dynamic Generation Quality / 动态生成能力"
)
static_consistency = gr.Slider(
minimum=1, maximum=5, step=1, value=3,
label="Static Consistency / 静态一致性"
)
overall_quality = gr.Slider(
minimum=1, maximum=5, step=1, value=3,
label="Overall Quality / 整体质量"
)
ratings_cols.append({
"dynamic_quality": dynamic_quality,
"static_consistency": static_consistency,
"overall_quality": overall_quality
})
status_text = gr.Textbox(label="状态 / Status", interactive=False)
# 更新问题显示的函数 / Function to update question display
def update_question(question_idx):
if question_idx < 0 or question_idx >= len(question_folders):
return []
question_folder = question_folders[question_idx]
videos, method_mapping = get_videos_for_question(question_folder)
video_outputs = []
method_display_names = ["Method 1", "Method 2", "Method 3", "Method 4"]
for method_display in method_display_names:
if method_display in videos:
video_outputs.append(gr.Video(value=videos[method_display], visible=True))
else:
video_outputs.append(gr.Video(visible=False))
question_markdown = f"问题 {question_idx + 1} / {len(question_folders)}: {question_folder} / Question {question_idx + 1} / {len(question_folders)}: {question_folder}"
return video_outputs + [gr.Markdown(question_markdown)] + [gr.State(value=method_mapping)]
# 保存当前评分的函数 / Function to save current ratings
def save_current_ratings(question_idx, method_mapping, *rating_values):
if question_idx >= len(question_folders):
return "没有更多问题了 / No more questions"
question_folder = question_folders[question_idx]
method_display_names = ["Method 1", "Method 2", "Method 3", "Method 4"]
ratings = {}
for i, method_display in enumerate(method_display_names):
base_idx = i * 3
ratings[method_display] = {
"dynamic_quality": rating_values[base_idx],
"static_consistency": rating_values[base_idx + 1],
"overall_quality": rating_values[base_idx + 2]
}
return save_ratings(question_folder, ratings, method_mapping)
# 事件处理 / Event handling
def on_next_click(question_idx, method_mapping, *rating_values):
# 先保存当前评分 / First save current ratings
save_message = save_current_ratings(question_idx, method_mapping, *rating_values)
# 移动到下一题 / Move to next question
new_idx = question_idx + 1
if new_idx >= len(question_folders):
return [save_message + "\n所有问题已完成!/ All questions completed!"] + [gr.Button(visible=False)] * 3
# 更新显示 / Update display
video_updates = update_question(new_idx)
return [save_message + f"\n当前问题: {question_folders[new_idx]}/ Current question: {question_folders[new_idx]}"] + video_updates
def on_prev_click(question_idx):
new_idx = question_idx - 1
if new_idx < 0:
return []
return update_question(new_idx)
# 收集所有评分组件
all_rating_components = []
for col in ratings_cols:
all_rating_components.extend([
col["dynamic_quality"],
col["static_consistency"],
col["overall_quality"]
])
# 初始化第一个问题
initial_videos = update_question(0)
# 绑定事件 / Bind events
next_btn.click(
on_next_click,
inputs=[current_question_idx, current_method_mapping] + all_rating_components,
outputs=[status_text] + video_cols + [question_text] + [current_method_mapping]
)
prev_btn.click(
on_prev_click,
inputs=[current_question_idx],
outputs=video_cols + [question_text] + [current_method_mapping]
)
return demo
if __name__ == "__main__":
app = create_video_survey_app()
app.launch(share=True)
|