Laramie2 commited on
Commit
338ff5b
·
verified ·
1 Parent(s): da7bfe5

Update src/refinement/refinement.py

Browse files
Files changed (1) hide show
  1. src/refinement/refinement.py +41 -26
src/refinement/refinement.py CHANGED
@@ -5,6 +5,7 @@ import json
5
  import time
6
  import PIL.Image
7
  import shutil
 
8
  from PIL import Image
9
  from pathlib import Path
10
  from openai import OpenAI
@@ -401,8 +402,7 @@ def refine_one_slide(input_path, output_path, prompts, outline, max_iterations,
401
  take_screenshot(current_input, final_screenshot_path)
402
  print(f"\n📷 Final screenshot saved: {final_screenshot_path}")
403
 
404
-
405
- def refinement_ppt(input_index, prompts, max_iterations=3, model="gpt-4o", config=None):
406
  # 1. 定义路径
407
  outline_path = os.path.join(input_index, "outline.json")
408
  output_index = os.path.join(input_index, "final")
@@ -412,7 +412,6 @@ def refinement_ppt(input_index, prompts, max_iterations=3, model="gpt-4o", confi
412
  os.makedirs(output_index, exist_ok=True)
413
 
414
  # 将图片复制到final/images目录下
415
- import shutil
416
  source_images_dir = os.path.join(input_index, "images")
417
  if os.path.exists(source_images_dir):
418
  shutil.copytree(source_images_dir, output_index_images, dirs_exist_ok=True)
@@ -441,25 +440,40 @@ def refinement_ppt(input_index, prompts, max_iterations=3, model="gpt-4o", confi
441
 
442
  # 3.2 定义排序 Key:直接提取开头的数字
443
  def get_file_number(filename):
444
- # 因为上一步已经过滤过了,这里可以直接提取
445
  return int(filename.split('_')[0])
446
 
447
  # 3.3 执行排序 (这步是关键,确保 2 在 10 前面)
448
  sorted_files = sorted(target_files, key=get_file_number)
449
 
450
- # Debug: 打印前几个文件确认顺序
451
- print(f"👀 排序后文件列表前5个: {sorted_files[:5]}")
452
 
453
- # 4. 遍历排序后列表
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
454
  for file_name in sorted_files:
455
- # 直接提取序号 (之前已经验证过格式了)
456
  num = str(get_file_number(file_name))
457
 
458
  # 获取当前 html 对应的 outline
459
  outline = outline_full.get(int(num)-1)
460
 
461
- # 【容错逻辑】处理索引偏移 (例如文件是 1_ppt,但列表是从 0 开始)
462
- # 如果 outline 为空,且 num-1 存在,则尝试自动回退
463
  if outline is None and str(int(num)-1) in outline_full:
464
  print(f"ℹ️ 尝试修正索引: 文件 {num} -> 使用大纲 {int(num)-1}")
465
  outline = outline_full.get(str(int(num)-1))
@@ -468,27 +482,28 @@ def refinement_ppt(input_index, prompts, max_iterations=3, model="gpt-4o", confi
468
  print(f"⚠️ 跳过 {file_name}: 在 outline.json 中找不到序号 {num} 或 {int(num)-1}")
469
  continue
470
 
471
- # 构建路径
472
  html_file_path = os.path.join(input_index, file_name)
473
  html_file_path_refine = os.path.join(output_index, file_name)
 
 
474
 
475
- print(f"📝 [顺序处理中] 正在优化: {file_name} (对应大纲 Key: {num})")
 
 
 
 
476
 
477
- # 6. 调用优化函数
478
- try:
479
- refine_one_slide(
480
- input_path=html_file_path,
481
- output_path=html_file_path_refine,
482
- prompts=prompts,
483
- outline=outline,
484
- max_iterations=max_iterations,
485
- model=model,
486
- config=config
487
- )
488
- except Exception as e:
489
- print(f"❌ 处理 {file_name} 时出错: {e}")
490
 
491
- print(f"✅ 所有文件处理完成,结果保存在: {output_index}")
492
 
493
  def refinement_poster(input_html_path, prompts, output_html_path, model, config=None):
494
  # ---------------- 0. 配置准备 ----------------
 
5
  import time
6
  import PIL.Image
7
  import shutil
8
+ import concurrent.futures
9
  from PIL import Image
10
  from pathlib import Path
11
  from openai import OpenAI
 
402
  take_screenshot(current_input, final_screenshot_path)
403
  print(f"\n📷 Final screenshot saved: {final_screenshot_path}")
404
 
405
+ def refinement_ppt(input_index, prompts, max_iterations=3, model="gpt-4o", config=None, max_workers=5):
 
406
  # 1. 定义路径
407
  outline_path = os.path.join(input_index, "outline.json")
408
  output_index = os.path.join(input_index, "final")
 
412
  os.makedirs(output_index, exist_ok=True)
413
 
414
  # 将图片复制到final/images目录下
 
415
  source_images_dir = os.path.join(input_index, "images")
416
  if os.path.exists(source_images_dir):
417
  shutil.copytree(source_images_dir, output_index_images, dirs_exist_ok=True)
 
440
 
441
  # 3.2 定义排序 Key:直接提取开头的数字
442
  def get_file_number(filename):
 
443
  return int(filename.split('_')[0])
444
 
445
  # 3.3 执行排序 (这步是关键,确保 2 在 10 前面)
446
  sorted_files = sorted(target_files, key=get_file_number)
447
 
448
+ print(f"👀 找到文件 {len(sorted_files)} 个,准备进行并发优化...")
 
449
 
450
+ # 定义单个任务处理逻辑
451
+ def _worker(file_name, num, outline, html_input, html_output):
452
+ print(f"📝 [并发处理中] 正在优化: {file_name} (对应大纲 Key: {num})")
453
+ try:
454
+ refine_one_slide(
455
+ input_path=html_input,
456
+ output_path=html_output,
457
+ prompts=prompts,
458
+ outline=outline,
459
+ max_iterations=max_iterations,
460
+ model=model,
461
+ config=config
462
+ )
463
+ return file_name, True, None
464
+ except Exception as e:
465
+ return file_name, False, str(e)
466
+
467
+ # 4. 收集需要并发执行的任务
468
+ tasks = []
469
  for file_name in sorted_files:
470
+ # 直接提取序号
471
  num = str(get_file_number(file_name))
472
 
473
  # 获取当前 html 对应的 outline
474
  outline = outline_full.get(int(num)-1)
475
 
476
+ # 【容错逻辑】处理索引偏移
 
477
  if outline is None and str(int(num)-1) in outline_full:
478
  print(f"ℹ️ 尝试修正索引: 文件 {num} -> 使用大纲 {int(num)-1}")
479
  outline = outline_full.get(str(int(num)-1))
 
482
  print(f"⚠️ 跳过 {file_name}: 在 outline.json 中找不到序号 {num} 或 {int(num)-1}")
483
  continue
484
 
485
+ # 构建路径并添加到任务列表
486
  html_file_path = os.path.join(input_index, file_name)
487
  html_file_path_refine = os.path.join(output_index, file_name)
488
+
489
+ tasks.append((file_name, num, outline, html_file_path, html_file_path_refine))
490
 
491
+ # 5. 使用线程池并发执行任务
492
+ print(f"⚡ 启动线程池,最大并发数: {max_workers}")
493
+ with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
494
+ # 提交所有任务
495
+ futures = {executor.submit(_worker, *task): task for task in tasks}
496
 
497
+ # 收集结果
498
+ for future in concurrent.futures.as_completed(futures):
499
+ file_name, success, error_msg = future.result()
500
+ if success:
501
+ print(f"✅ 成功完成: {file_name}")
502
+ else:
503
+ print(f"❌ 处理 {file_name} 时出错: {error_msg}")
504
+
505
+ print(f"🎉 所有文件处理完成,结果保存在: {output_index}")
 
 
 
 
506
 
 
507
 
508
  def refinement_poster(input_html_path, prompts, output_html_path, model, config=None):
509
  # ---------------- 0. 配置准备 ----------------