Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import json | |
| from datetime import datetime | |
| from datasets import load_dataset | |
| import tempfile | |
| # 从Hugging Face dataset加载视频 / Load videos from Hugging Face dataset | |
| def load_videos_from_huggingface(): | |
| try: | |
| dataset = load_dataset("WenjiaWang/videoforuser") | |
| print("成功加载数据集: WenjiaWang/videoforuser / Successfully loaded dataset") | |
| # 获取数据集中的文件夹列表 | |
| question_folders = set() | |
| all_videos = {} # 存储所有视频路径 {folder: {method: video_path}} | |
| # 假设数据集包含file_path字段,或者使用标准的数据集结构 | |
| if 'train' in dataset: | |
| for item in dataset['train']: | |
| # 根据数据集的实际结构调整字段名 | |
| if 'file_path' in item: | |
| file_path = item['file_path'] | |
| elif 'path' in item: | |
| file_path = item['path'] | |
| elif 'video_path' in item: | |
| file_path = item['video_path'] | |
| else: | |
| continue | |
| # 从路径提取文件夹名和文件名 | |
| folder_name = os.path.basename(os.path.dirname(file_path)) | |
| file_name = os.path.basename(file_path) | |
| question_folders.add(folder_name) | |
| if folder_name not in all_videos: | |
| all_videos[folder_name] = {} | |
| # 确定方法名并映射到匿名显示名 | |
| display_name = f"Method {len(all_videos[folder_name]) + 1}" | |
| all_videos[folder_name][display_name] = item | |
| return list(sorted(question_folders)), all_videos | |
| except Exception as e: | |
| print(f"Hugging Face数据集加载失败: {e} / Hugging Face dataset loading failed") | |
| print("回退到本地videos文件夹 / Falling back to local videos folder") | |
| return None, None | |
| # 获取所有问题文件夹 / Get all question folders | |
| def get_question_folders(): | |
| # 优先使用Hugging Face数据集 | |
| hf_folders, hf_videos = load_videos_from_huggingface() | |
| if hf_folders and hf_videos: | |
| return hf_folders | |
| # 回退到本地文件夹 | |
| video_dir = "videos" | |
| if not os.path.exists(video_dir): | |
| print(f"视频目录不存在: {video_dir} / Video directory not found") | |
| return [] | |
| folders = [f for f in os.listdir(video_dir) if os.path.isdir(os.path.join(video_dir, f))] | |
| return sorted(folders) | |
| # 获取问题文件夹中的所有视频 / Get all videos in the question folder | |
| def get_videos_for_question(question_folder): | |
| # 优先使用Hugging Face数据集 | |
| hf_folders, hf_videos = load_videos_from_huggingface() | |
| if hf_folders and hf_videos and question_folder in hf_videos: | |
| videos = {} | |
| method_mapping = {} | |
| for display_name, item in hf_videos[question_folder].items(): | |
| # 从item中获取视频路径或内容 | |
| if 'file_content' in item: | |
| # 如果视频内容直接存储在数据集中 | |
| import tempfile | |
| with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp: | |
| tmp.write(item['file_content']) | |
| videos[display_name] = tmp.name | |
| elif 'file_path' in item: | |
| # 如果存储的是路径 | |
| videos[display_name] = item['file_path'] | |
| elif 'path' in item: | |
| videos[display_name] = item['path'] | |
| # 创建方法映射(假设能从路径推断真实方法名) | |
| if 'file_path' in item: | |
| file_path = item['file_path'] | |
| if 'stage12_new' in file_path: | |
| method_mapping[display_name] = "Ours (stage12_new)" | |
| elif 'gen3c' in file_path: | |
| method_mapping[display_name] = "Gen3c" | |
| elif 'svc' in file_path: | |
| method_mapping[display_name] = "SVC" | |
| elif 'trajattn' in file_path: | |
| method_mapping[display_name] = "TrajAttn" | |
| return videos, method_mapping | |
| # 回退到本地文件夹 | |
| video_dir = os.path.join("videos", question_folder) | |
| videos = {} | |
| method_mapping = {} # 显示名称映射到真实方法名 / Map display names to real method names | |
| if os.path.exists(video_dir): | |
| for file in os.listdir(video_dir): | |
| if file.endswith('.mp4'): | |
| video_path = os.path.join(video_dir, file) | |
| # 确定方法名称 / Determine method name | |
| if file.startswith('stage12_new'): | |
| real_method_name = "Ours (stage12_new)" | |
| elif file.startswith('gen3c'): | |
| real_method_name = "Gen3c" | |
| elif file.startswith('svc'): | |
| real_method_name = "SVC" | |
| elif file.startswith('trajattn'): | |
| real_method_name = "TrajAttn" | |
| else: | |
| real_method_name = file.replace('.mp4', '') | |
| # 使用匿名显示名称 / Use anonymous display names | |
| display_name = f"Method {len(videos) + 1}" | |
| videos[display_name] = video_path | |
| method_mapping[display_name] = real_method_name | |
| return videos, method_mapping | |
| # 保存评分数据 / Save rating data | |
| def save_ratings(current_question, ratings_data, method_mapping): | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| filename = f"ratings_{timestamp}.json" | |
| # 将显示名称映射到真实方法名 / Map display names to real method names | |
| mapped_ratings = {} | |
| for display_name, ratings in ratings_data.items(): | |
| real_method = method_mapping.get(display_name, display_name) | |
| mapped_ratings[real_method] = ratings | |
| # 读取现有数据 / Read existing data | |
| all_data = [] | |
| if os.path.exists("ratings_data.json"): | |
| try: | |
| with open("ratings_data.json", "r", encoding="utf-8") as f: | |
| all_data = json.load(f) | |
| except: | |
| all_data = [] | |
| # 添加新数据 / Add new data | |
| entry = { | |
| "timestamp": timestamp, | |
| "question": current_question, | |
| "ratings": mapped_ratings | |
| } | |
| all_data.append(entry) | |
| # 保存数据 / Save data | |
| with open("ratings_data.json", "w", encoding="utf-8") as f: | |
| json.dump(all_data, f, ensure_ascii=False, indent=2) | |
| return f"评分已保存到 {filename}" | |
| # 创建Gradio界面 / Create Gradio interface | |
| def create_video_survey_app(): | |
| question_folders = get_question_folders() | |
| current_question_idx = gr.State(0) | |
| all_ratings = gr.State({}) | |
| current_method_mapping = gr.State({}) | |
| with gr.Blocks(title="视频生成质量用户研究") as demo: | |
| gr.Markdown("# 视频生成质量用户研究 / Video Generation Quality User Study") | |
| gr.Markdown("Please rate each video based on dynamic generation quality, static consistency, and overall quality. 对每个视频从动态生成能力、静态物体的一致性、整体质量三方面进行1~5的评分") | |
| gr.Markdown("Rating scale: 1-5 (5 = Best) 评分等级:1-5分(5分为最好)") | |
| with gr.Row(): | |
| prev_btn = gr.Button("上一题 / Previous", visible=False) | |
| question_text = gr.Markdown(f"问题 1 / {len(question_folders)} / Question 1 / {len(question_folders)}") | |
| next_btn = gr.Button("下一题 / Next") | |
| # 视频显示区域 / Video display area | |
| with gr.Row(): | |
| video_cols = [] | |
| method_display_names = ["Method 1", "Method 2", "Method 3", "Method 4"] | |
| for i, method in enumerate(method_display_names): | |
| with gr.Column(): | |
| gr.Markdown(f"### {method}") | |
| video = gr.Video(visible=False) | |
| video_cols.append(video) | |
| # 评分区域 / Rating area | |
| with gr.Row(): | |
| ratings_cols = [] | |
| for i, method in enumerate(method_display_names): | |
| with gr.Column(): | |
| gr.Markdown(f"#### {method} 评分 / Rating") | |
| dynamic_quality = gr.Slider( | |
| minimum=1, maximum=5, step=1, value=3, | |
| label="Dynamic Generation Quality / 动态生成能力" | |
| ) | |
| static_consistency = gr.Slider( | |
| minimum=1, maximum=5, step=1, value=3, | |
| label="Static Consistency / 静态一致性" | |
| ) | |
| overall_quality = gr.Slider( | |
| minimum=1, maximum=5, step=1, value=3, | |
| label="Overall Quality / 整体质量" | |
| ) | |
| ratings_cols.append({ | |
| "dynamic_quality": dynamic_quality, | |
| "static_consistency": static_consistency, | |
| "overall_quality": overall_quality | |
| }) | |
| status_text = gr.Textbox(label="状态 / Status", interactive=False) | |
| # 更新问题显示的函数 / Function to update question display | |
| def update_question(question_idx): | |
| if question_idx < 0 or question_idx >= len(question_folders): | |
| return [] | |
| question_folder = question_folders[question_idx] | |
| videos, method_mapping = get_videos_for_question(question_folder) | |
| video_outputs = [] | |
| method_display_names = ["Method 1", "Method 2", "Method 3", "Method 4"] | |
| for method_display in method_display_names: | |
| if method_display in videos: | |
| video_outputs.append(gr.Video(value=videos[method_display], visible=True)) | |
| else: | |
| video_outputs.append(gr.Video(visible=False)) | |
| question_markdown = f"问题 {question_idx + 1} / {len(question_folders)}: {question_folder} / Question {question_idx + 1} / {len(question_folders)}: {question_folder}" | |
| return video_outputs + [gr.Markdown(question_markdown)] + [gr.State(value=method_mapping)] | |
| # 保存当前评分的函数 / Function to save current ratings | |
| def save_current_ratings(question_idx, method_mapping, *rating_values): | |
| if question_idx >= len(question_folders): | |
| return "没有更多问题了 / No more questions" | |
| question_folder = question_folders[question_idx] | |
| method_display_names = ["Method 1", "Method 2", "Method 3", "Method 4"] | |
| ratings = {} | |
| for i, method_display in enumerate(method_display_names): | |
| base_idx = i * 3 | |
| ratings[method_display] = { | |
| "dynamic_quality": rating_values[base_idx], | |
| "static_consistency": rating_values[base_idx + 1], | |
| "overall_quality": rating_values[base_idx + 2] | |
| } | |
| return save_ratings(question_folder, ratings, method_mapping) | |
| # 事件处理 / Event handling | |
| def on_next_click(question_idx, method_mapping, *rating_values): | |
| # 先保存当前评分 / First save current ratings | |
| save_message = save_current_ratings(question_idx, method_mapping, *rating_values) | |
| # 移动到下一题 / Move to next question | |
| new_idx = question_idx + 1 | |
| if new_idx >= len(question_folders): | |
| return [save_message + "\n所有问题已完成!/ All questions completed!"] + [gr.Button(visible=False)] * 3 | |
| # 更新显示 / Update display | |
| video_updates = update_question(new_idx) | |
| return [save_message + f"\n当前问题: {question_folders[new_idx]}/ Current question: {question_folders[new_idx]}"] + video_updates | |
| def on_prev_click(question_idx): | |
| new_idx = question_idx - 1 | |
| if new_idx < 0: | |
| return [] | |
| return update_question(new_idx) | |
| # 收集所有评分组件 | |
| all_rating_components = [] | |
| for col in ratings_cols: | |
| all_rating_components.extend([ | |
| col["dynamic_quality"], | |
| col["static_consistency"], | |
| col["overall_quality"] | |
| ]) | |
| # 初始化第一个问题 | |
| initial_videos = update_question(0) | |
| # 绑定事件 / Bind events | |
| next_btn.click( | |
| on_next_click, | |
| inputs=[current_question_idx, current_method_mapping] + all_rating_components, | |
| outputs=[status_text] + video_cols + [question_text] + [current_method_mapping] | |
| ) | |
| prev_btn.click( | |
| on_prev_click, | |
| inputs=[current_question_idx], | |
| outputs=video_cols + [question_text] + [current_method_mapping] | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| app = create_video_survey_app() | |
| app.launch(share=True) | |