Wenjiawang0312 commited on
Commit
1c13596
·
1 Parent(s): 1c160df

Add application file

Browse files
Files changed (1) hide show
  1. app.py +334 -0
app.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import json
4
+ from datetime import datetime
5
+ from datasets import load_dataset
6
+ import random
7
+
8
+ # 全局变量存储数据集
9
+ DATASET = None
10
+ VIDEO_DATA = None
11
+
12
+ # 从Hugging Face dataset加载视频
13
+ def load_videos_from_huggingface():
14
+ global DATASET, VIDEO_DATA
15
+
16
+ try:
17
+ print("正在加载数据集: WenjiaWang/videoforuser...")
18
+ DATASET = load_dataset("WenjiaWang/videoforuser", split="train")
19
+ print(f"成功加载数据集,共 {len(DATASET)} 个视频")
20
+
21
+ # 组织视频数据:按场景分组
22
+ VIDEO_DATA = {}
23
+
24
+ for idx, item in enumerate(DATASET):
25
+ # 获取视频路径信息
26
+ if 'video' in item:
27
+ video_path = item['video']
28
+ elif 'path' in item:
29
+ video_path = item['path']
30
+ else:
31
+ print(f"警告: 第 {idx} 项没有视频路径字段")
32
+ continue
33
+
34
+ # 从路径中提取场景名和方法名
35
+ # 假设路径格式类似: "videos/scene_name/method.mp4"
36
+ path_parts = video_path.split('/')
37
+ if len(path_parts) >= 2:
38
+ scene_name = path_parts[-2] # 倒数第二部分是场景名
39
+ file_name = path_parts[-1] # 最后部分是文件名
40
+
41
+ # 提取方法名
42
+ method_name = file_name.replace('.mp4', '')
43
+
44
+ if scene_name not in VIDEO_DATA:
45
+ VIDEO_DATA[scene_name] = {}
46
+
47
+ # 存储视频信息(包括在dataset中的索引)
48
+ VIDEO_DATA[scene_name][method_name] = {
49
+ 'index': idx,
50
+ 'path': video_path,
51
+ 'item': item
52
+ }
53
+
54
+ print(f"组织完成,共 {len(VIDEO_DATA)} 个场景")
55
+ return True
56
+
57
+ except Exception as e:
58
+ print(f"加载数据集失败: {e}")
59
+ import traceback
60
+ traceback.print_exc()
61
+ return False
62
+
63
+ # 获取所有场景列表
64
+ def get_question_folders():
65
+ if VIDEO_DATA is None:
66
+ success = load_videos_from_huggingface()
67
+ if not success:
68
+ return []
69
+
70
+ return sorted(list(VIDEO_DATA.keys()))
71
+
72
+ # 获取某个场景的所有视频
73
+ def get_videos_for_question(scene_name):
74
+ if VIDEO_DATA is None or scene_name not in VIDEO_DATA:
75
+ return {}, {}
76
+
77
+ scene_videos = VIDEO_DATA[scene_name]
78
+
79
+ # 创建方法名到真实名称的映射
80
+ method_names = list(scene_videos.keys())
81
+
82
+ # 随机打乱顺序以匿名化
83
+ shuffled_methods = method_names.copy()
84
+ random.shuffle(shuffled_methods)
85
+
86
+ videos = {}
87
+ method_mapping = {}
88
+
89
+ for i, method_name in enumerate(shuffled_methods):
90
+ display_name = f"Method {chr(65+i)}" # Method A, B, C, D
91
+
92
+ # 获取视频数据
93
+ video_info = scene_videos[method_name]
94
+ video_item = video_info['item']
95
+
96
+ # 从dataset item中获取视频文件
97
+ if 'video' in video_item:
98
+ videos[display_name] = video_item['video'] # 这应该是视频文件路径或对象
99
+
100
+ method_mapping[display_name] = method_name
101
+
102
+ return videos, method_mapping
103
+
104
+ # 保存评分数据
105
+ def save_ratings(scene_name, ratings_data, method_mapping):
106
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
107
+
108
+ # 将显示名称映射到真实方法名
109
+ mapped_ratings = {}
110
+ for display_name, ratings in ratings_data.items():
111
+ real_method = method_mapping.get(display_name, display_name)
112
+ mapped_ratings[real_method] = ratings
113
+
114
+ # 读取现有数据
115
+ all_data = []
116
+ if os.path.exists("ratings_data.json"):
117
+ try:
118
+ with open("ratings_data.json", "r", encoding="utf-8") as f:
119
+ all_data = json.load(f)
120
+ except:
121
+ all_data = []
122
+
123
+ # 添加新数据
124
+ entry = {
125
+ "timestamp": timestamp,
126
+ "scene": scene_name,
127
+ "ratings": mapped_ratings
128
+ }
129
+ all_data.append(entry)
130
+
131
+ # 保存数据
132
+ with open("ratings_data.json", "w", encoding="utf-8") as f:
133
+ json.dump(all_data, f, ensure_ascii=False, indent=2)
134
+
135
+ return f"✓ 评分已保存 / Ratings saved"
136
+
137
+ # 创建Gradio界面
138
+ def create_video_survey_app():
139
+ # 预加载数据集
140
+ print("初始化应用...")
141
+ load_videos_from_huggingface()
142
+ question_folders = get_question_folders()
143
+
144
+ if not question_folders:
145
+ print("错误: 没有找到任何场景数据")
146
+ return None
147
+
148
+ print(f"找到 {len(question_folders)} 个场景")
149
+
150
+ with gr.Blocks(title="视频生成质量用户研究", theme=gr.themes.Soft()) as demo:
151
+ gr.Markdown("# 🎬 视频生成质量用户研究 / Video Generation Quality User Study")
152
+ gr.Markdown("""
153
+ ### 说明 / Instructions:
154
+ - ��观看每个视频并进行评分 / Please watch each video and rate them
155
+ - 评分标准 / Rating criteria:
156
+ - **动态生成质量** / Dynamic Generation Quality: 视频中物体运动的流畅性和真实性
157
+ - **静态一致性** / Static Consistency: 视频中静态物体的稳定性和一致性
158
+ - **整体质量** / Overall Quality: 视频的整体观感
159
+ - 评分范围:1-5分(5分最好)/ Rating scale: 1-5 (5 = Best)
160
+ """)
161
+
162
+ # 状态变量
163
+ current_question_idx = gr.State(0)
164
+ current_method_mapping = gr.State({})
165
+
166
+ # 进度显示
167
+ with gr.Row():
168
+ prev_btn = gr.Button("⬅️ 上一题 / Previous", size="sm")
169
+ question_text = gr.Markdown(f"**场景 1 / {len(question_folders)}**")
170
+ next_btn = gr.Button("下一题 / Next ➡️", size="sm", variant="primary")
171
+
172
+ status_text = gr.Textbox(label="状态 / Status", interactive=False, visible=False)
173
+
174
+ # 视频显示区域(4个视频)
175
+ video_components = []
176
+ rating_components = []
177
+
178
+ for i in range(4):
179
+ method_name = f"Method {chr(65+i)}"
180
+
181
+ with gr.Group():
182
+ gr.Markdown(f"### 🎥 {method_name}")
183
+
184
+ video = gr.Video(label="", height=300)
185
+ video_components.append(video)
186
+
187
+ with gr.Row():
188
+ dynamic = gr.Slider(
189
+ minimum=1, maximum=5, step=1, value=3,
190
+ label="动态质量 / Dynamic Quality",
191
+ info="1=差 / Poor, 5=优秀 / Excellent"
192
+ )
193
+ static = gr.Slider(
194
+ minimum=1, maximum=5, step=1, value=3,
195
+ label="静态一致性 / Static Consistency",
196
+ info="1=差 / Poor, 5=优秀 / Excellent"
197
+ )
198
+ overall = gr.Slider(
199
+ minimum=1, maximum=5, step=1, value=3,
200
+ label="整体质量 / Overall Quality",
201
+ info="1=差 / Poor, 5=优秀 / Excellent"
202
+ )
203
+
204
+ rating_components.append({
205
+ "dynamic": dynamic,
206
+ "static": static,
207
+ "overall": overall
208
+ })
209
+
210
+ # 更新问题显示
211
+ def update_question(question_idx, save_previous=False, prev_ratings=None, prev_mapping=None):
212
+ if question_idx < 0:
213
+ question_idx = 0
214
+ if question_idx >= len(question_folders):
215
+ question_idx = len(question_folders) - 1
216
+
217
+ # 如果需要,保存上一题的评分
218
+ save_msg = ""
219
+ if save_previous and prev_ratings and prev_mapping:
220
+ prev_scene = question_folders[question_idx - 1] if question_idx > 0 else None
221
+ if prev_scene:
222
+ save_msg = save_ratings(prev_scene, prev_ratings, prev_mapping)
223
+
224
+ scene_name = question_folders[question_idx]
225
+ videos, method_mapping = get_videos_for_question(scene_name)
226
+
227
+ # 更新视频显示
228
+ video_updates = []
229
+ for i in range(4):
230
+ method_name = f"Method {chr(65+i)}"
231
+ if method_name in videos:
232
+ video_updates.append(gr.Video(value=videos[method_name], visible=True))
233
+ else:
234
+ video_updates.append(gr.Video(value=None, visible=False))
235
+
236
+ # 重置评分
237
+ rating_updates = [gr.Slider(value=3) for _ in range(12)] # 4个视频 x 3个评分
238
+
239
+ question_markdown = f"**场景 {question_idx + 1} / {len(question_folders)}**: `{scene_name}`"
240
+
241
+ return (
242
+ [question_idx, method_mapping, question_markdown, save_msg] +
243
+ video_updates +
244
+ rating_updates
245
+ )
246
+
247
+ # 收集当前评分
248
+ def collect_ratings(*rating_values):
249
+ ratings = {}
250
+ for i in range(4):
251
+ method_name = f"Method {chr(65+i)}"
252
+ base_idx = i * 3
253
+ ratings[method_name] = {
254
+ "dynamic_quality": rating_values[base_idx],
255
+ "static_consistency": rating_values[base_idx + 1],
256
+ "overall_quality": rating_values[base_idx + 2]
257
+ }
258
+ return ratings
259
+
260
+ # 下一题按钮
261
+ def on_next(question_idx, method_mapping, *rating_values):
262
+ # 收集当前评分
263
+ current_ratings = collect_ratings(*rating_values)
264
+
265
+ # 保存当前评分
266
+ scene_name = question_folders[question_idx]
267
+ save_msg = save_ratings(scene_name, current_ratings, method_mapping)
268
+
269
+ # 移动到下一题
270
+ new_idx = question_idx + 1
271
+ if new_idx >= len(question_folders):
272
+ return [
273
+ question_idx,
274
+ method_mapping,
275
+ f"**✅ 所有场景已完成!/ All scenes completed!**",
276
+ save_msg + "\n🎉 感谢参与!/ Thank you for participating!"
277
+ ] + [gr.Video()] * 4 + [gr.Slider(value=3)] * 12
278
+
279
+ return update_question(new_idx)
280
+
281
+ # 上一题按钮
282
+ def on_prev(question_idx, *args):
283
+ new_idx = question_idx - 1
284
+ if new_idx < 0:
285
+ new_idx = 0
286
+ return update_question(new_idx)
287
+
288
+ # 收集所有评分组件
289
+ all_rating_inputs = []
290
+ for comp in rating_components:
291
+ all_rating_inputs.extend([comp["dynamic"], comp["static"], comp["overall"]])
292
+
293
+ # 绑定事件
294
+ next_btn.click(
295
+ on_next,
296
+ inputs=[current_question_idx, current_method_mapping] + all_rating_inputs,
297
+ outputs=[
298
+ current_question_idx,
299
+ current_method_mapping,
300
+ question_text,
301
+ status_text
302
+ ] + video_components + all_rating_inputs
303
+ )
304
+
305
+ prev_btn.click(
306
+ on_prev,
307
+ inputs=[current_question_idx] + all_rating_inputs,
308
+ outputs=[
309
+ current_question_idx,
310
+ current_method_mapping,
311
+ question_text,
312
+ status_text
313
+ ] + video_components + all_rating_inputs
314
+ )
315
+
316
+ # 初始化第一个问题
317
+ demo.load(
318
+ lambda: update_question(0),
319
+ outputs=[
320
+ current_question_idx,
321
+ current_method_mapping,
322
+ question_text,
323
+ status_text
324
+ ] + video_components + all_rating_inputs
325
+ )
326
+
327
+ return demo
328
+
329
+ if __name__ == "__main__":
330
+ app = create_video_survey_app()
331
+ if app:
332
+ app.launch(server_name="0.0.0.0", server_port=7860, share=False)
333
+ else:
334
+ print("应用初始化失败 / App initialization failed")