Spaces:
Running
Running
Update src/refinement/refinement.py
Browse files- src/refinement/refinement.py +41 -26
src/refinement/refinement.py
CHANGED
|
@@ -5,6 +5,7 @@ import json
|
|
| 5 |
import time
|
| 6 |
import PIL.Image
|
| 7 |
import shutil
|
|
|
|
| 8 |
from PIL import Image
|
| 9 |
from pathlib import Path
|
| 10 |
from openai import OpenAI
|
|
@@ -401,8 +402,7 @@ def refine_one_slide(input_path, output_path, prompts, outline, max_iterations,
|
|
| 401 |
take_screenshot(current_input, final_screenshot_path)
|
| 402 |
print(f"\n📷 Final screenshot saved: {final_screenshot_path}")
|
| 403 |
|
| 404 |
-
|
| 405 |
-
def refinement_ppt(input_index, prompts, max_iterations=3, model="gpt-4o", config=None):
|
| 406 |
# 1. 定义路径
|
| 407 |
outline_path = os.path.join(input_index, "outline.json")
|
| 408 |
output_index = os.path.join(input_index, "final")
|
|
@@ -412,7 +412,6 @@ def refinement_ppt(input_index, prompts, max_iterations=3, model="gpt-4o", confi
|
|
| 412 |
os.makedirs(output_index, exist_ok=True)
|
| 413 |
|
| 414 |
# 将图片复制到final/images目录下
|
| 415 |
-
import shutil
|
| 416 |
source_images_dir = os.path.join(input_index, "images")
|
| 417 |
if os.path.exists(source_images_dir):
|
| 418 |
shutil.copytree(source_images_dir, output_index_images, dirs_exist_ok=True)
|
|
@@ -441,25 +440,40 @@ def refinement_ppt(input_index, prompts, max_iterations=3, model="gpt-4o", confi
|
|
| 441 |
|
| 442 |
# 3.2 定义排序 Key:直接提取开头的数字
|
| 443 |
def get_file_number(filename):
|
| 444 |
-
# 因为上一步已经过滤过了,这里可以直接提取
|
| 445 |
return int(filename.split('_')[0])
|
| 446 |
|
| 447 |
# 3.3 执行排序 (这步是关键,确保 2 在 10 前面)
|
| 448 |
sorted_files = sorted(target_files, key=get_file_number)
|
| 449 |
|
| 450 |
-
|
| 451 |
-
print(f"👀 排序后文件列表前5个: {sorted_files[:5]}")
|
| 452 |
|
| 453 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 454 |
for file_name in sorted_files:
|
| 455 |
-
# 直接提取序号
|
| 456 |
num = str(get_file_number(file_name))
|
| 457 |
|
| 458 |
# 获取当前 html 对应的 outline
|
| 459 |
outline = outline_full.get(int(num)-1)
|
| 460 |
|
| 461 |
-
# 【容错逻辑】处理索引偏移
|
| 462 |
-
# 如果 outline 为空,且 num-1 存在,则尝试自动回退
|
| 463 |
if outline is None and str(int(num)-1) in outline_full:
|
| 464 |
print(f"ℹ️ 尝试修正索引: 文件 {num} -> 使用大纲 {int(num)-1}")
|
| 465 |
outline = outline_full.get(str(int(num)-1))
|
|
@@ -468,27 +482,28 @@ def refinement_ppt(input_index, prompts, max_iterations=3, model="gpt-4o", confi
|
|
| 468 |
print(f"⚠️ 跳过 {file_name}: 在 outline.json 中找不到序号 {num} 或 {int(num)-1}")
|
| 469 |
continue
|
| 470 |
|
| 471 |
-
# 构建路径
|
| 472 |
html_file_path = os.path.join(input_index, file_name)
|
| 473 |
html_file_path_refine = os.path.join(output_index, file_name)
|
|
|
|
|
|
|
| 474 |
|
| 475 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 476 |
|
| 477 |
-
#
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
config=config
|
| 487 |
-
)
|
| 488 |
-
except Exception as e:
|
| 489 |
-
print(f"❌ 处理 {file_name} 时出错: {e}")
|
| 490 |
|
| 491 |
-
print(f"✅ 所有文件处理完成,结果保存在: {output_index}")
|
| 492 |
|
| 493 |
def refinement_poster(input_html_path, prompts, output_html_path, model, config=None):
|
| 494 |
# ---------------- 0. 配置准备 ----------------
|
|
|
|
| 5 |
import time
|
| 6 |
import PIL.Image
|
| 7 |
import shutil
|
| 8 |
+
import concurrent.futures
|
| 9 |
from PIL import Image
|
| 10 |
from pathlib import Path
|
| 11 |
from openai import OpenAI
|
|
|
|
| 402 |
take_screenshot(current_input, final_screenshot_path)
|
| 403 |
print(f"\n📷 Final screenshot saved: {final_screenshot_path}")
|
| 404 |
|
| 405 |
+
def refinement_ppt(input_index, prompts, max_iterations=3, model="gpt-4o", config=None, max_workers=5):
|
|
|
|
| 406 |
# 1. 定义路径
|
| 407 |
outline_path = os.path.join(input_index, "outline.json")
|
| 408 |
output_index = os.path.join(input_index, "final")
|
|
|
|
| 412 |
os.makedirs(output_index, exist_ok=True)
|
| 413 |
|
| 414 |
# 将图片复制到final/images目录下
|
|
|
|
| 415 |
source_images_dir = os.path.join(input_index, "images")
|
| 416 |
if os.path.exists(source_images_dir):
|
| 417 |
shutil.copytree(source_images_dir, output_index_images, dirs_exist_ok=True)
|
|
|
|
| 440 |
|
| 441 |
# 3.2 定义排序 Key:直接提取开头的数字
|
| 442 |
def get_file_number(filename):
|
|
|
|
| 443 |
return int(filename.split('_')[0])
|
| 444 |
|
| 445 |
# 3.3 执行排序 (这步是关键,确保 2 在 10 前面)
|
| 446 |
sorted_files = sorted(target_files, key=get_file_number)
|
| 447 |
|
| 448 |
+
print(f"👀 找到文件 {len(sorted_files)} 个,准备进行并发优化...")
|
|
|
|
| 449 |
|
| 450 |
+
# 定义单个任务的处理逻辑
|
| 451 |
+
def _worker(file_name, num, outline, html_input, html_output):
|
| 452 |
+
print(f"📝 [并发处理中] 正在优化: {file_name} (对应大纲 Key: {num})")
|
| 453 |
+
try:
|
| 454 |
+
refine_one_slide(
|
| 455 |
+
input_path=html_input,
|
| 456 |
+
output_path=html_output,
|
| 457 |
+
prompts=prompts,
|
| 458 |
+
outline=outline,
|
| 459 |
+
max_iterations=max_iterations,
|
| 460 |
+
model=model,
|
| 461 |
+
config=config
|
| 462 |
+
)
|
| 463 |
+
return file_name, True, None
|
| 464 |
+
except Exception as e:
|
| 465 |
+
return file_name, False, str(e)
|
| 466 |
+
|
| 467 |
+
# 4. 收集需要并发执行的任务
|
| 468 |
+
tasks = []
|
| 469 |
for file_name in sorted_files:
|
| 470 |
+
# 直接提取序号
|
| 471 |
num = str(get_file_number(file_name))
|
| 472 |
|
| 473 |
# 获取当前 html 对应的 outline
|
| 474 |
outline = outline_full.get(int(num)-1)
|
| 475 |
|
| 476 |
+
# 【容错逻辑】处理索引偏移
|
|
|
|
| 477 |
if outline is None and str(int(num)-1) in outline_full:
|
| 478 |
print(f"ℹ️ 尝试修正索引: 文件 {num} -> 使用大纲 {int(num)-1}")
|
| 479 |
outline = outline_full.get(str(int(num)-1))
|
|
|
|
| 482 |
print(f"⚠️ 跳过 {file_name}: 在 outline.json 中找不到序号 {num} 或 {int(num)-1}")
|
| 483 |
continue
|
| 484 |
|
| 485 |
+
# 构建路径并添加到任务列表
|
| 486 |
html_file_path = os.path.join(input_index, file_name)
|
| 487 |
html_file_path_refine = os.path.join(output_index, file_name)
|
| 488 |
+
|
| 489 |
+
tasks.append((file_name, num, outline, html_file_path, html_file_path_refine))
|
| 490 |
|
| 491 |
+
# 5. 使用线程池并发执行任务
|
| 492 |
+
print(f"⚡ 启动线程池,最大并发数: {max_workers}")
|
| 493 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
| 494 |
+
# 提交所有任务
|
| 495 |
+
futures = {executor.submit(_worker, *task): task for task in tasks}
|
| 496 |
|
| 497 |
+
# 收集结果
|
| 498 |
+
for future in concurrent.futures.as_completed(futures):
|
| 499 |
+
file_name, success, error_msg = future.result()
|
| 500 |
+
if success:
|
| 501 |
+
print(f"✅ 成功完成: {file_name}")
|
| 502 |
+
else:
|
| 503 |
+
print(f"❌ 处理 {file_name} 时出错: {error_msg}")
|
| 504 |
+
|
| 505 |
+
print(f"🎉 所有文件处理完成,结果保存在: {output_index}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 506 |
|
|
|
|
| 507 |
|
| 508 |
def refinement_poster(input_html_path, prompts, output_html_path, model, config=None):
|
| 509 |
# ---------------- 0. 配置准备 ----------------
|