Laramie2 commited on
Commit
c91a28c
·
verified ·
1 Parent(s): bea0e3a

Update src/refinement/refinement.py

Browse files
Files changed (1) hide show
  1. src/refinement/refinement.py +835 -818
src/refinement/refinement.py CHANGED
@@ -1,818 +1,835 @@
1
- import base64
2
- import os
3
- import re
4
- import json
5
- import time
6
- import PIL.Image
7
- import shutil
8
- from PIL import Image
9
- from pathlib import Path
10
- from openai import OpenAI
11
- from google import genai
12
- from google.genai import types
13
- from .html_revise import HTMLMapper, apply_html_modifications, HTMLModificationError
14
- from playwright.sync_api import sync_playwright
15
-
16
-
17
- class VLMCommenter:
18
- def __init__(self, api_key, prompt, provider="openai", model_name=None):
19
- """
20
- :param api_key: API Key
21
- :param prompt: 提示词文本
22
- :param provider: "openai" 或 "gemini"
23
- :param model_name: 指定模型名称 (可选)
24
- """
25
- self.provider = provider.lower()
26
- self.api_key = api_key
27
- self.model_name = model_name
28
- self.prompt_text = prompt
29
-
30
- if self.provider == "openai":
31
- self.client = OpenAI(api_key=api_key)
32
- self.model = model_name if model_name else "gpt-4o"
33
- elif self.provider == "gemini":
34
- self.client = genai.Client(api_key=api_key)
35
- self.model = model_name if model_name else "gemini-1.5-flash"
36
- else:
37
- raise ValueError("Unsupported provider. Choose 'openai' or 'gemini'.")
38
-
39
- def _encode_image(self, image_path):
40
- with open(image_path, "rb") as image_file:
41
- return base64.b64encode(image_file.read()).decode('utf-8')
42
-
43
- def evaluate_slide(self, image_path, outline, pre_comments):
44
- """
45
- 输入:截图路径
46
- 输出:诊断文本 string
47
- """
48
- prompt_text = self.prompt_text
49
- full_prompt = f"{prompt_text}\n \
50
- *****previous comments******\
51
- \n{pre_comments} \
52
- *****begin of the outline*****\
53
- \n{outline} \
54
- *****end of the outline*****\
55
- *****the following is the image,not the outline*****"
56
-
57
- if not full_prompt:
58
- return "Error: Commenter prompt is empty."
59
-
60
- if self.provider == "openai":
61
- base64_image = self._encode_image(image_path)
62
- try:
63
- response = self.client.chat.completions.create(
64
- model=self.model,
65
- messages=[
66
- {"role": "system", "content": "You are a helpful assistant."},
67
- {"role": "user", "content": [
68
- {"type": "text", "text": full_prompt},
69
- {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{base64_image}"}}
70
- ]}
71
- ],
72
- max_tokens=300
73
- )
74
- return response.choices[0].message.content
75
- except Exception as e:
76
- return f"Error using OpenAI VLM: {e}"
77
-
78
- elif self.provider == "gemini":
79
- try:
80
- img = PIL.Image.open(image_path)
81
-
82
- response = self.client.models.generate_content(
83
- model=self.model,
84
- contents=[full_prompt, img]
85
- )
86
- return response.text
87
- except Exception as e:
88
- return f"Error using Gemini VLM (google-genai): {e}"
89
-
90
-
91
- class LLMReviser:
92
- def __init__(self, api_key, prompt, provider="openai", model_name=None):
93
- """
94
- :param api_key: API Key
95
- :param prompt: 提示词文本
96
- :param provider: "openai" 或 "gemini"
97
- :param model_name: 指定模型名称
98
- """
99
- self.provider = provider.lower()
100
- self.api_key = api_key
101
- self.model_name = model_name
102
- self.system_prompt = prompt
103
-
104
- if self.provider == "openai":
105
- self.client = OpenAI(api_key=api_key)
106
- self.model = model_name if model_name else "gpt-4"
107
- elif self.provider == "gemini":
108
- self.client = genai.Client(api_key=api_key)
109
- self.model = model_name if model_name else "gemini-1.5-pro"
110
- else:
111
- raise ValueError("Unsupported provider. Choose 'openai' or 'gemini'.")
112
-
113
- def generate_revision_plan(self, current_structure_json, vlm_critique):
114
- """
115
- 输入:HTML 结构 JSON VLM 评价
116
- 输出:修改后的 JSON
117
- """
118
-
119
- if "PASS" in vlm_critique.upper() and len(vlm_critique) < 10:
120
- return None
121
-
122
- prompt_system = self.system_prompt
123
- if not prompt_system:
124
- print("Error: Reviser prompt is empty.")
125
- return None
126
-
127
- user_content = f"""
128
- --- CURRENT STRUCTURE JSON ---
129
- {json.dumps(current_structure_json, indent=2)}
130
-
131
- --- VISUAL CRITIQUE ---
132
- {vlm_critique}
133
-
134
- --- INSTRUCTION ---
135
- Generate the modification JSON based on the system instructions.
136
- """
137
-
138
- if self.provider == "openai":
139
- try:
140
- response = self.client.chat.completions.create(
141
- model=self.model,
142
- messages=[
143
- {"role": "system", "content": prompt_system},
144
- {"role": "user", "content": user_content}
145
- ],
146
- response_format={"type": "json_object"}
147
- )
148
- return json.loads(response.choices[0].message.content)
149
- except Exception as e:
150
- print(f"OpenAI Error: {e}")
151
- return None
152
-
153
- elif self.provider == "gemini":
154
- try:
155
- # 拼接 System Prompt 和 User Content
156
- full_prompt = f"{prompt_system}\n\n{user_content}"
157
-
158
- response = self.client.models.generate_content(
159
- model=self.model,
160
- contents=full_prompt,
161
- config=types.GenerateContentConfig(
162
- response_mime_type="application/json"
163
- )
164
- )
165
-
166
- text_response = response.text
167
-
168
- # 清洗可能存在的 Markdown 标记 (即使指定了 JSON mime type,有些模型仍可能加 ```json)
169
- if text_response.startswith("```"):
170
- text_response = text_response.strip("`").replace("json", "").strip()
171
-
172
- return json.loads(text_response)
173
- except json.JSONDecodeError:
174
- print(f"Gemini returned invalid JSON: {response.text}")
175
- return None
176
- except Exception as e:
177
- print(f"Gemini Error (google-genai): {e}")
178
- return None
179
-
180
-
181
- def take_screenshot(html_path, output_path):
182
- """简单的截图工具函数示例 (Playwright)"""
183
- if not os.path.exists(html_path):
184
- print(f"错误:文件不存在于 {html_path}")
185
- abs_path = Path(os.path.abspath(html_path)).as_uri()
186
-
187
- with sync_playwright() as p:
188
- # 1. 显式设置 device_scale_factor=1
189
- browser = p.chromium.launch()
190
- context = browser.new_context(
191
- viewport={'width': 960, 'height': 540},
192
- device_scale_factor=1
193
- )
194
- page = context.new_page()
195
-
196
- # 2. 访问页面
197
- page.goto(abs_path, wait_until="networkidle") # 确保图片和字体加载完成
198
-
199
- # 3. 截取特定元素而非全屏,这样最保险
200
- # 你的根 div id 是 slide1,或者直接截取 svg
201
- element = page.locator(".slideImage")
202
- element.screenshot(path=output_path)
203
-
204
- browser.close()
205
-
206
- def take_screenshot_poster(html_path, output_path):
207
- """适配 .poster/#flow 的 HTML 海报截(Playwright, 同步)"""
208
- if not os.path.exists(html_path):
209
- raise FileNotFoundError(f"文件不存在: {html_path}")
210
-
211
- abs_uri = Path(os.path.abspath(html_path)).as_uri()
212
-
213
- with sync_playwright() as p:
214
- browser = p.chromium.launch(headless=True, args=["--disable-dev-shm-usage"])
215
- context = browser.new_context(
216
- # 你的 CSS 固定 --poster-width/height=1400x900
217
- viewport={"width": 1400, "height": 900},
218
- device_scale_factor=1
219
- )
220
- page = context.new_page()
221
-
222
- # 1) 先 DOMReady,避免 networkidle 卡死
223
- page.goto(abs_uri, wait_until="domcontentloaded")
224
-
225
- # 2) 确保关键容器存在
226
- page.wait_for_selector(".poster", state="attached", timeout=30000)
227
- page.wait_for_selector("#flow", state="attached", timeout=30000)
228
-
229
- # 3) 等字体就绪(你的脚本里 fit 受字体/排版影响很大)
230
- try:
231
- page.evaluate("() => document.fonts ? document.fonts.ready : Promise.resolve()")
232
- except Exception:
233
- pass
234
-
235
- # 4) 等 flow 内所有图片加载完成(没有图片也会立即返回)
236
- page.evaluate(r"""
237
- () => {
238
- const flow = document.getElementById("flow");
239
- if (!flow) return Promise.resolve();
240
- const imgs = Array.from(flow.querySelectorAll("img"));
241
- if (imgs.length === 0) return Promise.resolve();
242
- return Promise.all(imgs.map(img => {
243
- if (img.complete) return Promise.resolve();
244
- return new Promise(res => {
245
- img.addEventListener("load", res, { once: true });
246
- img.addEventListener("error", res, { once: true });
247
- });
248
- }));
249
- }
250
- """)
251
-
252
- # 5) 等你的 fit() 执行并让布局稳定:等几帧 + scrollWidth 不再变化
253
- page.evaluate(r"""
254
- () => new Promise((resolve) => {
255
- const flow = document.getElementById("flow");
256
- if (!flow) return resolve();
257
-
258
- let last = -1;
259
- let stableCount = 0;
260
-
261
- function tick() {
262
- const cur = flow.scrollWidth; // multi-column 溢出判据
263
- if (cur === last) stableCount += 1;
264
- else stableCount = 0;
265
-
266
- last = cur;
267
-
268
- // 连续若干帧稳定,就认为 fit/重排结束
269
- if (stableCount >= 10) return resolve();
270
- requestAnimationFrame(tick);
271
- }
272
-
273
- // load 事件/fit 一点点启动时间
274
- setTimeout(() => requestAnimationFrame(tick), 50);
275
- })
276
- """)
277
-
278
- # 6) 截图:截 .poster(不截 stage 背景)
279
- poster = page.locator(".poster").first
280
- poster.screenshot(path=output_path, timeout=60000)
281
-
282
- browser.close()
283
-
284
-
285
- def load_prompt(prompt_path="prompt.json", prompt_name="poster_prompt"):
286
- with open(prompt_path, "r", encoding="utf-8") as f:
287
- data = json.load(f)
288
- return data.get(prompt_name, "")
289
-
290
-
291
- def refine_one_slide(input_path, output_path, prompts, outline, max_iterations, model, config):
292
- """
293
- 自动修复闭环:截图 -> 诊断 -> 修改 -> 循环
294
- """
295
- is_gemini = "gemini" in model.lower()
296
-
297
- if is_gemini:
298
- api_key = config['api_keys'].get('gemini_api_key')
299
- else:
300
- api_key = config['api_keys'].get('openai_api_key')
301
-
302
- commenter_prompt = prompts[0]
303
- reviser_prompt = prompts[1]
304
-
305
- platform = "gemini" if "gemini" in model.lower() else "openai"
306
-
307
- vlm = VLMCommenter(api_key, commenter_prompt, provider=platform, model_name=model)
308
- reviser = LLMReviser(api_key, reviser_prompt, provider=platform, model_name=model)
309
-
310
- current_input = input_path
311
- critic_his = ""
312
-
313
- for i in range(max_iterations):
314
- print(f"\n=== Iteration {i+1} ===")
315
-
316
- # 1. 渲染并截图 (这里用伪代码表示,实际可用 Selenium/Playwright)
317
- screenshot_path = f"{Path(output_path).parent}/{Path(current_input).stem}_{i+1}.png" # 临时截图路径
318
- take_screenshot(current_input, screenshot_path)
319
- print(f"Screenshot taken: {screenshot_path}")
320
-
321
- # 2. VLM 视觉诊断
322
- critique = vlm.evaluate_slide(screenshot_path, outline, critic_his)
323
- critic_his = critic_his + f"this is the {i}th comment: {critique}"
324
- print(f"VLM Critique: {critique}")
325
-
326
- if "PASS" in critique:
327
- print("Layout looks good! Stopping loop.")
328
-
329
- if os.path.abspath(current_input) != os.path.abspath(output_path):
330
- os.makedirs(os.path.dirname(output_path), exist_ok=True)
331
- shutil.copy2(current_input, output_path)
332
- print(f"Final result saved to: {output_path}")
333
- else:
334
- print(f"Result is already at output path: {output_path}")
335
-
336
- break
337
-
338
- # 3. 读取当前 HTML 结构
339
- mapper = HTMLMapper(current_input)
340
- current_tree = mapper.get_structure_tree()
341
-
342
- # 4. LLM 生成修改方案
343
- modification_json = reviser.generate_revision_plan(current_tree, critique)
344
-
345
- if not modification_json:
346
- print("Reviser suggested no changes. Stopping.")
347
- break
348
-
349
- print(f"Proposed Changes: {json.dumps(modification_json, indent=2)}")
350
-
351
- # 5. 执行修改
352
- try:
353
- # 调用 reviser,如果有严重错误,它现在会抛出 HTMLModificationError
354
- apply_html_modifications(current_input, output_path, modification_json)
355
- print("Modifications applied to HTML.")
356
-
357
- except (HTMLModificationError, Exception) as e:
358
- # 捕获自定义错误 或 其他意外错误
359
- print(f" Error applying modifications at iteration {i+1}: {e}")
360
-
361
- # ====== 添加另存副本逻辑 ======
362
- try:
363
- # 获取 input_path (最原始文件名)
364
- base_name = Path(input_path).stem
365
-
366
- # 定义错误副本路径
367
- error_backup_path = f"{Path(output_path).parent}/{base_name}_FAILED_iter{i+1}.html"
368
-
369
- # 将导致出错的那个 HTML 文件 (current_input) 复制出来
370
- shutil.copy2(current_input, error_backup_path)
371
-
372
- print(f"⚠️ 已自动保存出错前的 HTML 副本: {error_backup_path}")
373
- print(f" 你可以打开此文件,并使用控制台打印的 JSON 尝试复现问题。")
374
-
375
- except Exception as copy_err:
376
- print(f"❌ 尝试保存错误副本时失败: {copy_err}")
377
- # ============================
378
-
379
- # 出错后中断循环
380
- break
381
- current_input = output_path
382
-
383
- # 等待一会防止发读写问题
384
- time.sleep(1)
385
-
386
- # 优化结束后输出最终截图
387
- final_screenshot_path = f"{Path(output_path).parent}/{Path(current_input).stem}_final.png"
388
- take_screenshot(current_input, final_screenshot_path)
389
- print(f"\n📷 Final screenshot saved: {final_screenshot_path}")
390
-
391
-
392
- def refinement_ppt(input_index, prompts, max_iterations=3, model="gpt-4o", config=None):
393
- # 1. 定义路径
394
- outline_path = os.path.join(input_index, "outline.json")
395
- output_index = os.path.join(input_index, "final")
396
- output_index_images = os.path.join(output_index, "images") # 保存图片的子目录,用于显示refinement的html中的
397
-
398
- # 确保输出目录存在
399
- os.makedirs(output_index, exist_ok=True)
400
-
401
- # 将图片复制到final/images目录下
402
- import shutil
403
- source_images_dir = os.path.join(input_index, "images")
404
- if os.path.exists(source_images_dir):
405
- shutil.copytree(source_images_dir, output_index_images, dirs_exist_ok=True)
406
- print(f"📁 Copied images to: {output_index_images}")
407
-
408
- # 2. 加载大纲数据
409
- with open(outline_path, 'r', encoding='utf-8') as f:
410
- outline_data = json.load(f)
411
- if isinstance(outline_data, list):
412
- # 将列表转换为以索引(字符串)为 Key 的字典
413
- # 假设 list[0] 对应文件 0_ppt.html 或 1_ppt.html,这里保持原始索引
414
- outline_full = {str(i): item for i, item in enumerate(outline_data)}
415
- else:
416
- outline_full = outline_data
417
-
418
- # ================= 核心修改逻辑开始 =================
419
-
420
- print(f"🚀 开始扫描目录: {input_index}")
421
-
422
- # 3.1 先过滤出所有 "数字_ppt.html" 格式文件
423
- target_files = []
424
- for f in os.listdir(input_index):
425
- # 严格匹配:数字开头 + _ppt.html 结尾
426
- if re.search(r'^\d+_ppt\.html$', f):
427
- target_files.append(f)
428
-
429
- # 3.2 定义排序 Key:直接提取开头的数字
430
- def get_file_number(filename):
431
- # 因为上一步已经过滤过了,这里可以直接提取
432
- return int(filename.split('_')[0])
433
-
434
- # 3.3 执行排序 (这步是关键,确保 2 在 10 前面)
435
- sorted_files = sorted(target_files, key=get_file_number)
436
-
437
- # Debug: 打印前几个文件确认顺序
438
- print(f"👀 排序后文件列表前5个: {sorted_files[:5]}")
439
-
440
- # 4. 遍历排序后的列表
441
- for file_name in sorted_files:
442
- # 直接提取序号 (之前已经验证过格式了)
443
- num = str(get_file_number(file_name))
444
-
445
- # 获取当前 html 对应的 outline
446
- outline = outline_full.get(int(num)-1)
447
-
448
- # 【容错逻辑】处理索引偏移 (例如文件是 1_ppt,但列表是从 0 开始)
449
- # 如果 outline 为空,且 num-1 存在,则尝试自动回退
450
- if outline is None and str(int(num)-1) in outline_full:
451
- print(f"ℹ️ 尝试修正索引: 文件 {num} -> 使用大纲 {int(num)-1}")
452
- outline = outline_full.get(str(int(num)-1))
453
-
454
- if outline is None:
455
- print(f"⚠️ 跳过 {file_name}: outline.json 中找不到序号 {num} 或 {int(num)-1}")
456
- continue
457
-
458
- # 构建路径
459
- html_file_path = os.path.join(input_index, file_name)
460
- html_file_path_refine = os.path.join(output_index, file_name)
461
-
462
- print(f"📝 [顺序处理中] 正在优化: {file_name} (对应大纲 Key: {num})")
463
-
464
- # 6. 调用优化函数
465
- try:
466
- refine_one_slide(
467
- input_path=html_file_path,
468
- output_path=html_file_path_refine,
469
- prompts=prompts,
470
- outline=outline,
471
- max_iterations=max_iterations,
472
- model=model,
473
- config=config
474
- )
475
- except Exception as e:
476
- print(f"❌ 处理 {file_name} 时出错: {e}")
477
-
478
- print(f"✅ 所有文件处理完成,结果保存在: {output_index}")
479
-
480
- def refinement_poster(input_html_path, prompts, output_html_path, model, config=None):
481
- # ---------------- 0. 配置准备 ----------------
482
- if config is None:
483
- config = {}
484
-
485
- api_keys_conf = config.get('api_keys', {})
486
-
487
- # 判别平台
488
- is_gemini = "gemini" in model.lower()
489
-
490
- # ---------------- 1. 路径与文件准备 ----------------
491
- auto_path = Path(input_html_path).parent
492
- final_index = os.path.join(auto_path, "final")
493
- final_index_image = os.path.join(final_index, "images")
494
- os.makedirs(final_index, exist_ok=True)
495
-
496
- # 复制图片文件夹
497
- source_images_dir = os.path.join(auto_path, "images")
498
- if os.path.exists(source_images_dir):
499
- if not os.path.exists(final_index_image):
500
- shutil.copytree(source_images_dir, final_index_image, dirs_exist_ok=True)
501
- print(f"📁 Images copied to: {final_index_image}")
502
-
503
- with open(input_html_path, 'r', encoding='utf-8') as f:
504
- current_html = f.read()
505
-
506
- # ---------------- 2. 截逻辑 (保持不变) ----------------
507
- screenshot_name = Path(input_html_path).stem + ".png"
508
- screenshot_path = os.path.join(final_index, screenshot_name)
509
-
510
- print(f"📸 Taking screenshot of {input_html_path}...")
511
- # 假设 take_screenshot_poster 是外部定义的函数
512
- take_screenshot_poster(input_html_path, screenshot_path)
513
-
514
- if not os.path.exists(screenshot_path):
515
- raise FileNotFoundError("Screenshot failed to generate.")
516
-
517
- # 读取截图数据
518
- with open(screenshot_path, "rb") as f:
519
- image_bytes = f.read()
520
-
521
- generated_text = ""
522
-
523
- # ---------------- 3. 调用 LLM ----------------
524
- print(f"🤖 Sending to Vision Model ({model}) on {'Gemini' if is_gemini else 'OpenAI'}...")
525
-
526
- try:
527
- if is_gemini:
528
- # === Gemini Client Setup ===
529
- api_key = api_keys_conf.get('gemini_api_key') or os.getenv("GOOGLE_API_KEY")
530
-
531
- client = genai.Client(api_key=api_key)
532
-
533
- # 构造 Gemini 所需的 Contents
534
- # 新版 SDK (google.genai) 推荐的构造方式
535
- response = client.models.generate_content(
536
- model=model,
537
- contents=[
538
- types.Part.from_text(text=prompts),
539
- types.Part.from_text(text=f"--- CURRENT HTML ---\n{current_html}"),
540
- types.Part.from_bytes(data=image_bytes, mime_type="image/png"),
541
- ]
542
- )
543
-
544
- if response.text:
545
- generated_text = response.text
546
- else:
547
- raise RuntimeError("Gemini returned empty text.")
548
-
549
- else:
550
- # === OpenAI Client Setup ===
551
- api_key = api_keys_conf.get('openai_api_key') or os.getenv("OPENAI_API_KEY")
552
-
553
- client = OpenAI(api_key=api_key)
554
-
555
- # OpenAI 需要 Base64 编码的图片
556
- base64_image = base64.b64encode(image_bytes).decode('utf-8')
557
-
558
- messages = [
559
- {
560
- "role": "system",
561
- "content": "You are an expert web designer and code refiner."
562
- },
563
- {
564
- "role": "user",
565
- "content": [
566
- {
567
- "type": "text",
568
- "text": f"{prompts}\n\n--- CURRENT HTML ---\n{current_html}"
569
- },
570
- {
571
- "type": "image_url",
572
- "image_url": {
573
- "url": f"data:image/png;base64,{base64_image}",
574
- "detail": "high"
575
- }
576
- }
577
- ]
578
- }
579
- ]
580
-
581
- response = client.chat.completions.create(
582
- model=model,
583
- messages=messages,
584
- max_tokens=4096
585
- )
586
- generated_text = response.choices[0].message.content
587
-
588
- # ---------------- 4. 解析与保存结果 ----------------
589
- # 清洗 Markdown 代码块标记
590
- if "```html" in generated_text:
591
- final_html = generated_text.split("```html")[1].split("```")[0].strip()
592
- elif "```" in generated_text:
593
- final_html = generated_text.split("```")[1].strip()
594
- else:
595
- final_html = generated_text
596
-
597
- with open(output_html_path, 'w', encoding='utf-8') as f:
598
- f.write(final_html)
599
-
600
- print(f"✅ Refined poster saved to: {output_html_path}")
601
-
602
- # 生成最终截图
603
- final_screenshot_name = Path(input_html_path).stem + "_final" + ".png"
604
- final_screenshot_path = os.path.join(final_index, final_screenshot_name)
605
-
606
- print(f"📸 Taking final poster screenshot of {output_html_path}...")
607
- take_screenshot_poster(output_html_path, final_screenshot_path)
608
-
609
- except Exception as e:
610
- print(f"❌ Error during AI generation: {e}")
611
-
612
- def refinement_pr(pr_path: str, pr_refine_path: str, prompts: dict, model: str, config: dict):
613
- """
614
- 提取Markdown中的特定章节,使用LLM根据传入的prompts指令进行优化,并重组文件。
615
- 严格保留Markdown原有结构、图片引用以及未被选中的尾部内容(如Hashtags)。
616
- """
617
-
618
- # 1. (修改) 获取配置,不再依赖环境变量,API Key在调用前再具体提取
619
- if config is None:
620
- config = {}
621
- api_keys = config.get('api_keys', {})
622
-
623
- # 2. 读取原始文件
624
- if not os.path.exists(pr_path):
625
- raise FileNotFoundError(f"文件未找到: {pr_path}")
626
-
627
- with open(pr_path, 'r', encoding='utf-8') as f:
628
- original_content = f.read()
629
-
630
- # 3. 定义分标题映射
631
- section_headers = {
632
- "Key Question": r"🔍 \*\*Key Question\*\*",
633
- "Brilliant Idea": r"💡 \*\*Brilliant Idea\*\*",
634
- "Core Methods": r"🚀 \*\*Core Methods\*\*",
635
- "Core Results": r"📊 \*\*Core Results\*\*",
636
- "Significance/Impact": r"🧠 \*\*Significance/Impact\*\*"
637
- }
638
-
639
- footer_pattern = r"🏷️\s*\*\*Hashtag\*\*"
640
-
641
- # 4. 定位 核心标题 位置
642
- matches = []
643
- for key, pattern in section_headers.items():
644
- found = list(re.finditer(pattern, original_content))
645
- if found:
646
- match = found[0]
647
- matches.append({
648
- "key": key,
649
- "header_start": match.start(),
650
- "header_end": match.end(),
651
- "header_text": match.group()
652
- })
653
-
654
- matches.sort(key=lambda x: x["header_start"])
655
-
656
- if not matches:
657
- print("未检测到目标章节,直接复制文件。")
658
- with open(pr_refine_path, 'w', encoding='utf-8') as f:
659
- f.write(original_content)
660
- return
661
-
662
- # 定位 Footer (Hashtag) 位置
663
- footer_match = re.search(footer_pattern, original_content)
664
- if footer_match:
665
- global_content_end_limit = footer_match.start()
666
- else:
667
- print("Warning: 未检测到 '🏷️ **Hashtag**' 标记,最后一个章节将读取至文件末尾。")
668
- global_content_end_limit = len(original_content)
669
-
670
- # 5. 精确计算每个章节的“内容”范围
671
- content_ranges = {}
672
- for i, match in enumerate(matches):
673
- key = match["key"]
674
- content_start = match["header_end"]
675
- if i < len(matches) - 1:
676
- content_end = matches[i+1]["header_start"]
677
- else:
678
- content_end = max(content_start, global_content_end_limit)
679
-
680
- content_ranges[key] = {
681
- "start": content_start,
682
- "end": content_end,
683
- "text": original_content[content_start:content_end].strip()
684
- }
685
-
686
- # 6. 构建 LLM 请求
687
- extracted_data = {k: v["text"] for k, v in content_ranges.items()}
688
-
689
- system_prompt = (
690
- "You are an expert academic editor. Your task is to refine the content of specific sections of a paper summary based on user instructions.\n"
691
- "Input Format: JSON object {Section Name: Content}.\n"
692
- "Output Format: JSON object {Section Name: Refined Content}.\n"
693
- "CRITICAL RULES:\n"
694
- "1. **KEYS**: Keep the JSON keys EXACTLY the same as the input.\n"
695
- "2. **PURE BODY TEXT**: The output value must be pure body text. No Headers.\n"
696
- "3. **IMAGES**: Do NOT remove or modify markdown image links.\n"
697
- "4. **JSON ONLY**: Output pure JSON string.\n"
698
- "5. **FORMAT**: Use bolding ONLY for emphasis."
699
- )
700
-
701
- user_message = f"""
702
- [Refinement Instructions]
703
- {json.dumps(prompts, ensure_ascii=False)}
704
-
705
- [Content to Refine]
706
- {json.dumps(extracted_data, ensure_ascii=False)}
707
- """
708
-
709
- # === (修改) 核心:根据模型类型分流调用 ===
710
- llm_output = ""
711
- try:
712
- if "gemini" in model.lower():
713
- # --- Google Gemini (New SDK) ---
714
- api_key = api_keys.get("gemini_api_key", "").strip()
715
-
716
- if not api_key:
717
- raise ValueError("Missing config['api_keys']['gemini_api_key']")
718
-
719
- from google import genai
720
- from google.genai import types
721
-
722
- # 配置客户端
723
- client = genai.Client(api_key=api_key)
724
-
725
- response = client.models.generate_content(
726
- model=model,
727
- contents=user_message,
728
- config=types.GenerateContentConfig(
729
- system_instruction=system_prompt,
730
- temperature=0.2,
731
- response_mime_type="application/json" # 强制 JSON 模式,提高稳定性
732
- )
733
- )
734
- llm_output = response.text
735
-
736
- else:
737
- # --- OpenAI (Original) ---
738
- api_key = api_keys.get("openai_api_key", "").strip()
739
-
740
- if not api_key:
741
- # 兼容性回退:如果config里没有,尝试读环境变量
742
- api_key = os.getenv("OPENAI_API_KEY")
743
-
744
- if not api_key:
745
- raise ValueError("Missing config['api_keys']['openai_api_key']")
746
-
747
- from openai import OpenAI
748
- client = OpenAI(api_key=api_key)
749
-
750
- response = client.chat.completions.create(
751
- model=model,
752
- messages=[
753
- {"role": "system", "content": system_prompt},
754
- {"role": "user", "content": user_message}
755
- ],
756
- temperature=0.2
757
- )
758
- llm_output = response.choices[0].message.content.strip()
759
-
760
- except Exception as e:
761
- print(f"LLM API 调用失败: {e}")
762
- return
763
-
764
- # 7. 清洗 LLM 返回的 JSON
765
- try:
766
- # 移除可能存在的 markdown 代码块标记
767
- cleaned_output = llm_output.replace("```json", "").replace("```", "").strip()
768
- refined_data = json.loads(cleaned_output)
769
- except json.JSONDecodeError:
770
- print("解析 LLM 返回的 JSON 失败。Raw output:", llm_output)
771
- return
772
-
773
- # 8. 重组文件
774
- new_file_parts = []
775
- current_idx = 0
776
-
777
- # 按照原文件中的出现顺序处理
778
- sorted_matches = sorted(matches, key=lambda x: x["header_start"])
779
-
780
- for item in sorted_matches:
781
- key = item["key"]
782
- range_info = content_ranges[key]
783
- c_start = range_info["start"]
784
- c_end = range_info["end"]
785
-
786
- # 1. 拼接未修改部分 (上一个节点结束 到 当前节点内容开始)
787
- pre_content = original_content[current_idx:c_start]
788
- new_file_parts.append(pre_content)
789
-
790
- # 2. 拼接新内容
791
- if key in refined_data:
792
- new_text = refined_data[key]
793
- # 简单格式处理:确保换行
794
- if new_file_parts[-1] and not new_file_parts[-1].endswith('\n'):
795
- new_text = "\n" + new_text
796
- new_text = "\n" + new_text.strip() + "\n"
797
- new_file_parts.append(new_text)
798
- else:
799
- new_file_parts.append(original_content[c_start:c_end])
800
-
801
- new_file_parts.append('\n')
802
-
803
- # 3. 更新游标
804
- current_idx = c_end
805
-
806
- # 9. 添加文件剩余的所有内容
807
- new_file_parts.append(original_content[current_idx:])
808
-
809
- final_markdown = "".join(new_file_parts)
810
-
811
- # 10. 保存结果
812
- os.makedirs(os.path.dirname(os.path.abspath(pr_refine_path)), exist_ok=True)
813
-
814
- with open(pr_refine_path, 'w', encoding='utf-8') as f:
815
- f.write(final_markdown)
816
-
817
- print(f"文件优化完成,已保存至: {pr_refine_path}")
818
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import os
3
+ import re
4
+ import json
5
+ import time
6
+ import PIL.Image
7
+ import shutil
8
+ from PIL import Image
9
+ from pathlib import Path
10
+ from openai import OpenAI
11
+ from google import genai
12
+ from google.genai import types
13
+ from .html_revise import HTMLMapper, apply_html_modifications, HTMLModificationError
14
+ from playwright.sync_api import sync_playwright
15
+
16
+
17
+ class VLMCommenter:
18
+ def __init__(self, api_key, prompt, provider="openai", model_name=None):
19
+ """
20
+ :param api_key: API Key
21
+ :param prompt: 提示词文本
22
+ :param provider: "openai" 或 "gemini"
23
+ :param model_name: 指定模型名称 (可选)
24
+ """
25
+ self.provider = provider.lower()
26
+ self.api_key = api_key
27
+ self.model_name = model_name
28
+ self.prompt_text = prompt
29
+
30
+ if self.provider == "openai":
31
+ self.client = OpenAI(api_key=api_key)
32
+ self.model = model_name if model_name else "gpt-4o"
33
+ elif self.provider == "gemini":
34
+ raw_url = config.get('api_base_url', '').strip().rstrip("/")
35
+ if raw_url.endswith("/v1"):
36
+ base_url = raw_url[:-3].rstrip("/") # 去掉最后的 /v1
37
+ else:
38
+ base_url = raw_url
39
+ self.client = genai.Client(api_key=api_key, http_options={'base_url': base_url} if base_url else None)
40
+ self.model = model_name if model_name else "gemini-1.5-flash"
41
+ else:
42
+ raise ValueError("Unsupported provider. Choose 'openai' or 'gemini'.")
43
+
44
+ def _encode_image(self, image_path):
45
+ with open(image_path, "rb") as image_file:
46
+ return base64.b64encode(image_file.read()).decode('utf-8')
47
+
48
+ def evaluate_slide(self, image_path, outline, pre_comments):
49
+ """
50
+ 输入:截图路径
51
+ 输出:诊断文本 string
52
+ """
53
+ prompt_text = self.prompt_text
54
+ full_prompt = f"{prompt_text}\n \
55
+ *****previous comments******\
56
+ \n{pre_comments} \
57
+ *****begin of the outline*****\
58
+ \n{outline} \
59
+ *****end of the outline*****\
60
+ *****the following is the image,not the outline*****"
61
+
62
+ if not full_prompt:
63
+ return "Error: Commenter prompt is empty."
64
+
65
+ if self.provider == "openai":
66
+ base64_image = self._encode_image(image_path)
67
+ try:
68
+ response = self.client.chat.completions.create(
69
+ model=self.model,
70
+ messages=[
71
+ {"role": "system", "content": "You are a helpful assistant."},
72
+ {"role": "user", "content": [
73
+ {"type": "text", "text": full_prompt},
74
+ {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{base64_image}"}}
75
+ ]}
76
+ ],
77
+ max_tokens=300
78
+ )
79
+ return response.choices[0].message.content
80
+ except Exception as e:
81
+ return f"Error using OpenAI VLM: {e}"
82
+
83
+ elif self.provider == "gemini":
84
+ try:
85
+ img = PIL.Image.open(image_path)
86
+
87
+ response = self.client.models.generate_content(
88
+ model=self.model,
89
+ contents=[full_prompt, img]
90
+ )
91
+ return response.text
92
+ except Exception as e:
93
+ return f"Error using Gemini VLM (google-genai): {e}"
94
+
95
+
96
+ class LLMReviser:
97
+ def __init__(self, api_key, prompt, provider="openai", model_name=None):
98
+ """
99
+ :param api_key: API Key
100
+ :param prompt: 提示词文本
101
+ :param provider: "openai" 或 "gemini"
102
+ :param model_name: 指定模型名称
103
+ """
104
+ self.provider = provider.lower()
105
+ self.api_key = api_key
106
+ self.model_name = model_name
107
+ self.system_prompt = prompt
108
+
109
+ if self.provider == "openai":
110
+ self.client = OpenAI(api_key=api_key)
111
+ self.model = model_name if model_name else "gpt-4"
112
+ elif self.provider == "gemini":
113
+ raw_url = config.get('api_base_url', '').strip().rstrip("/")
114
+ if raw_url.endswith("/v1"):
115
+ base_url = raw_url[:-3].rstrip("/") # 去掉最后的 /v1
116
+ else:
117
+ base_url = raw_url
118
+ self.client = genai.Client(api_key=api_key, http_options={'base_url': base_url} if base_url else None)
119
+ self.model = model_name if model_name else "gemini-1.5-pro"
120
+ else:
121
+ raise ValueError("Unsupported provider. Choose 'openai' or 'gemini'.")
122
+
123
+ def generate_revision_plan(self, current_structure_json, vlm_critique):
124
+ """
125
+ 输入:HTML 结构 JSON 和 VLM 评价
126
+ 输出:修改后的 JSON
127
+ """
128
+
129
+ if "PASS" in vlm_critique.upper() and len(vlm_critique) < 10:
130
+ return None
131
+
132
+ prompt_system = self.system_prompt
133
+ if not prompt_system:
134
+ print("Error: Reviser prompt is empty.")
135
+ return None
136
+
137
+ user_content = f"""
138
+ --- CURRENT STRUCTURE JSON ---
139
+ {json.dumps(current_structure_json, indent=2)}
140
+
141
+ --- VISUAL CRITIQUE ---
142
+ {vlm_critique}
143
+
144
+ --- INSTRUCTION ---
145
+ Generate the modification JSON based on the system instructions.
146
+ """
147
+
148
+ if self.provider == "openai":
149
+ try:
150
+ response = self.client.chat.completions.create(
151
+ model=self.model,
152
+ messages=[
153
+ {"role": "system", "content": prompt_system},
154
+ {"role": "user", "content": user_content}
155
+ ],
156
+ response_format={"type": "json_object"}
157
+ )
158
+ return json.loads(response.choices[0].message.content)
159
+ except Exception as e:
160
+ print(f"OpenAI Error: {e}")
161
+ return None
162
+
163
+ elif self.provider == "gemini":
164
+ try:
165
+ # 拼接 System Prompt 和 User Content
166
+ full_prompt = f"{prompt_system}\n\n{user_content}"
167
+
168
+ response = self.client.models.generate_content(
169
+ model=self.model,
170
+ contents=full_prompt,
171
+ config=types.GenerateContentConfig(
172
+ response_mime_type="application/json"
173
+ )
174
+ )
175
+
176
+ text_response = response.text
177
+
178
+ # 清洗可能存在的 Markdown 标记 (即使指定了 JSON mime type,有些模型仍可能加 ```json)
179
+ if text_response.startswith("```"):
180
+ text_response = text_response.strip("`").replace("json", "").strip()
181
+
182
+ return json.loads(text_response)
183
+ except json.JSONDecodeError:
184
+ print(f"Gemini returned invalid JSON: {response.text}")
185
+ return None
186
+ except Exception as e:
187
+ print(f"Gemini Error (google-genai): {e}")
188
+ return None
189
+
190
+
191
+ def take_screenshot(html_path, output_path):
192
+ """简单的截图工具函数示例 (Playwright)"""
193
+ if not os.path.exists(html_path):
194
+ print(f"错误:文件不存在于 {html_path}")
195
+ abs_path = Path(os.path.abspath(html_path)).as_uri()
196
+
197
+ with sync_playwright() as p:
198
+ # 1. 显式设置 device_scale_factor=1
199
+ browser = p.chromium.launch()
200
+ context = browser.new_context(
201
+ viewport={'width': 960, 'height': 540},
202
+ device_scale_factor=1
203
+ )
204
+ page = context.new_page()
205
+
206
+ # 2. 访问页面
207
+ page.goto(abs_path, wait_until="networkidle") # 确保片和字体加载完成
208
+
209
+ # 3. 截取特定元素而非全屏,这样���保险
210
+ # 你的根 div id 是 slide1,或者直接截取 svg
211
+ element = page.locator(".slideImage")
212
+ element.screenshot(path=output_path)
213
+
214
+ browser.close()
215
+
216
+ def take_screenshot_poster(html_path, output_path):
217
+ """适配 .poster/#flow 的 HTML 海报截图(Playwright, 同步)"""
218
+ if not os.path.exists(html_path):
219
+ raise FileNotFoundError(f"文件不存在: {html_path}")
220
+
221
+ abs_uri = Path(os.path.abspath(html_path)).as_uri()
222
+
223
+ with sync_playwright() as p:
224
+ browser = p.chromium.launch(headless=True, args=["--disable-dev-shm-usage"])
225
+ context = browser.new_context(
226
+ # 你的 CSS 固定 --poster-width/height=1400x900
227
+ viewport={"width": 1400, "height": 900},
228
+ device_scale_factor=1
229
+ )
230
+ page = context.new_page()
231
+
232
+ # 1) 先 DOMReady,避免 networkidle 卡死
233
+ page.goto(abs_uri, wait_until="domcontentloaded")
234
+
235
+ # 2) 确保关键容器存在
236
+ page.wait_for_selector(".poster", state="attached", timeout=30000)
237
+ page.wait_for_selector("#flow", state="attached", timeout=30000)
238
+
239
+ # 3) 等字体就绪(你的脚本里 fit 受字体/排版影响很大)
240
+ try:
241
+ page.evaluate("() => document.fonts ? document.fonts.ready : Promise.resolve()")
242
+ except Exception:
243
+ pass
244
+
245
+ # 4) flow 内所有图片加载完成(没有图片也会立即返回)
246
+ page.evaluate(r"""
247
+ () => {
248
+ const flow = document.getElementById("flow");
249
+ if (!flow) return Promise.resolve();
250
+ const imgs = Array.from(flow.querySelectorAll("img"));
251
+ if (imgs.length === 0) return Promise.resolve();
252
+ return Promise.all(imgs.map(img => {
253
+ if (img.complete) return Promise.resolve();
254
+ return new Promise(res => {
255
+ img.addEventListener("load", res, { once: true });
256
+ img.addEventListener("error", res, { once: true });
257
+ });
258
+ }));
259
+ }
260
+ """)
261
+
262
+ # 5) 等你的 fit() 执行并让布局稳定:等几帧 + scrollWidth 不再变化
263
+ page.evaluate(r"""
264
+ () => new Promise((resolve) => {
265
+ const flow = document.getElementById("flow");
266
+ if (!flow) return resolve();
267
+
268
+ let last = -1;
269
+ let stableCount = 0;
270
+
271
+ function tick() {
272
+ const cur = flow.scrollWidth; // multi-column 溢出判据
273
+ if (cur === last) stableCount += 1;
274
+ else stableCount = 0;
275
+
276
+ last = cur;
277
+
278
+ // 连续若干帧稳定,就认为 fit/重排结束
279
+ if (stableCount >= 10) return resolve();
280
+ requestAnimationFrame(tick);
281
+ }
282
+
283
+ // 给 load 事件/fit 一点点启动时间
284
+ setTimeout(() => requestAnimationFrame(tick), 50);
285
+ })
286
+ """)
287
+
288
+ # 6) 截图:截 .poster(不截 stage 背景)
289
+ poster = page.locator(".poster").first
290
+ poster.screenshot(path=output_path, timeout=60000)
291
+
292
+ browser.close()
293
+
294
+
295
+ def load_prompt(prompt_path="prompt.json", prompt_name="poster_prompt"):
296
+ with open(prompt_path, "r", encoding="utf-8") as f:
297
+ data = json.load(f)
298
+ return data.get(prompt_name, "")
299
+
300
+
301
+ def refine_one_slide(input_path, output_path, prompts, outline, max_iterations, model, config):
302
+ """
303
+ 自动修复闭环:截图 -> 诊断 -> 修改 -> 循环
304
+ """
305
+ is_gemini = "gemini" in model.lower()
306
+
307
+ if is_gemini:
308
+ api_key = config['api_keys'].get('gemini_api_key')
309
+ else:
310
+ api_key = config['api_keys'].get('openai_api_key')
311
+
312
+ commenter_prompt = prompts[0]
313
+ reviser_prompt = prompts[1]
314
+
315
+ platform = "gemini" if "gemini" in model.lower() else "openai"
316
+
317
+ vlm = VLMCommenter(api_key, commenter_prompt, provider=platform, model_name=model)
318
+ reviser = LLMReviser(api_key, reviser_prompt, provider=platform, model_name=model)
319
+
320
+ current_input = input_path
321
+ critic_his = ""
322
+
323
+ for i in range(max_iterations):
324
+ print(f"\n=== Iteration {i+1} ===")
325
+
326
+ # 1. 渲染并截图 (这里用伪代码表示,实际可用 Selenium/Playwright)
327
+ screenshot_path = f"{Path(output_path).parent}/{Path(current_input).stem}_{i+1}.png" # 临时截图路径
328
+ take_screenshot(current_input, screenshot_path)
329
+ print(f"Screenshot taken: {screenshot_path}")
330
+
331
+ # 2. VLM 视觉诊断
332
+ critique = vlm.evaluate_slide(screenshot_path, outline, critic_his)
333
+ critic_his = critic_his + f"this is the {i}th comment: {critique}"
334
+ print(f"VLM Critique: {critique}")
335
+
336
+ if "PASS" in critique:
337
+ print("Layout looks good! Stopping loop.")
338
+
339
+ if os.path.abspath(current_input) != os.path.abspath(output_path):
340
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
341
+ shutil.copy2(current_input, output_path)
342
+ print(f"Final result saved to: {output_path}")
343
+ else:
344
+ print(f"Result is already at output path: {output_path}")
345
+
346
+ break
347
+
348
+ # 3. 读取当前 HTML 结构
349
+ mapper = HTMLMapper(current_input)
350
+ current_tree = mapper.get_structure_tree()
351
+
352
+ # 4. LLM 生成修改方案
353
+ modification_json = reviser.generate_revision_plan(current_tree, critique)
354
+
355
+ if not modification_json:
356
+ print("Reviser suggested no changes. Stopping.")
357
+ break
358
+
359
+ print(f"Proposed Changes: {json.dumps(modification_json, indent=2)}")
360
+
361
+ # 5. 执行修改
362
+ try:
363
+ # 调用 reviser,如果有严重错误,它现在会抛出 HTMLModificationError
364
+ apply_html_modifications(current_input, output_path, modification_json)
365
+ print("Modifications applied to HTML.")
366
+
367
+ except (HTMLModificationError, Exception) as e:
368
+ # 捕获自定义错误 或 其他意外错误
369
+ print(f"❌ Error applying modifications at iteration {i+1}: {e}")
370
+
371
+ # ====== 添加另存副本逻辑 ======
372
+ try:
373
+ # 获取 input_path (最原始文件)
374
+ base_name = Path(input_path).stem
375
+
376
+ # 定义错误副本路径
377
+ error_backup_path = f"{Path(output_path).parent}/{base_name}_FAILED_iter{i+1}.html"
378
+
379
+ # 将导致出错的那个 HTML 文件 (current_input) 复制出来
380
+ shutil.copy2(current_input, error_backup_path)
381
+
382
+ print(f"⚠️ 已自动保存出错前的 HTML 副本: {error_backup_path}")
383
+ print(f" 你可以打开此文件,并使用控制台打印的 JSON 尝试复现问题。")
384
+
385
+ except Exception as copy_err:
386
+ print(f"❌ 尝试保存错误副本时失败: {copy_err}")
387
+ # ============================
388
+
389
+ # 出错后中断循环
390
+ break
391
+ current_input = output_path
392
+
393
+ # 等待一会,防止并发读写问题
394
+ time.sleep(1)
395
+
396
+ # 优化结束输出最终截
397
+ final_screenshot_path = f"{Path(output_path).parent}/{Path(current_input).stem}_final.png"
398
+ take_screenshot(current_input, final_screenshot_path)
399
+ print(f"\n📷 Final screenshot saved: {final_screenshot_path}")
400
+
401
+
402
+ def refinement_ppt(input_index, prompts, max_iterations=3, model="gpt-4o", config=None):
403
+ # 1. 定义路径
404
+ outline_path = os.path.join(input_index, "outline.json")
405
+ output_index = os.path.join(input_index, "final")
406
+ output_index_images = os.path.join(output_index, "images") # 保存图片的子目录,用于显示refinement后的html中的图片
407
+
408
+ # 确保输出目录存在
409
+ os.makedirs(output_index, exist_ok=True)
410
+
411
+ # 将图片复制到final/images目录下
412
+ import shutil
413
+ source_images_dir = os.path.join(input_index, "images")
414
+ if os.path.exists(source_images_dir):
415
+ shutil.copytree(source_images_dir, output_index_images, dirs_exist_ok=True)
416
+ print(f"📁 Copied images to: {output_index_images}")
417
+
418
+ # 2. 加载大纲数据
419
+ with open(outline_path, 'r', encoding='utf-8') as f:
420
+ outline_data = json.load(f)
421
+ if isinstance(outline_data, list):
422
+ # 将列表转换为以索引(字串)为 Key字典
423
+ # 假设 list[0] 对应文件 0_ppt.html 或 1_ppt.html,这里保持原始索引
424
+ outline_full = {str(i): item for i, item in enumerate(outline_data)}
425
+ else:
426
+ outline_full = outline_data
427
+
428
+ # ================= 核心修改逻辑开始 =================
429
+
430
+ print(f"🚀 开始扫描目录: {input_index}")
431
+
432
+ # 3.1 先过滤出所有符合 "数字_ppt.html" 格式的文件
433
+ target_files = []
434
+ for f in os.listdir(input_index):
435
+ # 严格匹配:数字开头 + _ppt.html 结尾
436
+ if re.search(r'^\d+_ppt\.html$', f):
437
+ target_files.append(f)
438
+
439
+ # 3.2 定义排序 Key:直接提取开头的数字
440
+ def get_file_number(filename):
441
+ # 因为上一步已经过滤过了,这里可以直接提取
442
+ return int(filename.split('_')[0])
443
+
444
+ # 3.3 执行排序 (这步是关键,确保 2 在 10 前面)
445
+ sorted_files = sorted(target_files, key=get_file_number)
446
+
447
+ # Debug: 打印前几个文件确认顺序
448
+ print(f"👀 排序后文件列表前5个: {sorted_files[:5]}")
449
+
450
+ # 4. 遍历排序后的列表
451
+ for file_name in sorted_files:
452
+ # 直接提取序号 (之前已经验证过格式了)
453
+ num = str(get_file_number(file_name))
454
+
455
+ # 获取当前 html 对应的 outline
456
+ outline = outline_full.get(int(num)-1)
457
+
458
+ # 【容错逻辑】处理索引偏移 (例如文件是 1_ppt,但列表是从 0 开始)
459
+ # 如果 outline 为空,且 num-1 存在,则尝试��动回退
460
+ if outline is None and str(int(num)-1) in outline_full:
461
+ print(f"ℹ️ 尝试修正索引: 文件 {num} -> 使用大纲 {int(num)-1}")
462
+ outline = outline_full.get(str(int(num)-1))
463
+
464
+ if outline is None:
465
+ print(f"⚠️ 跳过 {file_name}: 在 outline.json 中找不到序号 {num} 或 {int(num)-1}")
466
+ continue
467
+
468
+ # 构建路径
469
+ html_file_path = os.path.join(input_index, file_name)
470
+ html_file_path_refine = os.path.join(output_index, file_name)
471
+
472
+ print(f"📝 [顺序处理中] 正在优化: {file_name} (对应大纲 Key: {num})")
473
+
474
+ # 6. 调用优化函数
475
+ try:
476
+ refine_one_slide(
477
+ input_path=html_file_path,
478
+ output_path=html_file_path_refine,
479
+ prompts=prompts,
480
+ outline=outline,
481
+ max_iterations=max_iterations,
482
+ model=model,
483
+ config=config
484
+ )
485
+ except Exception as e:
486
+ print(f"❌ 处理 {file_name} 时出错: {e}")
487
+
488
+ print(f" 所有文件处理完成,结果保存在: {output_index}")
489
+
490
+ def refinement_poster(input_html_path, prompts, output_html_path, model, config=None):
491
+ # ---------------- 0. 配置准备 ----------------
492
+ if config is None:
493
+ config = {}
494
+
495
+ api_keys_conf = config.get('api_keys', {})
496
+
497
+ # 判别平台
498
+ is_gemini = "gemini" in model.lower()
499
+
500
+ # ---------------- 1. 路径与文件准备 ----------------
501
+ auto_path = Path(input_html_path).parent
502
+ final_index = os.path.join(auto_path, "final")
503
+ final_index_image = os.path.join(final_index, "images")
504
+ os.makedirs(final_index, exist_ok=True)
505
+
506
+ # 复制片文件夹
507
+ source_images_dir = os.path.join(auto_path, "images")
508
+ if os.path.exists(source_images_dir):
509
+ if not os.path.exists(final_index_image):
510
+ shutil.copytree(source_images_dir, final_index_image, dirs_exist_ok=True)
511
+ print(f"📁 Images copied to: {final_index_image}")
512
+
513
+ with open(input_html_path, 'r', encoding='utf-8') as f:
514
+ current_html = f.read()
515
+
516
+ # ---------------- 2. 截图逻辑 (保持不变) ----------------
517
+ screenshot_name = Path(input_html_path).stem + ".png"
518
+ screenshot_path = os.path.join(final_index, screenshot_name)
519
+
520
+ print(f"📸 Taking screenshot of {input_html_path}...")
521
+ # 假设 take_screenshot_poster 是外部定义的函数
522
+ take_screenshot_poster(input_html_path, screenshot_path)
523
+
524
+ if not os.path.exists(screenshot_path):
525
+ raise FileNotFoundError("Screenshot failed to generate.")
526
+
527
+ # 读取截图数据
528
+ with open(screenshot_path, "rb") as f:
529
+ image_bytes = f.read()
530
+
531
+ generated_text = ""
532
+
533
+ # ---------------- 3. 调用 LLM ----------------
534
+ print(f"🤖 Sending to Vision Model ({model}) on {'Gemini' if is_gemini else 'OpenAI'}...")
535
+
536
+ try:
537
+ if is_gemini:
538
+ raw_url = config.get('api_base_url', '').strip().rstrip("/")
539
+ if raw_url.endswith("/v1"):
540
+ base_url = raw_url[:-3].rstrip("/") # 去掉最后的 /v1
541
+ else:
542
+ base_url = raw_url
543
+ # === Gemini Client Setup ===
544
+ api_key = api_keys_conf.get('gemini_api_key') or os.getenv("GOOGLE_API_KEY")
545
+
546
+ client = genai.Client(api_key=api_key, http_options={'base_url': base_url} if base_url else None)
547
+
548
+ # 构造 Gemini 所需的 Contents
549
+ # 新版 SDK (google.genai) 推荐的构造方式
550
+ response = client.models.generate_content(
551
+ model=model,
552
+ contents=[
553
+ types.Part.from_text(text=prompts),
554
+ types.Part.from_text(text=f"--- CURRENT HTML ---\n{current_html}"),
555
+ types.Part.from_bytes(data=image_bytes, mime_type="image/png"),
556
+ ]
557
+ )
558
+
559
+ if response.text:
560
+ generated_text = response.text
561
+ else:
562
+ raise RuntimeError("Gemini returned empty text.")
563
+
564
+ else:
565
+ # === OpenAI Client Setup ===
566
+ api_key = api_keys_conf.get('openai_api_key') or os.getenv("OPENAI_API_KEY")
567
+
568
+ client = OpenAI(api_key=api_key)
569
+
570
+ # OpenAI 需要 Base64 编码的图片
571
+ base64_image = base64.b64encode(image_bytes).decode('utf-8')
572
+
573
+ messages = [
574
+ {
575
+ "role": "system",
576
+ "content": "You are an expert web designer and code refiner."
577
+ },
578
+ {
579
+ "role": "user",
580
+ "content": [
581
+ {
582
+ "type": "text",
583
+ "text": f"{prompts}\n\n--- CURRENT HTML ---\n{current_html}"
584
+ },
585
+ {
586
+ "type": "image_url",
587
+ "image_url": {
588
+ "url": f"data:image/png;base64,{base64_image}",
589
+ "detail": "high"
590
+ }
591
+ }
592
+ ]
593
+ }
594
+ ]
595
+
596
+ response = client.chat.completions.create(
597
+ model=model,
598
+ messages=messages,
599
+ max_tokens=4096
600
+ )
601
+ generated_text = response.choices[0].message.content
602
+
603
+ # ---------------- 4. 解析与保存结果 ----------------
604
+ # 清洗 Markdown 代码块标记
605
+ if "```html" in generated_text:
606
+ final_html = generated_text.split("```html")[1].split("```")[0].strip()
607
+ elif "```" in generated_text:
608
+ final_html = generated_text.split("```")[1].strip()
609
+ else:
610
+ final_html = generated_text
611
+
612
+ with open(output_html_path, 'w', encoding='utf-8') as f:
613
+ f.write(final_html)
614
+
615
+ print(f"✅ Refined poster saved to: {output_html_path}")
616
+
617
+ # 生成最终截图
618
+ final_screenshot_name = Path(input_html_path).stem + "_final" + ".png"
619
+ final_screenshot_path = os.path.join(final_index, final_screenshot_name)
620
+
621
+ print(f"📸 Taking final poster screenshot of {output_html_path}...")
622
+ take_screenshot_poster(output_html_path, final_screenshot_path)
623
+
624
+ except Exception as e:
625
+ print(f"❌ Error during AI generation: {e}")
626
+
627
+ def refinement_pr(pr_path: str, pr_refine_path: str, prompts: dict, model: str, config: dict):
628
+ """
629
+ 提取Markdown中的特定章节,使用LLM根据传入的prompts指令进行优化,并重组文件。
630
+ 严格保留Markdown原有结构、图片引用以及未被选中的尾内容(如Hashtags)。
631
+ """
632
+
633
+ # 1. (修改) 获取配置,不再依赖环境变量,API Key在调用前再具体提取
634
+ if config is None:
635
+ config = {}
636
+ api_keys = config.get('api_keys', {})
637
+
638
+ # 2. 读取原始文件
639
+ if not os.path.exists(pr_path):
640
+ raise FileNotFoundError(f"文件未找到: {pr_path}")
641
+
642
+ with open(pr_path, 'r', encoding='utf-8') as f:
643
+ original_content = f.read()
644
+
645
+ # 3. 定义部分标题映射
646
+ section_headers = {
647
+ "Key Question": r"🔍 \*\*Key Question\*\*",
648
+ "Brilliant Idea": r"💡 \*\*Brilliant Idea\*\*",
649
+ "Core Methods": r"🚀 \*\*Core Methods\*\*",
650
+ "Core Results": r"📊 \*\*Core Results\*\*",
651
+ "Significance/Impact": r"🧠 \*\*Significance/Impact\*\*"
652
+ }
653
+
654
+ footer_pattern = r"🏷️\s*\*\*Hashtag\*\*"
655
+
656
+ # 4. 定位 核心标题 位置
657
+ matches = []
658
+ for key, pattern in section_headers.items():
659
+ found = list(re.finditer(pattern, original_content))
660
+ if found:
661
+ match = found[0]
662
+ matches.append({
663
+ "key": key,
664
+ "header_start": match.start(),
665
+ "header_end": match.end(),
666
+ "header_text": match.group()
667
+ })
668
+
669
+ matches.sort(key=lambda x: x["header_start"])
670
+
671
+ if not matches:
672
+ print("未检测到目标章节,直接复制文件。")
673
+ with open(pr_refine_path, 'w', encoding='utf-8') as f:
674
+ f.write(original_content)
675
+ return
676
+
677
+ # 定位 Footer (Hashtag) 位置
678
+ footer_match = re.search(footer_pattern, original_content)
679
+ if footer_match:
680
+ global_content_end_limit = footer_match.start()
681
+ else:
682
+ print("Warning: 未检测到 '🏷️ **Hashtag**' 标记,最后一个章节将读取至文件末尾。")
683
+ global_content_end_limit = len(original_content)
684
+
685
+ # 5. 精确计算每个章节的“内容”范围
686
+ content_ranges = {}
687
+ for i, match in enumerate(matches):
688
+ key = match["key"]
689
+ content_start = match["header_end"]
690
+ if i < len(matches) - 1:
691
+ content_end = matches[i+1]["header_start"]
692
+ else:
693
+ content_end = max(content_start, global_content_end_limit)
694
+
695
+ content_ranges[key] = {
696
+ "start": content_start,
697
+ "end": content_end,
698
+ "text": original_content[content_start:content_end].strip()
699
+ }
700
+
701
+ # 6. 构建 LLM 请求
702
+ extracted_data = {k: v["text"] for k, v in content_ranges.items()}
703
+
704
+ system_prompt = (
705
+ "You are an expert academic editor. Your task is to refine the content of specific sections of a paper summary based on user instructions.\n"
706
+ "Input Format: JSON object {Section Name: Content}.\n"
707
+ "Output Format: JSON object {Section Name: Refined Content}.\n"
708
+ "CRITICAL RULES:\n"
709
+ "1. **KEYS**: Keep the JSON keys EXACTLY the same as the input.\n"
710
+ "2. **PURE BODY TEXT**: The output value must be pure body text. No Headers.\n"
711
+ "3. **IMAGES**: Do NOT remove or modify markdown image links.\n"
712
+ "4. **JSON ONLY**: Output pure JSON string.\n"
713
+ "5. **FORMAT**: Use bolding ONLY for emphasis."
714
+ )
715
+
716
+ user_message = f"""
717
+ [Refinement Instructions]
718
+ {json.dumps(prompts, ensure_ascii=False)}
719
+
720
+ [Content to Refine]
721
+ {json.dumps(extracted_data, ensure_ascii=False)}
722
+ """
723
+
724
+ # === (修改) 核心:根据模型类型分流调用 ===
725
+ llm_output = ""
726
+ try:
727
+ if "gemini" in model.lower():
728
+ raw_url = config.get('api_base_url', '').strip().rstrip("/")
729
+ if raw_url.endswith("/v1"):
730
+ base_url = raw_url[:-3].rstrip("/") # 去掉最后的 /v1
731
+ else:
732
+ base_url = raw_url
733
+ # --- Google Gemini (New SDK) ---
734
+ api_key = api_keys.get("gemini_api_key", "").strip()
735
+
736
+ if not api_key:
737
+ raise ValueError("Missing config['api_keys']['gemini_api_key']")
738
+
739
+ # 配置客户端
740
+ client = genai.Client(api_key=api_key, http_options={'base_url': base_url} if base_url else None)
741
+
742
+ response = client.models.generate_content(
743
+ model=model,
744
+ contents=user_message,
745
+ config=types.GenerateContentConfig(
746
+ system_instruction=system_prompt,
747
+ temperature=0.2,
748
+ response_mime_type="application/json" # 强制 JSON 模式,提高稳定性
749
+ )
750
+ )
751
+ llm_output = response.text
752
+
753
+ else:
754
+ # --- OpenAI (Original) ---
755
+ api_key = api_keys.get("openai_api_key", "").strip()
756
+
757
+ if not api_key:
758
+ # 兼容性回退:如果config里没有,尝试读环境变量
759
+ api_key = os.getenv("OPENAI_API_KEY")
760
+
761
+ if not api_key:
762
+ raise ValueError("Missing config['api_keys']['openai_api_key']")
763
+
764
+ from openai import OpenAI
765
+ client = OpenAI(api_key=api_key)
766
+
767
+ response = client.chat.completions.create(
768
+ model=model,
769
+ messages=[
770
+ {"role": "system", "content": system_prompt},
771
+ {"role": "user", "content": user_message}
772
+ ],
773
+ temperature=0.2
774
+ )
775
+ llm_output = response.choices[0].message.content.strip()
776
+
777
+ except Exception as e:
778
+ print(f"LLM API 调用失败: {e}")
779
+ return
780
+
781
+ # 7. 清洗 LLM 返回的 JSON
782
+ try:
783
+ # 移除可能存在的 markdown 代码块标记
784
+ cleaned_output = llm_output.replace("```json", "").replace("```", "").strip()
785
+ refined_data = json.loads(cleaned_output)
786
+ except json.JSONDecodeError:
787
+ print("解析 LLM 返回的 JSON 失败。Raw output:", llm_output)
788
+ return
789
+
790
+ # 8. 重组文件
791
+ new_file_parts = []
792
+ current_idx = 0
793
+
794
+ # 按照原文件中的出现顺序处理
795
+ sorted_matches = sorted(matches, key=lambda x: x["header_start"])
796
+
797
+ for item in sorted_matches:
798
+ key = item["key"]
799
+ range_info = content_ranges[key]
800
+ c_start = range_info["start"]
801
+ c_end = range_info["end"]
802
+
803
+ # 1. 拼接未修改部分 (上一个节点结束 到 当前节点内容开始)
804
+ pre_content = original_content[current_idx:c_start]
805
+ new_file_parts.append(pre_content)
806
+
807
+ # 2. 拼接新内容
808
+ if key in refined_data:
809
+ new_text = refined_data[key]
810
+ # 简单格式处理:确保换行
811
+ if new_file_parts[-1] and not new_file_parts[-1].endswith('\n'):
812
+ new_text = "\n" + new_text
813
+ new_text = "\n" + new_text.strip() + "\n"
814
+ new_file_parts.append(new_text)
815
+ else:
816
+ new_file_parts.append(original_content[c_start:c_end])
817
+
818
+ new_file_parts.append('\n')
819
+
820
+ # 3. 更新游标
821
+ current_idx = c_end
822
+
823
+ # 9. 添加文件剩余的所有内容
824
+ new_file_parts.append(original_content[current_idx:])
825
+
826
+ final_markdown = "".join(new_file_parts)
827
+
828
+ # 10. 保存结果
829
+ os.makedirs(os.path.dirname(os.path.abspath(pr_refine_path)), exist_ok=True)
830
+
831
+ with open(pr_refine_path, 'w', encoding='utf-8') as f:
832
+ f.write(final_markdown)
833
+
834
+ print(f"文件优化完成,已保存至: {pr_refine_path}")
835
+