File size: 13,456 Bytes
d1d5132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
import gradio as gr
import os
import json
from datetime import datetime
from datasets import load_dataset
import tempfile

# 从Hugging Face dataset加载视频 / Load videos from Hugging Face dataset
def load_videos_from_huggingface():
    try:
        dataset = load_dataset("WenjiaWang/videoforuser")
        print("成功加载数据集: WenjiaWang/videoforuser / Successfully loaded dataset")
        
        # 获取数据集中的文件夹列表
        question_folders = set()
        all_videos = {}  # 存储所有视频路径 {folder: {method: video_path}}
        
        # 假设数据集包含file_path字段,或者使用标准的数据集结构
        if 'train' in dataset:
            for item in dataset['train']:
                # 根据数据集的实际结构调整字段名
                if 'file_path' in item:
                    file_path = item['file_path']
                elif 'path' in item:
                    file_path = item['path']
                elif 'video_path' in item:
                    file_path = item['video_path']
                else:
                    continue
                
                # 从路径提取文件夹名和文件名
                folder_name = os.path.basename(os.path.dirname(file_path))
                file_name = os.path.basename(file_path)
                
                question_folders.add(folder_name)
                
                if folder_name not in all_videos:
                    all_videos[folder_name] = {}
                
                # 确定方法名并映射到匿名显示名
                display_name = f"Method {len(all_videos[folder_name]) + 1}"
                all_videos[folder_name][display_name] = item
        
        return list(sorted(question_folders)), all_videos
        
    except Exception as e:
        print(f"Hugging Face数据集加载失败: {e} / Hugging Face dataset loading failed")
        print("回退到本地videos文件夹 / Falling back to local videos folder")
        return None, None

# 获取所有问题文件夹 / Get all question folders
def get_question_folders():
    # 优先使用Hugging Face数据集
    hf_folders, hf_videos = load_videos_from_huggingface()
    if hf_folders and hf_videos:
        return hf_folders
    
    # 回退到本地文件夹
    video_dir = "videos"
    if not os.path.exists(video_dir):
        print(f"视频目录不存在: {video_dir} / Video directory not found")
        return []
    
    folders = [f for f in os.listdir(video_dir) if os.path.isdir(os.path.join(video_dir, f))]
    return sorted(folders)

# 获取问题文件夹中的所有视频 / Get all videos in the question folder
def get_videos_for_question(question_folder):
    # 优先使用Hugging Face数据集
    hf_folders, hf_videos = load_videos_from_huggingface()
    if hf_folders and hf_videos and question_folder in hf_videos:
        videos = {}
        method_mapping = {}
        
        for display_name, item in hf_videos[question_folder].items():
            # 从item中获取视频路径或内容
            if 'file_content' in item:
                # 如果视频内容直接存储在数据集中
                import tempfile
                with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp:
                    tmp.write(item['file_content'])
                    videos[display_name] = tmp.name
            elif 'file_path' in item:
                # 如果存储的是路径
                videos[display_name] = item['file_path']
            elif 'path' in item:
                videos[display_name] = item['path']
            
            # 创建方法映射(假设能从路径推断真实方法名)
            if 'file_path' in item:
                file_path = item['file_path']
                if 'stage12_new' in file_path:
                    method_mapping[display_name] = "Ours (stage12_new)"
                elif 'gen3c' in file_path:
                    method_mapping[display_name] = "Gen3c"
                elif 'svc' in file_path:
                    method_mapping[display_name] = "SVC"
                elif 'trajattn' in file_path:
                    method_mapping[display_name] = "TrajAttn"
        
        return videos, method_mapping
    
    # 回退到本地文件夹
    video_dir = os.path.join("videos", question_folder)
    videos = {}
    method_mapping = {}  # 显示名称映射到真实方法名 / Map display names to real method names
    
    if os.path.exists(video_dir):
        for file in os.listdir(video_dir):
            if file.endswith('.mp4'):
                video_path = os.path.join(video_dir, file)
                # 确定方法名称 / Determine method name
                if file.startswith('stage12_new'):
                    real_method_name = "Ours (stage12_new)"
                elif file.startswith('gen3c'):
                    real_method_name = "Gen3c"
                elif file.startswith('svc'):
                    real_method_name = "SVC"
                elif file.startswith('trajattn'):
                    real_method_name = "TrajAttn"
                else:
                    real_method_name = file.replace('.mp4', '')
                
                # 使用匿名显示名称 / Use anonymous display names
                display_name = f"Method {len(videos) + 1}"
                videos[display_name] = video_path
                method_mapping[display_name] = real_method_name
    
    return videos, method_mapping

# 保存评分数据 / Save rating data
def save_ratings(current_question, ratings_data, method_mapping):
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    filename = f"ratings_{timestamp}.json"
    
    # 将显示名称映射到真实方法名 / Map display names to real method names
    mapped_ratings = {}
    for display_name, ratings in ratings_data.items():
        real_method = method_mapping.get(display_name, display_name)
        mapped_ratings[real_method] = ratings
    
    # 读取现有数据 / Read existing data
    all_data = []
    if os.path.exists("ratings_data.json"):
        try:
            with open("ratings_data.json", "r", encoding="utf-8") as f:
                all_data = json.load(f)
        except:
            all_data = []
    
    # 添加新数据 / Add new data
    entry = {
        "timestamp": timestamp,
        "question": current_question,
        "ratings": mapped_ratings
    }
    all_data.append(entry)
    
    # 保存数据 / Save data
    with open("ratings_data.json", "w", encoding="utf-8") as f:
        json.dump(all_data, f, ensure_ascii=False, indent=2)
    
    return f"评分已保存到 {filename}"

# 创建Gradio界面 / Create Gradio interface
def create_video_survey_app():
    question_folders = get_question_folders()
    current_question_idx = gr.State(0)
    all_ratings = gr.State({})
    current_method_mapping = gr.State({})
    
    with gr.Blocks(title="视频生成质量用户研究") as demo:
        gr.Markdown("# 视频生成质量用户研究 / Video Generation Quality User Study")
        gr.Markdown("Please rate each video based on dynamic generation quality, static consistency, and overall quality. 对每个视频从动态生成能力、静态物体的一致性、整体质量三方面进行1~5的评分")
        gr.Markdown("Rating scale: 1-5 (5 = Best) 评分等级:1-5分(5分为最好)")
        
        with gr.Row():
            prev_btn = gr.Button("上一题 / Previous", visible=False)
            question_text = gr.Markdown(f"问题 1 / {len(question_folders)} / Question 1 / {len(question_folders)}")
            next_btn = gr.Button("下一题 / Next")
        
        # 视频显示区域 / Video display area
        with gr.Row():
            video_cols = []
            method_display_names = ["Method 1", "Method 2", "Method 3", "Method 4"]
            
            for i, method in enumerate(method_display_names):
                with gr.Column():
                    gr.Markdown(f"### {method}")
                    video = gr.Video(visible=False)
                    video_cols.append(video)
        
        # 评分区域 / Rating area
        with gr.Row():
            ratings_cols = []
            
            for i, method in enumerate(method_display_names):
                with gr.Column():
                    gr.Markdown(f"#### {method} 评分 / Rating")
                    
                    dynamic_quality = gr.Slider(
                        minimum=1, maximum=5, step=1, value=3,
                        label="Dynamic Generation Quality / 动态生成能力"
                    )
                    static_consistency = gr.Slider(
                        minimum=1, maximum=5, step=1, value=3,
                        label="Static Consistency / 静态一致性"
                    )
                    overall_quality = gr.Slider(
                        minimum=1, maximum=5, step=1, value=3,
                        label="Overall Quality / 整体质量"
                    )
                    
                    ratings_cols.append({
                        "dynamic_quality": dynamic_quality,
                        "static_consistency": static_consistency,
                        "overall_quality": overall_quality
                    })
        
        status_text = gr.Textbox(label="状态 / Status", interactive=False)
        
        # 更新问题显示的函数 / Function to update question display
        def update_question(question_idx):
            if question_idx < 0 or question_idx >= len(question_folders):
                return []
            
            question_folder = question_folders[question_idx]
            videos, method_mapping = get_videos_for_question(question_folder)
            
            video_outputs = []
            method_display_names = ["Method 1", "Method 2", "Method 3", "Method 4"]
            
            for method_display in method_display_names:
                if method_display in videos:
                    video_outputs.append(gr.Video(value=videos[method_display], visible=True))
                else:
                    video_outputs.append(gr.Video(visible=False))
            
            question_markdown = f"问题 {question_idx + 1} / {len(question_folders)}: {question_folder} / Question {question_idx + 1} / {len(question_folders)}: {question_folder}"
            
            return video_outputs + [gr.Markdown(question_markdown)] + [gr.State(value=method_mapping)]
        
        # 保存当前评分的函数 / Function to save current ratings
        def save_current_ratings(question_idx, method_mapping, *rating_values):
            if question_idx >= len(question_folders):
                return "没有更多问题了 / No more questions"
            
            question_folder = question_folders[question_idx]
            method_display_names = ["Method 1", "Method 2", "Method 3", "Method 4"]
            
            ratings = {}
            for i, method_display in enumerate(method_display_names):
                base_idx = i * 3
                ratings[method_display] = {
                    "dynamic_quality": rating_values[base_idx],
                    "static_consistency": rating_values[base_idx + 1],
                    "overall_quality": rating_values[base_idx + 2]
                }
            
            return save_ratings(question_folder, ratings, method_mapping)
        
        # 事件处理 / Event handling
        def on_next_click(question_idx, method_mapping, *rating_values):
            # 先保存当前评分 / First save current ratings
            save_message = save_current_ratings(question_idx, method_mapping, *rating_values)
            
            # 移动到下一题 / Move to next question
            new_idx = question_idx + 1
            if new_idx >= len(question_folders):
                return [save_message + "\n所有问题已完成!/ All questions completed!"] + [gr.Button(visible=False)] * 3
            
            # 更新显示 / Update display
            video_updates = update_question(new_idx)
            
            return [save_message + f"\n当前问题: {question_folders[new_idx]}/ Current question: {question_folders[new_idx]}"] + video_updates
        
        def on_prev_click(question_idx):
            new_idx = question_idx - 1
            if new_idx < 0:
                return []
            
            return update_question(new_idx)
        
        # 收集所有评分组件
        all_rating_components = []
        for col in ratings_cols:
            all_rating_components.extend([
                col["dynamic_quality"],
                col["static_consistency"],
                col["overall_quality"]
            ])
        
        # 初始化第一个问题
        initial_videos = update_question(0)
        
        # 绑定事件 / Bind events
        next_btn.click(
            on_next_click,
            inputs=[current_question_idx, current_method_mapping] + all_rating_components,
            outputs=[status_text] + video_cols + [question_text] + [current_method_mapping]
        )
        
        prev_btn.click(
            on_prev_click,
            inputs=[current_question_idx],
            outputs=video_cols + [question_text] + [current_method_mapping]
        )
    
    return demo

if __name__ == "__main__":
    app = create_video_survey_app()
    app.launch(share=True)