Upload 2 files
Browse files- .gitattributes +1 -0
- abot-ocr-infer.py +258 -0
- metric.png +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
metric.png filter=lfs diff=lfs merge=lfs -text
|
abot-ocr-infer.py
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from vllm import LLM, SamplingParams
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from transformers import AutoProcessor
|
| 7 |
+
|
| 8 |
+
def generate_prompt(image_path):
|
| 9 |
+
PROMPT = '''You are an AI assistant specialized in converting PDF images to Markdown format. Please follow these instructions for the conversion:
|
| 10 |
+
|
| 11 |
+
1. Text Processing:
|
| 12 |
+
- Accurately recognize all text content in the PDF image without guessing or inferring.
|
| 13 |
+
- Convert the recognized text into Markdown format.
|
| 14 |
+
- Maintain the original document structure, including headings, paragraphs, lists, etc.
|
| 15 |
+
|
| 16 |
+
2. Mathematical Formula Processing:
|
| 17 |
+
- Convert all mathematical formulas to LaTeX format.
|
| 18 |
+
- Enclose inline formulas with,(,). For example: This is an inline formula,( E = mc^2,)
|
| 19 |
+
- Enclose block formulas with,\[,\]. For example:,[,frac{-b,pm,sqrt{b^2 - 4ac}}{2a},]
|
| 20 |
+
|
| 21 |
+
3. Table Processing:
|
| 22 |
+
- Convert tables to HTML format.
|
| 23 |
+
- Wrap the entire table with <table> and </table>.
|
| 24 |
+
|
| 25 |
+
4. Figure Handling:
|
| 26 |
+
- Ignore figures content in the PDF image. Do not attempt to describe or convert images.
|
| 27 |
+
|
| 28 |
+
5. Output Format:
|
| 29 |
+
- Ensure the output Markdown document has a clear structure with appropriate line breaks between elements.
|
| 30 |
+
- For complex layouts, try to maintain the original document's structure and format as closely as possible.
|
| 31 |
+
|
| 32 |
+
Please strictly follow these guidelines to ensure accuracy and consistency in the conversion. Your task is to accurately convert the content of the PDF image into Markdown format without adding any extra explanations or comments.
|
| 33 |
+
'''
|
| 34 |
+
user_conversation = [
|
| 35 |
+
{
|
| 36 |
+
"role": "user",
|
| 37 |
+
"content": [
|
| 38 |
+
{"type": "image", "image": image_path},
|
| 39 |
+
{"type": "text", "text": PROMPT},
|
| 40 |
+
],
|
| 41 |
+
},
|
| 42 |
+
]
|
| 43 |
+
return user_conversation
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
MODEL_PATH = str(Path(__file__).resolve().parent / "abot-ocr")
|
| 47 |
+
|
| 48 |
+
TOKENIZER_PATH: str | None = None
|
| 49 |
+
|
| 50 |
+
SUPPORTED_IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp"}
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def append_failed_image_log(image_path: Path, reason: str, log_file: str = "failed_images.log") -> None:
|
| 54 |
+
with open(log_file, "a", encoding="utf-8") as f:
|
| 55 |
+
f.write(f"{image_path}\t{reason}\n")
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def post_process_text(text: str) -> str:
|
| 59 |
+
n = len(text)
|
| 60 |
+
if n < 8000:
|
| 61 |
+
return text
|
| 62 |
+
for length in range(2, n // 10 + 1):
|
| 63 |
+
candidate = text[-length:]
|
| 64 |
+
count = 0
|
| 65 |
+
i = n - length
|
| 66 |
+
while i >= 0 and text[i:i + length] == candidate:
|
| 67 |
+
count += 1
|
| 68 |
+
i -= length
|
| 69 |
+
if count >= 10:
|
| 70 |
+
return text[:n - length * (count - 1)]
|
| 71 |
+
return text
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def resolve_markdown_path(
|
| 75 |
+
image_path: Path,
|
| 76 |
+
output_dir: Path | None = None,
|
| 77 |
+
input_root: Path | None = None,
|
| 78 |
+
) -> Path:
|
| 79 |
+
"""
|
| 80 |
+
计算 Markdown 输出路径。
|
| 81 |
+
|
| 82 |
+
- 未指定 output_dir:与图片同目录,仅替换后缀为 .md。
|
| 83 |
+
- 指定 output_dir:
|
| 84 |
+
- 若提供 input_root(目录输入场景),保留相对 input_root 的子目录结构;
|
| 85 |
+
- 否则直接输出到 output_dir 下。
|
| 86 |
+
"""
|
| 87 |
+
if output_dir is None:
|
| 88 |
+
return image_path.parent / (image_path.stem + ".md")
|
| 89 |
+
if input_root is None:
|
| 90 |
+
return output_dir / (image_path.stem + ".md")
|
| 91 |
+
relative_parent = image_path.relative_to(input_root).parent
|
| 92 |
+
return output_dir / relative_parent / (image_path.stem + ".md")
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def load_model(model_path: str = MODEL_PATH, tokenizer_path: str | None = TOKENIZER_PATH):
|
| 96 |
+
"""初始化 vLLM 引擎和处理器,返回 (llm, processor, sampling_params)。"""
|
| 97 |
+
tok = tokenizer_path or model_path
|
| 98 |
+
llm = LLM(model=model_path, tokenizer=tok, trust_remote_code=False)
|
| 99 |
+
processor = AutoProcessor.from_pretrained(tok)
|
| 100 |
+
sampling_params = SamplingParams(temperature=0, max_tokens=8192)
|
| 101 |
+
return llm, processor, sampling_params
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def collect_pending_images(
|
| 105 |
+
image_files: list[Path],
|
| 106 |
+
output_dir: Path | None = None,
|
| 107 |
+
input_root: Path | None = None,
|
| 108 |
+
) -> list[Path]:
|
| 109 |
+
"""过滤掉已存在对应 Markdown 的图片,返回待处理列表。"""
|
| 110 |
+
pending = [
|
| 111 |
+
f for f in image_files
|
| 112 |
+
if not resolve_markdown_path(f, output_dir, input_root).exists()
|
| 113 |
+
]
|
| 114 |
+
skipped = len(image_files) - len(pending)
|
| 115 |
+
if skipped > 0:
|
| 116 |
+
print(f"[~] 跳过已完成 {skipped} 张,待处理 {len(pending)} 张")
|
| 117 |
+
return pending
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def batch_infer(
|
| 121 |
+
image_files: list[Path], llm, processor, sampling_params
|
| 122 |
+
) -> list[tuple[Path, str]]:
|
| 123 |
+
"""
|
| 124 |
+
批量推理:一次性将所有图片送入 vLLM,充分利用连续批处理加速。
|
| 125 |
+
|
| 126 |
+
Returns:
|
| 127 |
+
(图片路径, 文本结果) 列表,仅包含成功读取并推理的图片。
|
| 128 |
+
"""
|
| 129 |
+
inputs = []
|
| 130 |
+
valid_image_files: list[Path] = []
|
| 131 |
+
for image_file in image_files:
|
| 132 |
+
try:
|
| 133 |
+
# verify() 先做完整性校验,随后重新 open 才可正常读取像素
|
| 134 |
+
with Image.open(image_file) as probe:
|
| 135 |
+
probe.verify()
|
| 136 |
+
img = Image.open(image_file).convert("RGB")
|
| 137 |
+
except Exception as e:
|
| 138 |
+
reason = f"{type(e).__name__}: {e}"
|
| 139 |
+
print(f"[跳过] 坏图或不可读: {image_file} ({reason})")
|
| 140 |
+
append_failed_image_log(image_file, reason)
|
| 141 |
+
continue
|
| 142 |
+
messages = generate_prompt(str(image_file))
|
| 143 |
+
prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 144 |
+
inputs.append({"prompt": prompt, "multi_modal_data": {"image": [img]}})
|
| 145 |
+
valid_image_files.append(image_file)
|
| 146 |
+
|
| 147 |
+
if not inputs:
|
| 148 |
+
return []
|
| 149 |
+
|
| 150 |
+
outputs = llm.generate(inputs, sampling_params)
|
| 151 |
+
texts = [post_process_text(output.outputs[0].text) for output in outputs]
|
| 152 |
+
return list(zip(valid_image_files, texts))
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def save_markdown(
|
| 156 |
+
image_path: Path,
|
| 157 |
+
text: str,
|
| 158 |
+
output_dir: Path | None = None,
|
| 159 |
+
input_root: Path | None = None,
|
| 160 |
+
) -> str:
|
| 161 |
+
"""将推理结果写入 Markdown 文件,返回文件路径。"""
|
| 162 |
+
markdown_file = resolve_markdown_path(image_path, output_dir, input_root)
|
| 163 |
+
markdown_file.parent.mkdir(parents=True, exist_ok=True)
|
| 164 |
+
with open(markdown_file, "w", encoding="utf-8") as f:
|
| 165 |
+
f.write(text)
|
| 166 |
+
print(f"[OK] {image_path.name} -> {markdown_file}")
|
| 167 |
+
return str(markdown_file)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def run_infer(
|
| 171 |
+
input_path: str,
|
| 172 |
+
llm,
|
| 173 |
+
processor,
|
| 174 |
+
sampling_params,
|
| 175 |
+
batch_size: int = 0,
|
| 176 |
+
output_dir: str | None = None,
|
| 177 |
+
) -> list[str]:
|
| 178 |
+
"""
|
| 179 |
+
对单张图片或目录(含多级子目录)批量执行 OCR。
|
| 180 |
+
Markdown 文件名与图片一致,后缀为 .md。
|
| 181 |
+
提供 output_dir 时输出到该目录;不提供则默认输出到输入目录。
|
| 182 |
+
|
| 183 |
+
Args:
|
| 184 |
+
input_path: 单张图片路径,或包含图片的目录路径(支持多级子目录)。
|
| 185 |
+
llm: 已初始化的 vLLM 引擎。
|
| 186 |
+
processor: 已加载的处理器。
|
| 187 |
+
sampling_params: 采样参数。
|
| 188 |
+
batch_size: 每批推理的图片数量,0 表示全部一次性推理(显存充足时最快)。
|
| 189 |
+
output_dir: 指定 Markdown 输出目录;不传时默认输出到输入图片所在目录。
|
| 190 |
+
|
| 191 |
+
Returns:
|
| 192 |
+
所有生成的 Markdown 文件路径列表。
|
| 193 |
+
"""
|
| 194 |
+
input_path = Path(input_path)
|
| 195 |
+
output_dir_path = Path(output_dir) if output_dir else None
|
| 196 |
+
|
| 197 |
+
if input_path.is_file():
|
| 198 |
+
if input_path.suffix.lower() not in SUPPORTED_IMAGE_EXTENSIONS:
|
| 199 |
+
raise ValueError(f"不支持的文件格式: {input_path.suffix}")
|
| 200 |
+
image_files = [input_path]
|
| 201 |
+
input_root = None
|
| 202 |
+
elif input_path.is_dir():
|
| 203 |
+
# rglob 递归遍历多级子目录
|
| 204 |
+
image_files = sorted([
|
| 205 |
+
f for f in input_path.rglob("*")
|
| 206 |
+
if f.is_file() and f.suffix.lower() in SUPPORTED_IMAGE_EXTENSIONS
|
| 207 |
+
])
|
| 208 |
+
if not image_files:
|
| 209 |
+
print(f"[!] 目录中未找到支持的图片文件: {input_path}")
|
| 210 |
+
return []
|
| 211 |
+
input_root = input_path
|
| 212 |
+
else:
|
| 213 |
+
raise FileNotFoundError(f"路径不存在: {input_path}")
|
| 214 |
+
|
| 215 |
+
print(f"[*] 共找到 {len(image_files)} 张图片")
|
| 216 |
+
pending_files = collect_pending_images(image_files, output_dir_path, input_root)
|
| 217 |
+
if not pending_files:
|
| 218 |
+
print("[完成] 所有图片均已处理,无需重新推理。")
|
| 219 |
+
return [
|
| 220 |
+
str(resolve_markdown_path(f, output_dir_path, input_root))
|
| 221 |
+
for f in image_files
|
| 222 |
+
]
|
| 223 |
+
|
| 224 |
+
# 按 batch_size 分批推理;batch_size=0 时一次性全量推理
|
| 225 |
+
chunk_size = batch_size if batch_size > 0 else len(pending_files)
|
| 226 |
+
output_files = []
|
| 227 |
+
for start in range(0, len(pending_files), chunk_size):
|
| 228 |
+
chunk = pending_files[start:start + chunk_size]
|
| 229 |
+
print(f"[*] 推理第 {start + 1}~{start + len(chunk)} 张...")
|
| 230 |
+
pairs = batch_infer(chunk, llm, processor, sampling_params)
|
| 231 |
+
for image_file, text in pairs:
|
| 232 |
+
output_files.append(
|
| 233 |
+
save_markdown(image_file, text, output_dir_path, input_root)
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
# 已跳过的文件也加入返回列表
|
| 237 |
+
skipped_files = [
|
| 238 |
+
str(resolve_markdown_path(f, output_dir_path, input_root))
|
| 239 |
+
for f in image_files if f not in pending_files
|
| 240 |
+
]
|
| 241 |
+
all_output_files = skipped_files + output_files
|
| 242 |
+
print(f"\n[完成] 新生成 {len(output_files)} 个,共 {len(all_output_files)} 个 Markdown 文件。")
|
| 243 |
+
return all_output_files
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
# ── 调用示例 ──────────────────────────────────────────────────────────────────
|
| 247 |
+
|
| 248 |
+
if __name__ == "__main__":
|
| 249 |
+
llm, processor, sampling_params = load_model()
|
| 250 |
+
|
| 251 |
+
run_infer(
|
| 252 |
+
input_path="images",
|
| 253 |
+
llm=llm,
|
| 254 |
+
processor=processor,
|
| 255 |
+
sampling_params=sampling_params,
|
| 256 |
+
batch_size=8,
|
| 257 |
+
output_dir="./abot-ocr-infer-output"
|
| 258 |
+
)
|
metric.png
ADDED
|
Git LFS Details
|