Doramong commited on
Commit
6f243e4
·
verified ·
1 Parent(s): 19fd1c5

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +281 -0
README.md ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ```
2
+ import json
3
+ import copy
4
+ from PIL import Image
5
+ from pypdf import PdfReader
6
+ from vllm import LLM, SamplingParams
7
+ from ocrflux.image_utils import get_page_image
8
+ from ocrflux.table_format import table_matrix2html
9
+ from ocrflux.prompts import PageResponse, build_page_to_markdown_prompt, build_element_merge_detect_prompt, build_html_table_merge_prompt
10
+ import requests
11
+
12
+ import base64
13
+ from io import BytesIO
14
+ from PIL import Image
15
+ import httpx
16
+
17
+ import asyncio
18
+
19
+ def pil_to_base64(img: Image.Image, format: str = "PNG") -> str:
20
+ buffered = BytesIO()
21
+ img.save(buffered, format=format)
22
+ img_bytes = buffered.getvalue()
23
+ img_base64 = base64.b64encode(img_bytes).decode("utf-8")
24
+ return img_base64
25
+
26
+ async def get_response(messages, temperature):
27
+ url = "http://127.0.0.1:8000/v1/chat/completions"
28
+ headers = {"Content-Type": "application/json"}
29
+ payload = {
30
+ "model": "ChatDOC/OCRFlux-3B",
31
+ "temperature": temperature,
32
+ "messages": messages,
33
+ "stream": False,
34
+ "max_tokens": 4096,
35
+ }
36
+
37
+ timeout = httpx.Timeout(60.0) # 전체 요청 제한 시간: 60초
38
+
39
+ async with httpx.AsyncClient(timeout=timeout) as client:
40
+ response = await client.post(url, json=payload, headers=headers)
41
+ response.raise_for_status()
42
+ return response.json()["choices"][0]['message']['content']
43
+
44
+ def build_qwen2_5_vl_prompt(question):
45
+ messages = []
46
+ messages.append({"role":"system", "content":"You are a helpful assistant."})
47
+ messages.append({"role":"user", "content":[{"type":"text", "text":f"<|vision_start|><|image_pad|><|vision_end|>{question}"}]})
48
+
49
+ return messages
50
+
51
+
52
+ def build_page_to_markdown_query(file_path: str, page_number: int, target_longest_image_dim: int = 1024, image_rotation: int = 0) -> dict:
53
+ assert image_rotation in [0, 90, 180, 270], "Invalid image rotation provided in build_page_query"
54
+ image = get_page_image(file_path, page_number, target_longest_image_dim=target_longest_image_dim, image_rotation=image_rotation)
55
+ question = build_page_to_markdown_prompt()
56
+ prompt = build_qwen2_5_vl_prompt(question)
57
+ prompt[-1]['content'].append({"type":"image_url","image_url": {"url":f"data:image/png;base64,{pil_to_base64(image)}"}})
58
+ return prompt
59
+
60
+ def build_element_merge_detect_query(text_list_1,text_list_2) -> dict:
61
+ image = Image.new('RGB', (28, 28), color='black')
62
+ question = build_element_merge_detect_prompt(text_list_1,text_list_2)
63
+ prompt = build_qwen2_5_vl_prompt(question)
64
+ prompt[-1]['content'].append({"type":"image_url","image_url": {"url":f"data:image/png;base64,{pil_to_base64(image)}"}})
65
+ return prompt
66
+
67
+ def build_html_table_merge_query(text_1,text_2) -> dict:
68
+ image = Image.new('RGB', (28, 28), color='black')
69
+ question = build_html_table_merge_prompt(text_1,text_2)
70
+ prompt = build_qwen2_5_vl_prompt(question)
71
+ prompt[-1]['content'].append({"type":"image_url","image_url": {"url":f"data:image/png;base64,{pil_to_base64(image)}"}})
72
+ return prompt
73
+
74
+ def bulid_document_text(page_to_markdown_result, element_merge_detect_result, html_table_merge_result):
75
+ page_to_markdown_keys = list(page_to_markdown_result.keys())
76
+ element_merge_detect_keys = list(element_merge_detect_result.keys())
77
+ html_table_merge_keys = list(html_table_merge_result.keys())
78
+
79
+ for page_1,page_2,elem_idx_1,elem_idx_2 in sorted(html_table_merge_keys,key=lambda x: -x[0]):
80
+ page_to_markdown_result[page_1][elem_idx_1] = html_table_merge_result[(page_1,page_2,elem_idx_1,elem_idx_2)]
81
+ page_to_markdown_result[page_2][elem_idx_2] = ''
82
+
83
+ for page_1,page_2 in sorted(element_merge_detect_keys,key=lambda x: -x[0]):
84
+ for elem_idx_1,elem_idx_2 in element_merge_detect_result[(page_1,page_2)]:
85
+ if len(page_to_markdown_result[page_1][elem_idx_1]) == 0 or page_to_markdown_result[page_1][elem_idx_1][-1] == '-' or ('\u4e00' <= page_to_markdown_result[page_1][elem_idx_1][-1] <= '\u9fff'):
86
+ page_to_markdown_result[page_1][elem_idx_1] = page_to_markdown_result[page_1][elem_idx_1] + '' + page_to_markdown_result[page_2][elem_idx_2]
87
+ else:
88
+ page_to_markdown_result[page_1][elem_idx_1] = page_to_markdown_result[page_1][elem_idx_1] + ' ' + page_to_markdown_result[page_2][elem_idx_2]
89
+ page_to_markdown_result[page_2][elem_idx_2] = ''
90
+
91
+ document_text_list = []
92
+ for page in page_to_markdown_keys:
93
+ page_text_list = [s for s in page_to_markdown_result[page] if s]
94
+ document_text_list += page_text_list
95
+ return "\n\n".join(document_text_list)
96
+
97
+ async def parse(file_path,skip_cross_page_merge=False,max_page_retries=0):
98
+ sampling_params = SamplingParams(temperature=0.0,max_tokens=8192)
99
+ if file_path.lower().endswith(".pdf"):
100
+ try:
101
+ reader = PdfReader(file_path)
102
+ num_pages = reader.get_num_pages()
103
+ except:
104
+ return None
105
+ else:
106
+ num_pages = 1
107
+
108
+ # try:
109
+ # Stage 1: Page to Markdown
110
+ page_to_markdown_query_list = [build_page_to_markdown_query(file_path,page_num) for page_num in range(1, num_pages + 1)]
111
+ # responses = [get_response(page_to_markdown_query, 0.0) for page_to_markdown_query in page_to_markdown_query_list]
112
+ tasks = [
113
+ get_response(query, 0.0)
114
+ for query in page_to_markdown_query_list
115
+ ]
116
+ responses = await asyncio.gather(*tasks)
117
+ results = [response for response in responses]
118
+ page_to_markdown_result = {}
119
+ retry_list = []
120
+ for i,result in enumerate(results):
121
+ try:
122
+ json_data = json.loads(result)
123
+ page_response = PageResponse(**json_data)
124
+ natural_text = page_response.natural_text
125
+ markdown_element_list = []
126
+ for text in natural_text.split('\n\n'):
127
+ if text.startswith("<Image>") and text.endswith("</Image>"):
128
+ pass
129
+ elif text.startswith("<table>") and text.endswith("</table>"):
130
+ try:
131
+ new_text = table_matrix2html(text)
132
+ except:
133
+ new_text = text.replace("<t>","").replace("<l>","").replace("<lt>","")
134
+ markdown_element_list.append(new_text)
135
+ else:
136
+ markdown_element_list.append(text)
137
+ page_to_markdown_result[i+1] = markdown_element_list
138
+ except:
139
+ retry_list.append(i)
140
+
141
+ attempt = 0
142
+ while len(retry_list) > 0 and attempt < max_page_retries:
143
+ retry_page_to_markdown_query_list = [build_page_to_markdown_query(file_path,page_num) for page_num in retry_list]
144
+ # retry_sampling_params = SamplingParams(temperature=0.1*attempt, max_tokens=8192)
145
+ # responses = [get_response(retry_page_to_markdown_query, 0.1*attempt) for retry_page_to_markdown_query in retry_page_to_markdown_query_list]
146
+ # responses = llm.generate(retry_page_to_markdown_query_list, sampling_params=retry_sampling_params)
147
+ tasks = [
148
+ get_response(query, 0.1*attempt)
149
+ for query in retry_page_to_markdown_query_list
150
+ ]
151
+ responses = await asyncio.gather(*tasks)
152
+ results = [response for response in responses]
153
+ next_retry_list = []
154
+ for i,result in zip(retry_list,results):
155
+ try:
156
+ json_data = json.loads(result)
157
+ page_response = PageResponse(**json_data)
158
+ natural_text = page_response.natural_text
159
+ markdown_element_list = []
160
+ for text in natural_text.split('\n\n'):
161
+ if text.startswith("<Image>") and text.endswith("</Image>"):
162
+ pass
163
+ elif text.startswith("<table>") and text.endswith("</table>"):
164
+ try:
165
+ new_text = table_matrix2html(text)
166
+ except:
167
+ new_text = text.replace("<t>","").replace("<l>","").replace("<lt>","")
168
+ markdown_element_list.append(new_text)
169
+ else:
170
+ markdown_element_list.append(text)
171
+ page_to_markdown_result[i+1] = markdown_element_list
172
+ except:
173
+ next_retry_list.append(i)
174
+ retry_list = next_retry_list
175
+ attempt += 1
176
+
177
+ page_texts = {}
178
+ fallback_pages = []
179
+ for page_number in range(1, num_pages+1):
180
+ if page_number not in page_to_markdown_result.keys():
181
+ fallback_pages.append(page_number-1)
182
+ else:
183
+ page_texts[str(page_number-1)] = "\n\n".join(page_to_markdown_result[page_number])
184
+
185
+ if skip_cross_page_merge:
186
+ document_text_list = []
187
+ for i in range(num_pages):
188
+ if i not in fallback_pages:
189
+ document_text_list.append(page_texts[str(i)])
190
+ document_text = "\n\n".join(document_text_list)
191
+ return {
192
+ "orig_path": file_path,
193
+ "num_pages": num_pages,
194
+ "document_text": document_text,
195
+ "page_texts": page_texts,
196
+ "fallback_pages": fallback_pages,
197
+ }
198
+
199
+ # Stage 2: Element Merge Detect
200
+ element_merge_detect_keys = []
201
+ element_merge_detect_query_list = []
202
+ for page_num in range(1,num_pages):
203
+ if page_num in page_to_markdown_result.keys() and page_num+1 in page_to_markdown_result.keys():
204
+ element_merge_detect_query_list.append(build_element_merge_detect_query(page_to_markdown_result[page_num],page_to_markdown_result[page_num+1]))
205
+ element_merge_detect_keys.append((page_num,page_num+1))
206
+
207
+ # responses = [get_response(element_merge_detect_query, 0.0) for element_merge_detect_query in element_merge_detect_query_list]
208
+ # responses = llm.generate(element_merge_detect_query_list, sampling_params=sampling_params)
209
+ tasks = [
210
+ get_response(query, 0.0)
211
+ for query in element_merge_detect_query_list
212
+ ]
213
+ responses = await asyncio.gather(*tasks)
214
+ results = [response for response in responses]
215
+ element_merge_detect_result = {}
216
+ for key,result in zip(element_merge_detect_keys,results):
217
+ try:
218
+ element_merge_detect_result[key] = eval(result)
219
+ except:
220
+ pass
221
+
222
+ # Stage 3: HTML Table Merge
223
+ html_table_merge_keys = []
224
+ for key,result in element_merge_detect_result.items():
225
+ page_1,page_2 = key
226
+ for elem_idx_1,elem_idx_2 in result:
227
+ text_1 = page_to_markdown_result[page_1][elem_idx_1]
228
+ text_2 = page_to_markdown_result[page_2][elem_idx_2]
229
+ if text_1.startswith("<table>") and text_1.endswith("</table>") and text_2.startswith("<table>") and text_2.endswith("</table>"):
230
+ html_table_merge_keys.append((page_1,page_2,elem_idx_1,elem_idx_2))
231
+
232
+ html_table_merge_keys = sorted(html_table_merge_keys,key=lambda x: -x[0])
233
+
234
+ html_table_merge_result = {}
235
+ page_to_markdown_result_tmp = copy.deepcopy(page_to_markdown_result)
236
+ i = 0
237
+ while i < len(html_table_merge_keys):
238
+ tmp = set()
239
+ keys = []
240
+ while i < len(html_table_merge_keys):
241
+ page_1,page_2,elem_idx_1,elem_idx_2 = html_table_merge_keys[i]
242
+ if (page_2,elem_idx_2) in tmp:
243
+ break
244
+ tmp.add((page_1,elem_idx_1))
245
+ keys.append((page_1,page_2,elem_idx_1,elem_idx_2))
246
+ i += 1
247
+
248
+ html_table_merge_query_list = [build_html_table_merge_query(page_to_markdown_result_tmp[page_1][elem_idx_1],page_to_markdown_result_tmp[page_2][elem_idx_2]) for page_1,page_2,elem_idx_1,elem_idx_2 in keys]
249
+ # responses = [get_response(html_table_merge_query, 0.0) for html_table_merge_query in html_table_merge_query_list]
250
+ # responses = llm.generate(html_table_merge_query_list, sampling_params=sampling_params)
251
+ tasks = [
252
+ get_response(query, 0.0)
253
+ for query in html_table_merge_query_list
254
+ ]
255
+ responses = await asyncio.gather(*tasks)
256
+ results = [response for response in responses]
257
+ for key,result in zip(keys,results):
258
+ if result.startswith("<table>") and result.endswith("</table>"):
259
+ html_table_merge_result[key] = result
260
+ page_to_markdown_result_tmp[page_1][elem_idx_1] = result
261
+
262
+ document_text = bulid_document_text(page_to_markdown_result, element_merge_detect_result, html_table_merge_result)
263
+ return {
264
+ "orig_path": file_path,
265
+ "num_pages": num_pages,
266
+ "document_text": document_text,
267
+ "page_texts": page_texts,
268
+ "fallback_pages": fallback_pages,
269
+ }
270
+
271
+
272
+ file_path = '/content/test.pdf'
273
+ result = await parse(file_path)
274
+ if result != None:
275
+ document_markdown = result['document_text']
276
+ print(document_markdown)
277
+ with open('test.md','w') as f:
278
+ f.write(document_markdown)
279
+ else:
280
+ print("Parse failed.")
281
+ ```