NeuroDong commited on
Commit
39f5059
·
1 Parent(s): 73eaf2d
Files changed (2) hide show
  1. app.py +127 -0
  2. requirements.txt +9 -0
app.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ import gradio as gr
4
+ import pypdfium2 as pdfium
5
+ from PIL import Image
6
+ import torch
7
+ from transformers import AutoProcessor, VisionEncoderDecoderModel
8
+
9
+ # ========= 配置 =========
10
+ # 可通过 Space Settings -> Variables 设置这些环境变量
11
+ MODEL_ID = os.getenv("MODEL_ID", "facebook/nougat-small") # small: CC-BY-4.0;base: CC-BY-NC-4.0
12
+ DEFAULT_DPI = int(os.getenv("DEFAULT_DPI", "144")) # 96~288;越高越清晰但更耗时
13
+ MAX_PAGES = int(os.getenv("MAX_PAGES", "20")) # 限制一次处理页数,避免超时
14
+
15
+ # ========= 模型加载 =========
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
17
+ processor = AutoProcessor.from_pretrained(MODEL_ID)
18
+ model = VisionEncoderDecoderModel.from_pretrained(MODEL_ID).to(device)
19
+
20
+ # ========= 工具函数 =========
21
+ def rasterize_pages(pdf_bytes: bytes, dpi: int = DEFAULT_DPI):
22
+ """
23
+ 将 PDF bytes 渲染为 PIL.Image 列表(每页一张)。
24
+ 说明:pypdfium2 的 Page.render(scale=...) 返回位图;dpi/72 为常用缩放方式。
25
+ """
26
+ with tempfile.NamedTemporaryFile(suffix=".pdf") as tmp:
27
+ tmp.write(pdf_bytes)
28
+ tmp.flush()
29
+ doc = pdfium.PdfDocument(tmp.name)
30
+ images = []
31
+ for i in range(len(doc)):
32
+ page = doc.get_page(i)
33
+ bitmap = page.render(scale=dpi/72.0) # dpi/72 缩放
34
+ img = bitmap.to_pil().convert("RGB")
35
+ bitmap.close()
36
+ page.close()
37
+ images.append(img)
38
+ doc.close()
39
+ return images
40
+
41
+ def parse_pages_arg(pages_str: str, n_pages: int):
42
+ """
43
+ 解析页码字符串:如 '1-4,7' 或 'all'
44
+ 返回 0-based 下标列表。
45
+ """
46
+ if not pages_str or pages_str.strip().lower() == "all":
47
+ return list(range(n_pages))
48
+ keep = []
49
+ for span in pages_str.split(","):
50
+ span = span.strip()
51
+ if "-" in span:
52
+ a, b = span.split("-")
53
+ a = max(1, int(a)); b = min(n_pages, int(b))
54
+ keep.extend(list(range(a-1, b)))
55
+ else:
56
+ k = int(span) - 1
57
+ if 0 <= k < n_pages:
58
+ keep.append(k)
59
+ return sorted(set(keep))
60
+
61
+ # ========= 核心推理函数(UI 与 API 共用) =========
62
+ def convert_pdf(pdf_file, pages="all", dpi=DEFAULT_DPI):
63
+ """
64
+ 输入:
65
+ - pdf_file: Gradio File(浏览器上传的 PDF)
66
+ - pages: 'all' 或 '1-4,7'
67
+ - dpi: 渲染 DPI
68
+ 输出:
69
+ - out_path: 生成的 .mmd 文件路径(供下载)
70
+ - preview: Markdown 预览(前几页)
71
+ """
72
+ if pdf_file is None:
73
+ raise gr.Error("请上传 PDF 文件")
74
+
75
+ # 读取 PDF bytes 并渲染为图像
76
+ pdf_bytes = pdf_file.read()
77
+ images_all = rasterize_pages(pdf_bytes, dpi=int(dpi))
78
+
79
+ # 页码选择与限制
80
+ idx = parse_pages_arg(pages, len(images_all))
81
+ if not idx:
82
+ raise gr.Error("页码选择为空")
83
+ if len(idx) > MAX_PAGES:
84
+ idx = idx[:MAX_PAGES]
85
+
86
+ # 逐页调用 Nougat 模型生成 Markdown
87
+ md_pages = []
88
+ for k in idx:
89
+ img = images_all[k]
90
+ inputs = processor(images=[img], return_tensors="pt").to(device)
91
+ ids = model.generate(**inputs, max_length=4096)
92
+ md = processor.batch_decode(ids, skip_special_tokens=True)[0]
93
+ md_pages.append(md)
94
+
95
+ # 保存到临时 .mmd 文件
96
+ out_path = os.path.join(tempfile.gettempdir(), "nougat_output.mmd")
97
+ with open(out_path, "w", encoding="utf-8") as f:
98
+ f.write("\n\n".join(md_pages))
99
+
100
+ # 预览(前若干页)
101
+ preview = "\n\n".join(md_pages[:3])
102
+ return out_path, preview
103
+
104
+ # ========= Gradio 应用(UI + API) =========
105
+ with gr.Blocks(title="Nougat OCR → Markdown") as demo:
106
+ gr.Markdown(
107
+ "# Nougat:PDF → Markdown\n"
108
+ f"**模型**:`{MODEL_ID}` (small 为 CC‑BY‑4.0;base 为 CC‑BY‑NC‑4.0)。\n"
109
+ "上传 PDF,选择页码与 DPI,点击转换即可下载 `.mmd`。\n"
110
+ )
111
+ with gr.Row():
112
+ pdf = gr.File(label="上传 PDF", file_types=[".pdf"])
113
+ pages = gr.Textbox(value="all", label="页码(如 1-4,7 或 all)")
114
+ dpi = gr.Slider(96, 288, value=DEFAULT_DPI, step=12, label="渲染 DPI")
115
+ btn = gr.Button("转换", variant="primary")
116
+ out_file = gr.File(label="下载 Markdown(.mmd)")
117
+ out_preview = gr.Markdown(label="预览(前几页)")
118
+
119
+ # 队列可避免并发拥堵;也使 API 端口支持异步排队
120
+ demo.queue(max_size=32, concurrency_count=1)
121
+
122
+ # 关键:为点击事件绑定一个可供 REST 调用的 api_name(对应 /api/predict)
123
+ btn.click(convert_pdf, inputs=[pdf, pages, dpi], outputs=[out_file, out_preview], api_name="predict")
124
+
125
+ # (本地调试用;在 Spaces 中无需)
126
+ if __name__ == "__main__":
127
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=3.40
2
+ transformers>=4.30.0
3
+ torch>=2.0.0
4
+ pillow>=9.0.0
5
+ pypdfium2>=5.0.0
6
+ huggingface-hub>=0.16.0
7
+ accelerate>=0.20.0
8
+ safetensors>=0.3.0
9
+ sentencepiece>=0.1.99