testcoder-ui commited on
Commit
a0d3b53
·
1 Parent(s): eb120a3

feat: 并行视频生成 + 更新数据集路径

Browse files

- 更新数据集路径: video-model-evaluator-cuti/video-evaluations
- 添加异步并行视频生成,使用 asyncio.gather
- 限制 Pollo API 并发数为 5(使用 Semaphore)
- 性能提升:4个模型从串行20+分钟降至并行5-6分钟
- 提交阶段快速并行,轮询阶段受信号量限制

Files changed (1) hide show
  1. app.py +176 -100
app.py CHANGED
@@ -35,9 +35,10 @@ logger = logging.getLogger(__name__)
35
 
36
  # 配置常量
37
  MAX_DAILY_CALLS = 4 # 每个用户每天最多调用次数
38
- DATASET_REPO_ID = "learnmlf/video-evaluations" # Private Dataset 名称
39
  HF_TOKEN = os.getenv("HF_TOKEN", "") # 从 Space Settings 获取
40
  API_KEY = os.getenv("API_KEY", "") # 从 Space Settings 获取
 
41
 
42
  # 支持的模型列表
43
  MODELS_TO_CALL = [
@@ -375,9 +376,145 @@ def check_user_access(request: gr.Request) -> Tuple[str, bool]:
375
  return username, True
376
 
377
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
378
  def generate_videos(prompt: str, input_image: Optional[str], request: gr.Request) -> Tuple[str, Dict[str, Any], Dict[str, str]]:
379
  """
380
- 生成视频(调用多个模型)
381
 
382
  Args:
383
  prompt: 提示词
@@ -412,115 +549,54 @@ def generate_videos(prompt: str, input_image: Optional[str], request: gr.Request
412
  logger.warning(f"更新用户调用次数失败: {e}")
413
 
414
  try:
415
- model_results = {}
416
- video_urls = {}
417
-
418
- # 使用配置的模型列表
419
- models = MODELS_TO_CALL
420
-
421
- status_messages = []
422
-
423
  # 处理图片上传(如果提供)
424
- image_path = None
425
  image_url = None
426
  if input_image:
427
- # Gradio 返回的是临时文件路径
428
- image_path = input_image
429
-
430
  # 上传图片到S3,获取公网URL(Pollo API需要URL)
431
  logger.info("上传图片到S3...")
432
- image_url = s3_utils.upload_image_from_path(image_path)
433
 
434
  if not image_url:
435
  return "❌ 图片上传到S3失败,请检查S3配置", {}, {}
436
 
437
  logger.info(f"图片已上传到S3: {image_url}")
438
 
439
- for model_name in models:
440
- try:
441
- display_name = MODEL_DISPLAY_NAMES.get(model_name, model_name)
442
- logger.info(f"开始生成视频: {display_name} ({model_name}), 提示词: {prompt[:50]}...")
443
-
444
- # 获取对应模型的服务实例
445
- service = get_pollo_service(model_name)
446
-
447
- # 根据是否有图片选择模式
448
- mode = "i2v" if image_url else "t2v"
449
-
450
- # 使用S3 URL而不是本地路径
451
- result = service.generate_video(
452
- prompt=prompt,
453
- mode=mode, # 根据是否有图片自动选择 i2v 或 t2v
454
- input_image_path=image_url if image_url else None, # 使用S3 URL
455
- video_length=5,
456
- width=1280,
457
- height=720
458
- )
459
-
460
- task_id = result.get('pollo_task_id')
461
- if task_id:
462
- # 轮询任务结果
463
- max_polls = 60
464
- poll_interval = 10
465
-
466
- for i in range(max_polls):
467
- poll_result = service.poll_task_result(task_id)
468
-
469
- if poll_result['status'] == 'completed':
470
- pollo_video_url = poll_result.get('video_url')
471
- if pollo_video_url:
472
- # 下载视频并上传到S3(Pollo的视频只保存一段时间)
473
- logger.info(f"下载视频并上传到S3: {pollo_video_url}")
474
- s3_video_url = s3_utils.download_and_upload_video(pollo_video_url)
475
-
476
- if s3_video_url:
477
- video_urls[model_name] = s3_video_url
478
- model_results[model_name] = {
479
- 'status': 'success',
480
- 'task_id': task_id,
481
- 'video_url': s3_video_url,
482
- 'pollo_video_url': pollo_video_url # 保留原始URL
483
- }
484
- status_messages.append(f"✅ {display_name}: 生成成功并已保存到S3")
485
- else:
486
- # 如果S3上传失败,使用原始URL
487
- logger.warning(f"S3上传失败,使用原始URL: {pollo_video_url}")
488
- video_urls[model_name] = pollo_video_url
489
- model_results[model_name] = {
490
- 'status': 'success',
491
- 'task_id': task_id,
492
- 'video_url': pollo_video_url,
493
- 'warning': 'S3上传失败,使用临时URL'
494
- }
495
- status_messages.append(f"✅ {display_name}: 生成成功(S3上传失败)")
496
- break
497
- elif poll_result['status'] == 'failed':
498
- error_msg = poll_result.get('error_message', '未知错误')
499
- model_results[model_name] = {
500
- 'status': 'failed',
501
- 'error': error_msg
502
- }
503
- status_messages.append(f"❌ {display_name}: {error_msg}")
504
- break
505
- else:
506
- # 处理中,继续等待
507
- if i == max_polls - 1:
508
- model_results[model_name] = {
509
- 'status': 'timeout',
510
- 'error': '任务超时'
511
- }
512
- status_messages.append(f"⏱️ {display_name}: 任务超时")
513
- else:
514
- time.sleep(poll_interval)
515
-
516
- except Exception as e:
517
- display_name = MODEL_DISPLAY_NAMES.get(model_name, model_name)
518
- logger.error(f"生成视频失败 ({display_name}): {e}")
519
- model_results[model_name] = {
520
- 'status': 'error',
521
- 'error': str(e)
522
- }
523
- status_messages.append(f"❌ {display_name}: {str(e)}")
524
 
525
  status_message = "\n".join(status_messages) if status_messages else "生成完成"
526
 
 
35
 
36
  # 配置常量
37
  MAX_DAILY_CALLS = 4 # 每个用户每天最多调用次数
38
+ DATASET_REPO_ID = "video-model-evaluator-cuti/video-evaluations" # Private Dataset 名称
39
  HF_TOKEN = os.getenv("HF_TOKEN", "") # 从 Space Settings 获取
40
  API_KEY = os.getenv("API_KEY", "") # 从 Space Settings 获取
41
+ MAX_POLLO_CONCURRENCY = 5 # Pollo API 最大并发数
42
 
43
  # 支持的模型列表
44
  MODELS_TO_CALL = [
 
376
  return username, True
377
 
378
 
379
+ async def _generate_single_video_async(
380
+ model_name: str,
381
+ prompt: str,
382
+ image_url: Optional[str],
383
+ semaphore: asyncio.Semaphore
384
+ ) -> Tuple[str, Dict[str, Any], Optional[str], str]:
385
+ """
386
+ 异步生成单个模型的视频(使用信号量限制并发)
387
+
388
+ Args:
389
+ model_name: 模型名称
390
+ prompt: 提示词
391
+ image_url: 图片URL(可选)
392
+ semaphore: asyncio信号量,用于限制并发数
393
+
394
+ Returns:
395
+ (model_name, model_result, video_url, status_message) 元组
396
+ """
397
+ display_name = MODEL_DISPLAY_NAMES.get(model_name, model_name)
398
+
399
+ try:
400
+ logger.info(f"开始生成视频: {display_name} ({model_name}), 提示词: {prompt[:50]}...")
401
+
402
+ # 获取对应模型的服务实例
403
+ service = get_pollo_service(model_name)
404
+
405
+ # 根据是否有图片选择模式
406
+ mode = "i2v" if image_url else "t2v"
407
+
408
+ # 提交任务(快速,不需要限制并发)
409
+ loop = asyncio.get_event_loop()
410
+ result = await loop.run_in_executor(
411
+ None,
412
+ lambda: service.generate_video(
413
+ prompt=prompt,
414
+ mode=mode,
415
+ input_image_path=image_url if image_url else None,
416
+ video_length=5,
417
+ width=1280,
418
+ height=720
419
+ )
420
+ )
421
+
422
+ task_id = result.get('pollo_task_id')
423
+ if not task_id:
424
+ raise Exception("未获取到任务ID")
425
+
426
+ logger.info(f"{display_name}: 任务已提交,task_id={task_id}")
427
+
428
+ # 使用信号量限制轮询并发数
429
+ async with semaphore:
430
+ logger.info(f"{display_name}: 开始轮询(当前并发槽位已占用)")
431
+
432
+ # 轮询任务结果
433
+ max_polls = 60
434
+ poll_interval = 10
435
+
436
+ for i in range(max_polls):
437
+ # 在线程池中执行同步的轮询操作
438
+ poll_result = await loop.run_in_executor(
439
+ None,
440
+ service.poll_task_result,
441
+ task_id
442
+ )
443
+
444
+ if poll_result['status'] == 'completed':
445
+ pollo_video_url = poll_result.get('video_url')
446
+ if pollo_video_url:
447
+ # 下载视频并上传到S3(在线程池中执行)
448
+ logger.info(f"{display_name}: 下载视频并上传到S3: {pollo_video_url}")
449
+ s3_video_url = await loop.run_in_executor(
450
+ None,
451
+ s3_utils.download_and_upload_video,
452
+ pollo_video_url
453
+ )
454
+
455
+ if s3_video_url:
456
+ model_result = {
457
+ 'status': 'success',
458
+ 'task_id': task_id,
459
+ 'video_url': s3_video_url,
460
+ 'pollo_video_url': pollo_video_url
461
+ }
462
+ status_message = f"✅ {display_name}: 生成成功并已保存到S3"
463
+ logger.info(f"{display_name}: 完成,释放并发槽位")
464
+ return model_name, model_result, s3_video_url, status_message
465
+ else:
466
+ # 如果S3上传失败,使用原始URL
467
+ logger.warning(f"{display_name}: S3上传失败,使用原始URL: {pollo_video_url}")
468
+ model_result = {
469
+ 'status': 'success',
470
+ 'task_id': task_id,
471
+ 'video_url': pollo_video_url,
472
+ 'warning': 'S3上传失败,使用临时URL'
473
+ }
474
+ status_message = f"✅ {display_name}: 生成成功(S3上传失败)"
475
+ logger.info(f"{display_name}: 完成,释放并发槽位")
476
+ return model_name, model_result, pollo_video_url, status_message
477
+ break
478
+
479
+ elif poll_result['status'] == 'failed':
480
+ error_msg = poll_result.get('error_message', '未知错误')
481
+ model_result = {
482
+ 'status': 'failed',
483
+ 'error': error_msg
484
+ }
485
+ status_message = f"❌ {display_name}: {error_msg}"
486
+ logger.info(f"{display_name}: 失败,释放并发槽位")
487
+ return model_name, model_result, None, status_message
488
+
489
+ else:
490
+ # 处理中,继续等待
491
+ if i == max_polls - 1:
492
+ model_result = {
493
+ 'status': 'timeout',
494
+ 'error': '任务超时'
495
+ }
496
+ status_message = f"⏱️ {display_name}: 任务超时"
497
+ logger.info(f"{display_name}: 超时,释放并发槽位")
498
+ return model_name, model_result, None, status_message
499
+ else:
500
+ await asyncio.sleep(poll_interval)
501
+
502
+ # 如果没有返回结果,说明出现异常
503
+ raise Exception("轮询未返回有效结果")
504
+
505
+ except Exception as e:
506
+ logger.error(f"生成视频失败 ({display_name}): {e}")
507
+ model_result = {
508
+ 'status': 'error',
509
+ 'error': str(e)
510
+ }
511
+ status_message = f"❌ {display_name}: {str(e)}"
512
+ return model_name, model_result, None, status_message
513
+
514
+
515
  def generate_videos(prompt: str, input_image: Optional[str], request: gr.Request) -> Tuple[str, Dict[str, Any], Dict[str, str]]:
516
  """
517
+ 生成视频(并行调用多个模型,限制Pollo API并发数为5)
518
 
519
  Args:
520
  prompt: 提示词
 
549
  logger.warning(f"更新用户调用次数失败: {e}")
550
 
551
  try:
 
 
 
 
 
 
 
 
552
  # 处理图片上传(如果提供)
 
553
  image_url = None
554
  if input_image:
 
 
 
555
  # 上传图片到S3,获取公网URL(Pollo API需要URL)
556
  logger.info("上传图片到S3...")
557
+ image_url = s3_utils.upload_image_from_path(input_image)
558
 
559
  if not image_url:
560
  return "❌ 图片上传到S3失败,请检查S3配置", {}, {}
561
 
562
  logger.info(f"图片已上传到S3: {image_url}")
563
 
564
+ # 使用配置的模型列表
565
+ models = MODELS_TO_CALL
566
+
567
+ # 创建信号量限制并发数
568
+ semaphore = asyncio.Semaphore(MAX_POLLO_CONCURRENCY)
569
+
570
+ # 创建异步任务列表
571
+ async def run_parallel_generation():
572
+ tasks = [
573
+ _generate_single_video_async(model_name, prompt, image_url, semaphore)
574
+ for model_name in models
575
+ ]
576
+ # 并行执行所有任务
577
+ return await asyncio.gather(*tasks, return_exceptions=True)
578
+
579
+ # 运行异步任务
580
+ logger.info(f"开始并行生成视频,最大并发数: {MAX_POLLO_CONCURRENCY}")
581
+ results = asyncio.run(run_parallel_generation())
582
+
583
+ # 整理结果
584
+ model_results = {}
585
+ video_urls = {}
586
+ status_messages = []
587
+
588
+ for result in results:
589
+ if isinstance(result, Exception):
590
+ # 捕获异常
591
+ logger.error(f"任务执行异常: {result}")
592
+ status_messages.append(f"❌ 任务异常: {str(result)}")
593
+ else:
594
+ # 正常结果
595
+ model_name, model_result, video_url, status_message = result
596
+ model_results[model_name] = model_result
597
+ if video_url:
598
+ video_urls[model_name] = video_url
599
+ status_messages.append(status_message)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
600
 
601
  status_message = "\n".join(status_messages) if status_messages else "生成完成"
602