Spaces:
Sleeping
Sleeping
| """ | |
| Hugging Face Space: Video Model Evaluator | |
| 视频生成模型评估系统 - 支持 Prompt、模型、视频的评估和评分 | |
| """ | |
| import os | |
| import gradio as gr | |
| import pandas as pd | |
| from datetime import datetime, timedelta | |
| from typing import Optional, Dict, Any, List, Tuple | |
| import json | |
| import logging | |
| from huggingface_hub import HfApi | |
| import tempfile | |
| import time | |
| import asyncio | |
| import nest_asyncio | |
| import warnings | |
| # 抑制 asyncio 事件循环清理时的警告(这是 Gradio 6.0 的已知问题,不影响功能) | |
| warnings.filterwarnings('ignore', category=RuntimeWarning, module='asyncio') | |
| # 应用 nest_asyncio 以解决事件循环嵌套问题 | |
| nest_asyncio.apply() | |
| # 导入视频生成服务 | |
| from pollo_service_single import PolloAIService, get_pollo_service | |
| # 导入S3工具 | |
| from s3_utils import S3Utils | |
| # 配置日志 | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # 配置常量 | |
| MAX_DAILY_CALLS = 4 # 每个用户每天最多调用次数 | |
| DATASET_REPO_ID = "video-model-evaluator-cuti/video-evaluations" # Private Dataset 名称 | |
| HF_TOKEN = os.getenv("HF_TOKEN", "") # 从 Space Settings 获取 | |
| API_KEY = os.getenv("API_KEY", "") # 从 Space Settings 获取 | |
| MAX_POLLO_CONCURRENCY = 5 # Pollo API 最大并发数 | |
| # 支持的模型列表 | |
| MODELS_TO_CALL = [ | |
| "sora/sora-2-pro", # Sora 2 pro (修正路径) | |
| "bytedance/seedance-pro", # Seedance Pro | |
| "google/veo3", # Veo 3.1 | |
| "kling-ai/kling-v2-6", # Kling 2.6 | |
| ] | |
| # 模型显示名称映射 | |
| MODEL_DISPLAY_NAMES = { | |
| "sora/sora-2-pro": "Sora 2 pro", | |
| "bytedance/seedance-pro": "Seedance Pro", | |
| "google/veo3": "Veo 3.1", | |
| "kling-ai/kling-v2-6": "Kling 2.6", | |
| } | |
| # 初始化 HF API | |
| hf_api = HfApi(token=HF_TOKEN) if HF_TOKEN else None | |
| # 初始化视频生成服务 | |
| video_service = None | |
| if API_KEY: | |
| os.environ['POLLO_API_KEY'] = API_KEY | |
| video_service = get_pollo_service("bytedance/seedance-pro") | |
| else: | |
| logger.warning("API_KEY 未设置,视频生成功能将不可用") | |
| # 初始化 S3 工具 | |
| s3_utils = S3Utils() | |
| class DatasetManager: | |
| """管理 Private Dataset 的读写操作""" | |
| def __init__(self, repo_id: str, hf_token: str): | |
| self.repo_id = repo_id | |
| self.hf_token = hf_token | |
| self.api = HfApi(token=hf_token) if hf_token else None | |
| def get_user_calls_today(self, username: str) -> int: | |
| """ | |
| 获取用户今天的调用次数(从调用计数文件读取) | |
| Args: | |
| username: Hugging Face 用户名 | |
| Returns: | |
| 今天的调用次数 | |
| """ | |
| if not self.api or not self.repo_id: | |
| logger.warning("Dataset API 未配置,无法检查调用次数") | |
| return 0 | |
| try: | |
| today = datetime.now().strftime("%Y-%m-%d") | |
| count_file = f"call_counts/{today}/{username}.json" | |
| try: | |
| # 下载计数文件 | |
| local_path = self.api.hf_hub_download( | |
| repo_id=self.repo_id, | |
| filename=count_file, | |
| repo_type="dataset", | |
| token=self.hf_token | |
| ) | |
| # 读取计数 | |
| with open(local_path, 'r', encoding='utf-8') as f: | |
| data = json.load(f) | |
| return data.get('count', 0) | |
| except Exception: | |
| # 文件不存在,返回0 | |
| return 0 | |
| except Exception as e: | |
| logger.error(f"获取用户调用次数失败: {e}") | |
| return 0 | |
| def increment_user_calls(self, username: str) -> bool: | |
| """ | |
| 增加用户今天的调用次数(点击生成视频时调用) | |
| Args: | |
| username: Hugging Face 用户名 | |
| Returns: | |
| 是否成功 | |
| """ | |
| if not self.api or not self.repo_id: | |
| logger.warning("Dataset API 未配置,无法更新调用次数") | |
| return False | |
| try: | |
| today = datetime.now().strftime("%Y-%m-%d") | |
| count_file = f"call_counts/{today}/{username}.json" | |
| # 获取当前计数 | |
| current_count = self.get_user_calls_today(username) | |
| new_count = current_count + 1 | |
| # 创建计数数据 | |
| count_data = { | |
| 'user': username, | |
| 'date': today, | |
| 'count': new_count, | |
| 'last_updated': datetime.now().isoformat() | |
| } | |
| # 保存到临时文件 | |
| import tempfile | |
| with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False, encoding='utf-8') as f: | |
| json.dump(count_data, f, ensure_ascii=False, indent=2) | |
| temp_path = f.name | |
| try: | |
| # 上传到 Dataset | |
| self.api.upload_file( | |
| path_or_fileobj=temp_path, | |
| path_in_repo=count_file, | |
| repo_id=self.repo_id, | |
| repo_type="dataset", | |
| token=self.hf_token | |
| ) | |
| logger.info(f"用户 {username} 调用次数已更新: {new_count}") | |
| return True | |
| finally: | |
| # 清理临时文件 | |
| os.unlink(temp_path) | |
| except Exception as e: | |
| logger.error(f"更新用户调用次数失败: {e}") | |
| return False | |
| def get_user_history(self, username: str, limit: int = 10) -> List[Dict]: | |
| """ | |
| 获取用户的历史评分记录 | |
| Args: | |
| username: Hugging Face 用户名 | |
| limit: 返回记录数量限制 | |
| Returns: | |
| 历史记录列表 | |
| """ | |
| if not self.api or not self.repo_id: | |
| logger.warning("Dataset API 未配置,无法获取历史记录") | |
| return [] | |
| try: | |
| # 列出评分文件 | |
| files = self.api.list_repo_files( | |
| repo_id=self.repo_id, | |
| repo_type="dataset", | |
| token=self.hf_token | |
| ) | |
| # 筛选该用户的评分文件(按用户分组:evaluations/{username}/时间戳.jsonl) | |
| user_files = [f for f in files | |
| if f.startswith(f"evaluations/{username}/") | |
| and f.endswith('.jsonl')] | |
| # 按文件名倒序排序(文件名就是时间戳,自然排序) | |
| user_files.sort(reverse=True) | |
| history = [] | |
| for file_path in user_files[:limit]: # 只读需要的数量 | |
| try: | |
| # 下载文件 | |
| local_path = self.api.hf_hub_download( | |
| repo_id=self.repo_id, | |
| filename=file_path, | |
| repo_type="dataset", | |
| token=self.hf_token | |
| ) | |
| # 读取数据 | |
| with open(local_path, 'r', encoding='utf-8') as f: | |
| data = json.load(f) | |
| # 只返回该用户的记录 | |
| if data.get('user') == username: | |
| history.append({ | |
| 'timestamp': data.get('timestamp'), | |
| 'prompt': data.get('prompt', ''), | |
| 'scores': data.get('scores', {}), | |
| 'video_urls': data.get('video_urls', {}) | |
| }) | |
| if len(history) >= limit: | |
| break | |
| except Exception as e: | |
| logger.warning(f"读取历史记录失败 {file_path}: {e}") | |
| continue | |
| return history | |
| except Exception as e: | |
| logger.error(f"获取用户历史记录失败: {e}") | |
| return [] | |
| def save_evaluation( | |
| self, | |
| username: str, | |
| prompt: str, | |
| model_results: Dict[str, Any], | |
| scores: Dict[str, float], | |
| video_urls: Dict[str, str] | |
| ) -> bool: | |
| """ | |
| 保存评分数据到 Private Dataset | |
| Args: | |
| username: 用户名 | |
| prompt: 提示词 | |
| model_results: 模型结果字典 | |
| scores: 评分字典 {model_name: score} | |
| video_urls: 视频URL字典 {model_name: url} | |
| Returns: | |
| 是否保存成功 | |
| """ | |
| if not self.api or not self.repo_id: | |
| logger.warning("Dataset API 未配置,无法保存数据") | |
| return False | |
| try: | |
| # 准备数据 | |
| data = { | |
| "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), | |
| "user": username, | |
| "prompt": prompt, | |
| "scores": json.dumps(scores, ensure_ascii=False), | |
| "video_urls": json.dumps(video_urls, ensure_ascii=False), | |
| "model_results": json.dumps(model_results, ensure_ascii=False, default=str) | |
| } | |
| # 创建临时文件 | |
| with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False, encoding='utf-8') as f: | |
| json.dump(data, f, ensure_ascii=False) | |
| temp_file = f.name | |
| try: | |
| # 按用户分组存储(方便查询单个用户的历史记录) | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| filename = f"evaluations/{username}/{timestamp}.jsonl" | |
| # 上传到 Dataset | |
| self.api.upload_file( | |
| path_or_fileobj=temp_file, | |
| path_in_repo=filename, | |
| repo_id=self.repo_id, | |
| repo_type="dataset", | |
| token=self.hf_token | |
| ) | |
| logger.info(f"成功保存评分数据: {filename}") | |
| return True | |
| finally: | |
| # 清理临时文件 | |
| if os.path.exists(temp_file): | |
| os.remove(temp_file) | |
| except Exception as e: | |
| logger.error(f"保存评分数据失败: {e}") | |
| return False | |
| # 初始化 Dataset Manager | |
| dataset_manager = DatasetManager(DATASET_REPO_ID, HF_TOKEN) if DATASET_REPO_ID and HF_TOKEN else None | |
| def check_user_access(request: gr.Request) -> Tuple[str, bool]: | |
| """ | |
| 检查用户访问权限 | |
| Args: | |
| request: Gradio Request 对象 | |
| Returns: | |
| (username, has_access) 元组 | |
| """ | |
| if not request: | |
| return "", False | |
| # 获取登录用户名 | |
| # 在 Hugging Face Space 中,OAuth 登录后用户信息存储在 Starlette Request 的 session 中 | |
| username = None | |
| # 从 Starlette Request 的 session 中获取 OAuth 用户信息 | |
| if hasattr(request, 'request') and hasattr(request.request, 'session'): | |
| session = request.request.session | |
| # OAuth 用户信息在 session['oauth_info'] 中 | |
| if 'oauth_info' in session: | |
| oauth_info = session.get('oauth_info', {}) | |
| # 用户信息嵌套在 userinfo 字典中 | |
| userinfo = oauth_info.get('userinfo', {}) | |
| # 尝试多个可能的用户名字段 | |
| username = ( | |
| userinfo.get('preferred_username') or | |
| userinfo.get('name') or | |
| userinfo.get('sub') | |
| ) | |
| if username: | |
| logger.info(f"从 session['oauth_info']['userinfo'] 获取用户名: {username}") | |
| # 或者在 session['user'] 中 | |
| elif 'user' in session: | |
| user_info = session.get('user', {}) | |
| logger.info(f"user 内容: {user_info}") | |
| username = user_info.get('preferred_username') or user_info.get('name') or user_info.get('username') | |
| if username: | |
| logger.info(f"从 session['user'] 获取用户名: {username}") | |
| # 如果 session 中没有用户信息,打印调试信息 | |
| if not username: | |
| logger.warning("无法获取用户名,请确保:") | |
| logger.warning("1. 已在 Space Settings 中启用 OAuth (hf_oauth: true)") | |
| logger.warning("2. 用户已通过 'Login with Hugging Face' 按钮登录") | |
| if not username: | |
| return "", False | |
| # 检查今天的调用次数 | |
| if dataset_manager: | |
| calls_today = dataset_manager.get_user_calls_today(username) | |
| if calls_today >= MAX_DAILY_CALLS: | |
| return username, False | |
| return username, True | |
| async def _generate_single_video_async( | |
| model_name: str, | |
| prompt: str, | |
| image_url: Optional[str], | |
| semaphore: asyncio.Semaphore | |
| ) -> Tuple[str, Dict[str, Any], Optional[str], str]: | |
| """ | |
| 异步生成单个模型的视频(使用信号量限制并发) | |
| Args: | |
| model_name: 模型名称 | |
| prompt: 提示词 | |
| image_url: 图片URL(可选) | |
| semaphore: asyncio信号量,用于限制并发数 | |
| Returns: | |
| (model_name, model_result, video_url, status_message) 元组 | |
| """ | |
| display_name = MODEL_DISPLAY_NAMES.get(model_name, model_name) | |
| try: | |
| logger.info(f"开始生成视频: {display_name} ({model_name}), 提示词: {prompt[:50]}...") | |
| # 获取对应模型的服务实例 | |
| service = get_pollo_service(model_name) | |
| # 根据是否有图片选择模式 | |
| mode = "i2v" if image_url else "t2v" | |
| # 提交任务(快速,不需要限制并发) | |
| loop = asyncio.get_event_loop() | |
| result = await loop.run_in_executor( | |
| None, | |
| lambda: service.generate_video( | |
| prompt=prompt, | |
| mode=mode, | |
| input_image_path=image_url if image_url else None, | |
| video_length=5, | |
| width=1280, | |
| height=720 | |
| ) | |
| ) | |
| task_id = result.get('pollo_task_id') | |
| if not task_id: | |
| raise Exception("未获取到任务ID") | |
| logger.info(f"{display_name}: 任务已提交,task_id={task_id}") | |
| # 使用信号量限制轮询并发数 | |
| async with semaphore: | |
| logger.info(f"{display_name}: 开始轮询(当前并发槽位已占用)") | |
| # 轮询任务结果 | |
| max_polls = 60 | |
| poll_interval = 10 | |
| for i in range(max_polls): | |
| # 在线程池中执行同步的轮询操作 | |
| poll_result = await loop.run_in_executor( | |
| None, | |
| service.poll_task_result, | |
| task_id | |
| ) | |
| if poll_result['status'] == 'completed': | |
| pollo_video_url = poll_result.get('video_url') | |
| if pollo_video_url: | |
| # 下载视频并上传到S3(在线程池中执行) | |
| logger.info(f"{display_name}: 下载视频并上传到S3: {pollo_video_url}") | |
| s3_video_url = await loop.run_in_executor( | |
| None, | |
| s3_utils.download_and_upload_video, | |
| pollo_video_url | |
| ) | |
| if s3_video_url: | |
| model_result = { | |
| 'status': 'success', | |
| 'task_id': task_id, | |
| 'video_url': s3_video_url, | |
| 'pollo_video_url': pollo_video_url | |
| } | |
| status_message = f"✅ {display_name}: 生成成功并已保存到S3" | |
| logger.info(f"{display_name}: 完成,释放并发槽位") | |
| return model_name, model_result, s3_video_url, status_message | |
| else: | |
| # 如果S3上传失败,使用原始URL | |
| logger.warning(f"{display_name}: S3上传失败,使用原始URL: {pollo_video_url}") | |
| model_result = { | |
| 'status': 'success', | |
| 'task_id': task_id, | |
| 'video_url': pollo_video_url, | |
| 'warning': 'S3上传失败,使用临时URL' | |
| } | |
| status_message = f"✅ {display_name}: 生成成功(S3上传失败)" | |
| logger.info(f"{display_name}: 完成,释放并发槽位") | |
| return model_name, model_result, pollo_video_url, status_message | |
| break | |
| elif poll_result['status'] == 'failed': | |
| error_msg = poll_result.get('error_message', '未知错误') | |
| model_result = { | |
| 'status': 'failed', | |
| 'error': error_msg | |
| } | |
| status_message = f"❌ {display_name}: {error_msg}" | |
| logger.info(f"{display_name}: 失败,释放并发槽位") | |
| return model_name, model_result, None, status_message | |
| else: | |
| # 处理中,继续等待 | |
| if i == max_polls - 1: | |
| model_result = { | |
| 'status': 'timeout', | |
| 'error': '任务超时' | |
| } | |
| status_message = f"⏱️ {display_name}: 任务超时" | |
| logger.info(f"{display_name}: 超时,释放并发槽位") | |
| return model_name, model_result, None, status_message | |
| else: | |
| await asyncio.sleep(poll_interval) | |
| # 如果没有返回结果,说明出现异常 | |
| raise Exception("轮询未返回有效结果") | |
| except Exception as e: | |
| logger.error(f"生成视频失败 ({display_name}): {e}") | |
| model_result = { | |
| 'status': 'error', | |
| 'error': str(e) | |
| } | |
| status_message = f"❌ {display_name}: {str(e)}" | |
| return model_name, model_result, None, status_message | |
| def generate_videos(prompt: str, input_image: Optional[str], request: gr.Request) -> Tuple[str, Dict[str, Any], Dict[str, str]]: | |
| """ | |
| 生成视频(并行调用多个模型,限制Pollo API并发数为5) | |
| Args: | |
| prompt: 提示词 | |
| input_image: 输入图片路径(可选) | |
| request: Gradio Request 对象 | |
| Returns: | |
| (status_message, model_results, video_urls) 元组 | |
| """ | |
| # 检查用户权限 | |
| username, has_access = check_user_access(request) | |
| if not has_access: | |
| if not username: | |
| return "❌ 请先登录 Hugging Face 账户", {}, {} | |
| else: | |
| calls_today = dataset_manager.get_user_calls_today(username) if dataset_manager else 0 | |
| return f"❌ 您今天的调用次数已用完({calls_today}/{MAX_DAILY_CALLS}),请明天再试", {}, {} | |
| if not video_service: | |
| return "❌ 视频生成服务未配置,请联系管理员", {}, {} | |
| if not prompt or not prompt.strip(): | |
| return "❌ 请输入提示词", {}, {} | |
| # 增加用户调用计数(在生成视频前就计数,避免失败后不计数的问题) | |
| if dataset_manager and username: | |
| try: | |
| dataset_manager.increment_user_calls(username) | |
| logger.info(f"用户 {username} 调用次数+1") | |
| except Exception as e: | |
| logger.warning(f"更新用户调用次数失败: {e}") | |
| try: | |
| # 处理图片上传(如果提供) | |
| image_url = None | |
| if input_image: | |
| # 上传图片到S3,获取公网URL(Pollo API需要URL) | |
| logger.info("上传图片到S3...") | |
| image_url = s3_utils.upload_image_from_path(input_image) | |
| if not image_url: | |
| return "❌ 图片上传到S3失败,请检查S3配置", {}, {} | |
| logger.info(f"图片已上传到S3: {image_url}") | |
| # 使用配置的模型列表 | |
| models = MODELS_TO_CALL | |
| # 创建信号量限制并发数 | |
| semaphore = asyncio.Semaphore(MAX_POLLO_CONCURRENCY) | |
| # 创建异步任务列表 | |
| async def run_parallel_generation(): | |
| tasks = [ | |
| _generate_single_video_async(model_name, prompt, image_url, semaphore) | |
| for model_name in models | |
| ] | |
| # 并行执行所有任务 | |
| return await asyncio.gather(*tasks, return_exceptions=True) | |
| # 运行异步任务 | |
| logger.info(f"开始并行生成视频,最大并发数: {MAX_POLLO_CONCURRENCY}") | |
| results = asyncio.run(run_parallel_generation()) | |
| # 整理结果 | |
| model_results = {} | |
| video_urls = {} | |
| status_messages = [] | |
| for result in results: | |
| if isinstance(result, Exception): | |
| # 捕获异常 | |
| logger.error(f"任务执行异常: {result}") | |
| status_messages.append(f"❌ 任务异常: {str(result)}") | |
| else: | |
| # 正常结果 | |
| model_name, model_result, video_url, status_message = result | |
| model_results[model_name] = model_result | |
| if video_url: | |
| video_urls[model_name] = video_url | |
| status_messages.append(status_message) | |
| status_message = "\n".join(status_messages) if status_messages else "生成完成" | |
| return status_message, model_results, video_urls | |
| except Exception as e: | |
| logger.error(f"生成视频异常: {e}") | |
| return f"❌ 生成视频失败: {str(e)}", {}, {} | |
| def submit_evaluation( | |
| prompt: str, | |
| ranks: Dict[str, int], | |
| model_results: Dict[str, Any], | |
| video_urls: Dict[str, str], | |
| request: gr.Request | |
| ) -> str: | |
| """ | |
| 提交排名 | |
| Args: | |
| prompt: 提示词 | |
| ranks: 排名字典 {model_name: rank} | |
| model_results: 模型结果 | |
| video_urls: 视频URL | |
| request: Gradio Request 对象 | |
| Returns: | |
| 状态消息 | |
| """ | |
| username, has_access = check_user_access(request) | |
| if not username: | |
| return "❌ 请先登录 Hugging Face 账户" | |
| if not ranks: | |
| return "❌ 请至少为一个模型选择排名" | |
| # 验证排名:检查是否有重复的排名 | |
| rank_values = [r for r in ranks.values() if r is not None] | |
| if len(rank_values) != len(set(rank_values)): | |
| return "❌ 排名不能重复!每个模型的排名必须不同" | |
| # 验证排名:如果有成功生成的视频,必须对所有成功的模型排名 | |
| successful_models = [m for m in video_urls.keys()] | |
| if len(rank_values) != len(successful_models): | |
| return f"❌ 请为所有 {len(successful_models)} 个成功生成的模型选择排名" | |
| try: | |
| # 保存到 Dataset | |
| if dataset_manager: | |
| success = dataset_manager.save_evaluation( | |
| username=username, | |
| prompt=prompt, | |
| model_results=model_results, | |
| scores=ranks, # 这里用ranks替代scores,但key名保持兼容 | |
| video_urls=video_urls | |
| ) | |
| if success: | |
| return f"✅ 排名已保存!感谢 {username} 的反馈" | |
| else: | |
| return "❌ 保存排名失败,请重试" | |
| else: | |
| return "❌ Dataset 未配置,无法保存排名" | |
| except Exception as e: | |
| logger.error(f"提交排名失败: {e}") | |
| return f"❌ 提交排名失败: {str(e)}" | |
| def get_user_info(request: gr.Request = None) -> str: | |
| """获取用户信息""" | |
| # 如果没有 request,返回提示 | |
| if not request: | |
| return "⏳ 正在检查登录状态..." | |
| username, has_access = check_user_access(request) | |
| if not username: | |
| return "❌ 请先登录 Hugging Face 账户\n\n提示:点击右上角的 'Login with Hugging Face' 按钮登录" | |
| calls_today = dataset_manager.get_user_calls_today(username) if dataset_manager else 0 | |
| remaining = max(0, MAX_DAILY_CALLS - calls_today) | |
| status_icon = "✅" if has_access else "❌" | |
| return f"{status_icon} 用户: {username}\n📊 今日已用: {calls_today}/{MAX_DAILY_CALLS}\n✨ 剩余次数: {remaining}" | |
| def get_history_html(request: gr.Request = None) -> str: | |
| """获取用户历史记录的HTML""" | |
| if not request: | |
| return "<p>⏳ 正在加载...</p>" | |
| username, _ = check_user_access(request) | |
| if not username: | |
| return "<p>❌ 请先登录 Hugging Face 账户</p>" | |
| if not dataset_manager: | |
| return "<p>❌ Dataset 未配置</p>" | |
| try: | |
| history = dataset_manager.get_user_history(username, limit=10) | |
| if not history: | |
| return f"<p>📭 暂无历史记录(用户: {username})</p>" | |
| html = f"<h3>用户: {username} 的历史记录</h3>" | |
| html += f"<p>共 {len(history)} 条记录</p>" | |
| html += "<div style='max-height: 600px; overflow-y: auto;'>" | |
| for i, record in enumerate(history, 1): | |
| timestamp = record.get('timestamp', '未知时间') | |
| prompt = record.get('prompt', '') | |
| scores = record.get('scores', {}) | |
| video_urls = record.get('video_urls', {}) | |
| html += f""" | |
| <div style='border: 1px solid #ddd; padding: 15px; margin: 10px 0; border-radius: 8px; background: #f9f9f9;'> | |
| <h4>📝 记录 #{i} - {timestamp}</h4> | |
| <p><strong>提示词:</strong> {prompt[:100]}{'...' if len(prompt) > 100 else ''}</p> | |
| <p><strong>排名:</strong></p> | |
| <ul> | |
| """ | |
| # 对排名进行排序显示(按排名从小到大) | |
| sorted_scores = sorted(scores.items(), key=lambda x: x[1]) | |
| for model, rank in sorted_scores: | |
| display_name = MODEL_DISPLAY_NAMES.get(model, model) | |
| rank_emoji = ["🥇", "🥈", "🥉", "4️⃣"][int(rank)-1] if 1 <= int(rank) <= 4 else "❓" | |
| html += f"<li>{rank_emoji} 第 {rank} 名: {display_name}</li>" | |
| html += "</ul>" | |
| if video_urls: | |
| html += "<p><strong>生成的视频:</strong></p>" | |
| for model, url in video_urls.items(): | |
| display_name = MODEL_DISPLAY_NAMES.get(model, model) | |
| html += f'<p>🎬 <a href="{url}" target="_blank">{display_name}</a></p>' | |
| html += "</div>" | |
| html += "</div>" | |
| return html | |
| except Exception as e: | |
| logger.error(f"获取历史记录失败: {e}") | |
| return f"<p>❌ 获取历史记录失败: {str(e)}</p>" | |
| # 创建 Gradio 界面 | |
| with gr.Blocks(title="Video Model Evaluator") as demo: | |
| gr.Markdown(""" | |
| # 🎬 视频生成模型评估系统 | |
| 欢迎使用视频生成模型评估系统!请先登录您的 Hugging Face 账户。 | |
| **功能说明:** | |
| - 每个用户每天最多可调用 4 次 | |
| - 支持多个模型同时生成视频(Sora 2 pro, Seedance Pro, Veo 3.1, Kling 2.6) | |
| - 可以对生成结果进行排名比较 | |
| - 排名数据将保存到 Private Dataset | |
| **评估流程:** | |
| 1. 上传输入图片(可选) | |
| 2. 输入提示词(Prompt) | |
| 3. 系统调用多个模型生成视频 | |
| 4. 对每个模型的生成结果进行排名(1=最好,4=最差) | |
| 5. 提交排名,数据自动保存 | |
| """) | |
| # 添加 HF 登录按钮(Gradio 6.0 OAuth 支持) | |
| # 只在 Space 环境中显示(本地测试时会失败) | |
| if os.getenv("SPACE_ID"): # SPACE_ID 只在 HF Space 中存在 | |
| login_btn = gr.LoginButton() | |
| # 用户信息区域(独立一行,更显眼) | |
| with gr.Row(): | |
| user_info = gr.Textbox( | |
| label="👤 用户状态(点击生成视频时统计次数,提交评分不额外计数)", | |
| interactive=False, | |
| value="⏳ 等待检查登录状态...", | |
| scale=9, | |
| lines=2 | |
| ) | |
| refresh_user_btn = gr.Button("🔄 刷新", size="sm", scale=1) | |
| with gr.Tabs(): | |
| with gr.Tab("🎬 生成视频"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| prompt_input = gr.Textbox( | |
| label="提示词 (Prompt)", | |
| placeholder="输入视频生成的提示词...", | |
| lines=3 | |
| ) | |
| input_image = gr.Image( | |
| label="输入图片 (可选,用于图生视频 i2v)", | |
| type="filepath", | |
| sources=["upload"] | |
| ) | |
| generate_btn = gr.Button("🚀 生成视频", variant="primary", size="lg") | |
| status_output = gr.Textbox( | |
| label="生成状态", | |
| interactive=False, | |
| lines=5 | |
| ) | |
| with gr.Column(scale=1): | |
| # 视频展示区域 | |
| gr.Markdown("### 🎬 生成的视频") | |
| # 动态创建视频和排名组件 | |
| video_components = {} | |
| rank_components = {} | |
| for model_name in MODELS_TO_CALL: | |
| display_name = MODEL_DISPLAY_NAMES.get(model_name, model_name) | |
| video_components[model_name] = gr.HTML( | |
| label=display_name, | |
| visible=False | |
| ) | |
| rank_components[model_name] = gr.Radio( | |
| label=f"{display_name} 排名", | |
| choices=["1 (最好)", "2", "3", "4 (最差)"], | |
| value=None, | |
| visible=False | |
| ) | |
| # 排序区域 | |
| gr.Markdown("### 🏆 排名 (选择每个视频的排名)") | |
| gr.Markdown("💡 提示:给每个模型选择排名(1=最好,4=最差),不能有重复排名") | |
| submit_btn = gr.Button("💾 提交排名", variant="secondary", visible=False) | |
| submit_status = gr.Textbox( | |
| label="提交状态", | |
| interactive=False | |
| ) | |
| # 历史记录Tab | |
| with gr.Tab("📜 历史记录"): | |
| gr.Markdown(""" | |
| ### 📜 我的生成历史 | |
| 这里显示您最近的视频生成记录(最多10条) | |
| """) | |
| history_refresh_btn = gr.Button("🔄 刷新历史记录", variant="primary") | |
| history_output = gr.HTML(label="历史记录", value="<p>点击刷新按钮加载历史记录</p>") | |
| # 存储中间数据 | |
| model_results_state = gr.State({}) | |
| video_urls_state = gr.State({}) | |
| prompt_state = gr.State("") | |
| # 事件处理 | |
| def on_generate(prompt, image, request: gr.Request): | |
| # 更新用户信息 | |
| user_info_text = get_user_info(request) | |
| # 生成视频 | |
| status, results, urls = generate_videos(prompt, image, request) | |
| # 构建输出列表 | |
| outputs = [user_info_text, status] | |
| has_results = len(urls) > 0 | |
| # 为每个模型添加视频和评分组件更新 | |
| for model_name in MODELS_TO_CALL: | |
| video_value = urls.get(model_name, None) | |
| # 返回视频组件更新(使用HTML video标签避免卡顿) | |
| # 重要:每次都要明确设置value,确保覆盖之前的内容 | |
| if video_value: | |
| # 使用HTML video标签,更可控且不会卡顿 | |
| display_name = MODEL_DISPLAY_NAMES.get(model_name, model_name) | |
| video_html = f""" | |
| <div style="margin: 10px 0; padding: 15px; border: 2px solid #4CAF50; border-radius: 10px; background: #f1f8f4;"> | |
| <h4 style="margin-top: 0; color: #4CAF50;">✅ {display_name}</h4> | |
| <video controls style="width: 100%; max-width: 640px; border-radius: 8px;"> | |
| <source src="{video_value}" type="video/mp4"> | |
| Your browser does not support the video tag. | |
| </video> | |
| <p style="margin-top: 10px;"> | |
| <a href="{video_value}" target="_blank" download style="color: #007bff; text-decoration: none;"> | |
| 📥 下载视频 | |
| </a> | | |
| <a href="{video_value}" target="_blank" style="color: #007bff; text-decoration: none;"> | |
| 🔗 在新标签页打开 | |
| </a> | |
| </p> | |
| </div> | |
| """ | |
| outputs.append(gr.update(value=video_html, visible=True)) | |
| # 该模型成功,显示排名选择器 | |
| outputs.append(gr.update(visible=True)) | |
| else: | |
| # 该模型失败或未生成,清空并隐藏 | |
| display_name = MODEL_DISPLAY_NAMES.get(model_name, model_name) | |
| failed_html = f""" | |
| <div style="margin: 10px 0; padding: 15px; border: 2px solid #ff9800; border-radius: 10px; background: #fff3e0;"> | |
| <h4 style="margin-top: 0; color: #ff9800;">⚠️ {display_name}</h4> | |
| <p style="color: #666;">视频生成失败或未完成</p> | |
| </div> | |
| """ | |
| outputs.append(gr.update(value=failed_html, visible=True)) | |
| # 该模型失败,隐藏排名选择器 | |
| outputs.append(gr.update(visible=False)) | |
| # 添加提交按钮可见性、状态变量 | |
| outputs.extend([ | |
| gr.update(visible=has_results), # submit_btn visibility | |
| prompt, # prompt_state | |
| results, # model_results_state | |
| urls # video_urls_state | |
| ]) | |
| return tuple(outputs) | |
| def on_submit(results, urls, prompt, request: gr.Request, *rank_values): | |
| # 构建排名字典 | |
| ranks = {} | |
| for i, model_name in enumerate(MODELS_TO_CALL): | |
| if i < len(rank_values) and model_name in urls: # 只处理成功生成视频的模型 | |
| rank_str = rank_values[i] | |
| if rank_str: # 如果用户选择了排名 | |
| # 从 "1 (最好)" 这样的字符串中提取数字 | |
| rank_num = int(rank_str.split()[0]) | |
| ranks[model_name] = rank_num | |
| status = submit_evaluation(prompt, ranks, results, urls, request) | |
| return status | |
| # 构建输入输出列表 | |
| generate_outputs = [user_info, status_output] | |
| for model_name in MODELS_TO_CALL: | |
| generate_outputs.append(video_components[model_name]) # 视频组件 | |
| generate_outputs.append(rank_components[model_name]) # 排名选择器可见性 | |
| generate_outputs.extend([submit_btn, prompt_state, model_results_state, video_urls_state]) | |
| submit_inputs = [model_results_state, video_urls_state, prompt_state] | |
| submit_inputs.extend([rank_components[model_name] for model_name in MODELS_TO_CALL]) | |
| # 绑定事件 | |
| generate_btn.click( | |
| fn=on_generate, | |
| inputs=[prompt_input, input_image], | |
| outputs=generate_outputs | |
| ) | |
| submit_btn.click( | |
| fn=on_submit, | |
| inputs=submit_inputs, | |
| outputs=[submit_status] | |
| ) | |
| # 注意: gr.Request 会被 Gradio 自动注入到 on_submit 的 request 参数 | |
| # 刷新用户信息按钮 | |
| refresh_user_btn.click( | |
| fn=get_user_info, | |
| inputs=None, | |
| outputs=[user_info] | |
| ) | |
| # 刷新历史记录按钮 | |
| history_refresh_btn.click( | |
| fn=get_history_html, | |
| inputs=None, | |
| outputs=[history_output] | |
| ) | |
| # 页面加载时自动获取用户信息(Gradio会自动注入gr.Request) | |
| demo.load( | |
| fn=get_user_info, | |
| inputs=None, | |
| outputs=[user_info] | |
| ) | |
| # 添加说明 | |
| gr.Markdown(""" | |
| --- | |
| ### 📌 使用提示 | |
| 1. **登录**: 请确保已登录 Hugging Face 账户,登录后点击🔄按钮刷新状态 | |
| 2. **上传图片**: 可选,不上传则为文生视频(t2v),上传则为图生视频(i2v) | |
| 3. **输入提示词**: 描述你希望视频展现的内容 | |
| 4. **等待生成**: 视频生成可能需要几分钟,请耐心等待 | |
| 5. **查看视频**: 生成完成后,视频会显示在右侧 | |
| 6. **排名提交**: 为每个模型选择排名(1=最好,4=最差),不能有重复排名 | |
| """) | |
| # 启动应用 | |
| # 注意:在 Hugging Face Space 中,Gradio 会自动处理认证 | |
| # 需要在 Space Settings 中启用 OAuth | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| theme=gr.themes.Soft(), # Gradio 6.0: theme 参数放在 launch() 中 | |
| ssr_mode=False # 禁用 SSR 模式以避免 asyncio 错误 | |
| ) | |