MultiThread
Browse files- main.py +24 -24
- poster/figures.py +16 -4
- poster/poster.py +58 -12
main.py
CHANGED
|
@@ -39,45 +39,45 @@ def generate_paper_poster(
|
|
| 39 |
figures_cap_cache = f"{pdf_stem}_figures_cap.json"
|
| 40 |
|
| 41 |
figures = []
|
| 42 |
-
figures_cap = []
|
| 43 |
print("开始提取图片...")
|
| 44 |
if os.path.exists(figures_cache) and os.path.exists(figures_cap_cache):
|
| 45 |
print(f"使用缓存的图片: {figures_cache}")
|
| 46 |
with open(figures_cache, "r") as f:
|
| 47 |
figures = json.load(f)
|
| 48 |
-
with open(figures_cap_cache, "r") as f:
|
| 49 |
-
|
| 50 |
else:
|
| 51 |
figures_img = extract_figures(url, pdf, task="figure")
|
| 52 |
figures_table = extract_figures(url, pdf, task="table")
|
| 53 |
-
img_caption = extract_figures(url, pdf, task="figurecaption")
|
| 54 |
-
table_caption = extract_figures(url, pdf, task="tablecaption")
|
| 55 |
-
threshold = 0.
|
| 56 |
-
while True:
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
figures_cap = [
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
]
|
| 67 |
-
print(f"{threshold:.2f} 提取到 {len(figures)} / {len(figures_cap)} 张图像")
|
| 68 |
-
if len(figures) == len(figures_cap):
|
| 69 |
-
|
| 70 |
-
threshold -= 0.05
|
| 71 |
|
| 72 |
with open(figures_cache, "w") as f:
|
| 73 |
json.dump(figures, f, ensure_ascii=False)
|
| 74 |
-
with open(figures_cap_cache, "w") as f:
|
| 75 |
-
|
| 76 |
|
| 77 |
while True:
|
| 78 |
try:
|
| 79 |
result = generate_poster_v3(
|
| 80 |
-
vendor, model, text_prompt, figures_prompt, pdf,
|
| 81 |
)
|
| 82 |
|
| 83 |
poster = result["image_based_poster"]
|
|
|
|
| 39 |
figures_cap_cache = f"{pdf_stem}_figures_cap.json"
|
| 40 |
|
| 41 |
figures = []
|
| 42 |
+
# figures_cap = []
|
| 43 |
print("开始提取图片...")
|
| 44 |
if os.path.exists(figures_cache) and os.path.exists(figures_cap_cache):
|
| 45 |
print(f"使用缓存的图片: {figures_cache}")
|
| 46 |
with open(figures_cache, "r") as f:
|
| 47 |
figures = json.load(f)
|
| 48 |
+
# with open(figures_cap_cache, "r") as f:
|
| 49 |
+
# figures_cap = json.load(f)
|
| 50 |
else:
|
| 51 |
figures_img = extract_figures(url, pdf, task="figure")
|
| 52 |
figures_table = extract_figures(url, pdf, task="table")
|
| 53 |
+
# img_caption = extract_figures(url, pdf, task="figurecaption")
|
| 54 |
+
# table_caption = extract_figures(url, pdf, task="tablecaption")
|
| 55 |
+
threshold = 0.75
|
| 56 |
+
# while True:
|
| 57 |
+
figures = [
|
| 58 |
+
image
|
| 59 |
+
for image, score in figures_img + figures_table
|
| 60 |
+
if score >= threshold
|
| 61 |
+
]
|
| 62 |
+
# figures_cap = [
|
| 63 |
+
# image
|
| 64 |
+
# for image, score in img_caption + table_caption
|
| 65 |
+
# if score >= threshold
|
| 66 |
+
# ]
|
| 67 |
+
# print(f"{threshold:.2f} 提取到 {len(figures)} / {len(figures_cap)} 张图像")
|
| 68 |
+
# if len(figures) == len(figures_cap):
|
| 69 |
+
# break
|
| 70 |
+
# threshold -= 0.05
|
| 71 |
|
| 72 |
with open(figures_cache, "w") as f:
|
| 73 |
json.dump(figures, f, ensure_ascii=False)
|
| 74 |
+
# with open(figures_cap_cache, "w") as f:
|
| 75 |
+
# json.dump(figures_cap, f, ensure_ascii=False)
|
| 76 |
|
| 77 |
while True:
|
| 78 |
try:
|
| 79 |
result = generate_poster_v3(
|
| 80 |
+
vendor, model, text_prompt, figures_prompt, pdf, figures, figures
|
| 81 |
)
|
| 82 |
|
| 83 |
poster = result["image_based_poster"]
|
poster/figures.py
CHANGED
|
@@ -2,6 +2,7 @@ import base64
|
|
| 2 |
import requests
|
| 3 |
import os
|
| 4 |
from pathlib import Path
|
|
|
|
| 5 |
|
| 6 |
from io import BytesIO
|
| 7 |
from PIL import Image
|
|
@@ -31,14 +32,25 @@ def _extract_figures(
|
|
| 31 |
|
| 32 |
|
| 33 |
def extract_figures(
|
| 34 |
-
url: str, pdf: str, task: str = "figure"
|
| 35 |
) -> list[tuple[str, float]]:
|
| 36 |
loader = ImagePDFLoader(pdf)
|
| 37 |
images = loader.load()
|
| 38 |
|
| 39 |
figures = []
|
| 40 |
-
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
base64_figures = []
|
| 44 |
for figure, score in figures:
|
|
@@ -52,7 +64,7 @@ def extract_figures(
|
|
| 52 |
|
| 53 |
|
| 54 |
if __name__ == "__main__":
|
| 55 |
-
url = ""
|
| 56 |
pdf = "1.pdf"
|
| 57 |
|
| 58 |
output_dir = Path("output")
|
|
|
|
| 2 |
import requests
|
| 3 |
import os
|
| 4 |
from pathlib import Path
|
| 5 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 6 |
|
| 7 |
from io import BytesIO
|
| 8 |
from PIL import Image
|
|
|
|
| 32 |
|
| 33 |
|
| 34 |
def extract_figures(
|
| 35 |
+
url: str, pdf: str, task: str = "figure", max_workers: int = 4
|
| 36 |
) -> list[tuple[str, float]]:
|
| 37 |
loader = ImagePDFLoader(pdf)
|
| 38 |
images = loader.load()
|
| 39 |
|
| 40 |
figures = []
|
| 41 |
+
|
| 42 |
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
| 43 |
+
future_to_image = {
|
| 44 |
+
executor.submit(_extract_figures, url, image, task): image
|
| 45 |
+
for image in images
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
for future in as_completed(future_to_image):
|
| 49 |
+
try:
|
| 50 |
+
result = future.result()
|
| 51 |
+
figures.extend(result)
|
| 52 |
+
except Exception as exc:
|
| 53 |
+
print(f'图像处理时发生错误: {exc}')
|
| 54 |
|
| 55 |
base64_figures = []
|
| 56 |
for figure, score in figures:
|
|
|
|
| 64 |
|
| 65 |
|
| 66 |
if __name__ == "__main__":
|
| 67 |
+
url = "https://kr4t0n--yolo-layout-detection-temp-layoutdetection-predict.modal.run"
|
| 68 |
pdf = "1.pdf"
|
| 69 |
|
| 70 |
output_dir = Path("output")
|
poster/poster.py
CHANGED
|
@@ -6,6 +6,7 @@ import re
|
|
| 6 |
import subprocess
|
| 7 |
import time
|
| 8 |
import cairosvg
|
|
|
|
| 9 |
|
| 10 |
from PIL import Image
|
| 11 |
from pdf2image import convert_from_path
|
|
@@ -388,8 +389,12 @@ def generate_html_v2(vendor: str, model: str, poster: BaseModel, figures: list[s
|
|
| 388 |
/ poster_total_size
|
| 389 |
)
|
| 390 |
|
| 391 |
-
max_attempts =
|
| 392 |
-
attempt =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 393 |
|
| 394 |
while True:
|
| 395 |
body = re.search(r"```html\n(.*?)\n```", output, re.DOTALL).group(1)
|
|
@@ -401,6 +406,13 @@ def generate_html_v2(vendor: str, model: str, poster: BaseModel, figures: list[s
|
|
| 401 |
section_sizes = get_sizes("section", html_with_figures)
|
| 402 |
|
| 403 |
proportion = calculate_blank_proportion(poster_sizes, section_sizes)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 404 |
if proportion <= 0.1:
|
| 405 |
print(
|
| 406 |
f"Attempted {attempt} times, remaining {proportion:.0%} blank spaces."
|
|
@@ -409,10 +421,16 @@ def generate_html_v2(vendor: str, model: str, poster: BaseModel, figures: list[s
|
|
| 409 |
|
| 410 |
attempt += 1
|
| 411 |
if attempt > max_attempts:
|
| 412 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 413 |
|
| 414 |
react = [
|
| 415 |
-
# AIMessage(""),
|
| 416 |
HumanMessage(
|
| 417 |
content=f"""# Previous Body
|
| 418 |
{body}
|
|
@@ -514,10 +532,16 @@ def generate_poster_v3(
|
|
| 514 |
model=model,
|
| 515 |
temperature=1,
|
| 516 |
max_tokens=8000,
|
| 517 |
-
# model_kwargs={
|
| 518 |
-
# "extra_body": {"chat_template_kwargs": {"enable_thinking": False}}
|
| 519 |
-
# },
|
| 520 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 521 |
loader = PyMuPDFLoader(pdf)
|
| 522 |
pages = loader.load()
|
| 523 |
paper_content = "\n".join([page.page_content for page in pages])
|
|
@@ -629,16 +653,38 @@ Paper content:
|
|
| 629 |
figures_with_descriptions = f.read()
|
| 630 |
else:
|
| 631 |
figure_chain = figures_description_prompt | (mllm if use_claude else llm)
|
| 632 |
-
|
|
|
|
|
|
|
| 633 |
figure_description_response = figure_chain.invoke({"image_data": figure})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 634 |
figures_with_descriptions += f"""
|
| 635 |
<figure_{i}>
|
| 636 |
-
{
|
| 637 |
</figure_{i}>
|
| 638 |
"""
|
| 639 |
-
figure_list.append(
|
| 640 |
-
|
| 641 |
-
|
|
|
|
|
|
|
| 642 |
if use_claude:
|
| 643 |
with open(figures_description_cache, "w") as f:
|
| 644 |
f.write(figures_with_descriptions)
|
|
|
|
| 6 |
import subprocess
|
| 7 |
import time
|
| 8 |
import cairosvg
|
| 9 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 10 |
|
| 11 |
from PIL import Image
|
| 12 |
from pdf2image import convert_from_path
|
|
|
|
| 389 |
/ poster_total_size
|
| 390 |
)
|
| 391 |
|
| 392 |
+
max_attempts = 6
|
| 393 |
+
attempt = 0
|
| 394 |
+
|
| 395 |
+
min_proportion = float('inf')
|
| 396 |
+
min_html = None
|
| 397 |
+
min_html_with_figures = None
|
| 398 |
|
| 399 |
while True:
|
| 400 |
body = re.search(r"```html\n(.*?)\n```", output, re.DOTALL).group(1)
|
|
|
|
| 406 |
section_sizes = get_sizes("section", html_with_figures)
|
| 407 |
|
| 408 |
proportion = calculate_blank_proportion(poster_sizes, section_sizes)
|
| 409 |
+
|
| 410 |
+
print(f"当前比例: {proportion:.0%}")
|
| 411 |
+
if proportion < min_proportion:
|
| 412 |
+
min_proportion = proportion
|
| 413 |
+
min_html = html
|
| 414 |
+
min_html_with_figures = html_with_figures
|
| 415 |
+
|
| 416 |
if proportion <= 0.1:
|
| 417 |
print(
|
| 418 |
f"Attempted {attempt} times, remaining {proportion:.0%} blank spaces."
|
|
|
|
| 421 |
|
| 422 |
attempt += 1
|
| 423 |
if attempt > max_attempts:
|
| 424 |
+
if min_proportion <= 0.2:
|
| 425 |
+
print(
|
| 426 |
+
f"Reached max attempts ({max_attempts}), returning best result with {min_proportion:.0%} blank spaces."
|
| 427 |
+
)
|
| 428 |
+
return {"html": min_html, "html_with_figures": min_html_with_figures}
|
| 429 |
+
else:
|
| 430 |
+
raise ValueError(f"Invalid blank spaces: {min_proportion:.0%}")
|
| 431 |
+
|
| 432 |
|
| 433 |
react = [
|
|
|
|
| 434 |
HumanMessage(
|
| 435 |
content=f"""# Previous Body
|
| 436 |
{body}
|
|
|
|
| 532 |
model=model,
|
| 533 |
temperature=1,
|
| 534 |
max_tokens=8000,
|
|
|
|
|
|
|
|
|
|
| 535 |
)
|
| 536 |
+
elif vendor == "azure":
|
| 537 |
+
llm = AzureChatOpenAI(
|
| 538 |
+
azure_deployment=model,
|
| 539 |
+
temperature=1,
|
| 540 |
+
max_tokens=8000,
|
| 541 |
+
)
|
| 542 |
+
else:
|
| 543 |
+
raise ValueError(f"Unsupported vendor: {vendor}")
|
| 544 |
+
|
| 545 |
loader = PyMuPDFLoader(pdf)
|
| 546 |
pages = loader.load()
|
| 547 |
paper_content = "\n".join([page.page_content for page in pages])
|
|
|
|
| 653 |
figures_with_descriptions = f.read()
|
| 654 |
else:
|
| 655 |
figure_chain = figures_description_prompt | (mllm if use_claude else llm)
|
| 656 |
+
|
| 657 |
+
def process_single_figure(figure_data):
|
| 658 |
+
figure, index = figure_data
|
| 659 |
figure_description_response = figure_chain.invoke({"image_data": figure})
|
| 660 |
+
return {
|
| 661 |
+
"index": index,
|
| 662 |
+
"figure": figure,
|
| 663 |
+
"description": figure_description_response.content
|
| 664 |
+
}
|
| 665 |
+
|
| 666 |
+
figure_data_list = [(figure, i) for i, figure in enumerate(figures)]
|
| 667 |
+
|
| 668 |
+
with ThreadPoolExecutor(max_workers=4) as executor:
|
| 669 |
+
results = list(tqdm(
|
| 670 |
+
executor.map(process_single_figure, figure_data_list),
|
| 671 |
+
total=len(figure_data_list),
|
| 672 |
+
desc=f"处理图片 {pdf}"
|
| 673 |
+
))
|
| 674 |
+
|
| 675 |
+
for result in results:
|
| 676 |
+
i = result["index"]
|
| 677 |
+
print(f"处理图片 {i} 完成")
|
| 678 |
figures_with_descriptions += f"""
|
| 679 |
<figure_{i}>
|
| 680 |
+
{result["description"]}
|
| 681 |
</figure_{i}>
|
| 682 |
"""
|
| 683 |
+
figure_list.append({
|
| 684 |
+
"figure": result["figure"],
|
| 685 |
+
"description": result["description"]
|
| 686 |
+
})
|
| 687 |
+
|
| 688 |
if use_claude:
|
| 689 |
with open(figures_description_cache, "w") as f:
|
| 690 |
f.write(figures_with_descriptions)
|