Vik Paruchuri commited on
Commit
518215f
·
1 Parent(s): e1ef281

Integrate new texify model

Browse files
marker/processors/equation.py CHANGED
@@ -22,7 +22,7 @@ class EquationProcessor(BaseProcessor):
22
  model_max_length: Annotated[
23
  int,
24
  "The maximum number of tokens to allow for the Texify model.",
25
- ] = 384
26
  texify_batch_size: Annotated[
27
  Optional[int],
28
  "The batch size to use for the Texify model.",
@@ -65,27 +65,7 @@ class EquationProcessor(BaseProcessor):
65
  continue
66
 
67
  block = document.get_block(equation_d["block_id"])
68
- block.html = self.parse_latex_to_html(prediction)
69
-
70
- def parse_latex_to_html(self, latex: str):
71
- html_out = ""
72
- try:
73
- latex = self.parse_latex(latex)
74
- except ValueError as e:
75
- # If we have mismatched delimiters, we'll treat it as a single block
76
- # Strip the $'s from the latex
77
- latex = [
78
- {"class": "block", "content": latex.replace("$", "")}
79
- ]
80
-
81
- for el in latex:
82
- if el["class"] == "block":
83
- html_out += f'<math display="block">{el["content"]}</math>'
84
- elif el["class"] == "inline":
85
- html_out += f'<math display="inline">{el["content"]}</math>'
86
- else:
87
- html_out += f" {el['content']} "
88
- return html_out.strip()
89
 
90
  def get_batch_size(self):
91
  if self.texify_batch_size is not None:
@@ -106,71 +86,22 @@ class EquationProcessor(BaseProcessor):
106
  max_idx = min(min_idx + batch_size, len(equation_data))
107
 
108
  batch_equations = equation_data[min_idx:max_idx]
109
- max_length = max([eq["token_count"] for eq in batch_equations])
110
- max_length = min(max_length, self.model_max_length)
111
- max_length += self.token_buffer
112
-
113
  batch_images = [eq["image"] for eq in batch_equations]
114
 
115
  model_output = self.texify_model(
116
- batch_images,
117
- max_tokens=max_length
118
  )
119
 
120
  for j, output in enumerate(model_output):
121
- token_count = self.get_total_texify_tokens(output)
122
- if token_count >= max_length - 1:
123
- output = ""
124
 
125
  image_idx = i + j
126
- predictions[image_idx] = output
127
  return predictions
128
 
129
  def get_total_texify_tokens(self, text):
130
  tokenizer = self.texify_model.processor.tokenizer
131
  tokens = tokenizer(text)
132
- return len(tokens["input_ids"])
133
-
134
-
135
- @staticmethod
136
- def parse_latex(text: str):
137
- if text.count("$") % 2 != 0:
138
- raise ValueError("Mismatched delimiters in LaTeX")
139
-
140
- DELIMITERS = [
141
- ("$$", "block"),
142
- ("$", "inline")
143
- ]
144
-
145
- text = text.replace("\n", "<br>") # we can't handle \n's inside <p> properly if we don't do this
146
-
147
- i = 0
148
- stack = []
149
- result = []
150
- buffer = ""
151
-
152
- while i < len(text):
153
- for delim, class_name in DELIMITERS:
154
- if text[i:].startswith(delim):
155
- if stack and stack[-1] == delim: # Closing
156
- stack.pop()
157
- result.append({"class": class_name, "content": buffer})
158
- buffer = ""
159
- i += len(delim)
160
- break
161
- elif not stack: # Opening
162
- if buffer:
163
- result.append({"class": "text", "content": buffer})
164
- stack.append(delim)
165
- buffer = ""
166
- i += len(delim)
167
- break
168
- else:
169
- raise ValueError(f"Nested {class_name} delimiters not supported")
170
- else: # No delimiter match
171
- buffer += text[i]
172
- i += 1
173
-
174
- if buffer:
175
- result.append({"class": "text", "content": buffer})
176
- return result
 
22
  model_max_length: Annotated[
23
  int,
24
  "The maximum number of tokens to allow for the Texify model.",
25
+ ] = 768
26
  texify_batch_size: Annotated[
27
  Optional[int],
28
  "The batch size to use for the Texify model.",
 
65
  continue
66
 
67
  block = document.get_block(equation_d["block_id"])
68
+ block.html = prediction
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
  def get_batch_size(self):
71
  if self.texify_batch_size is not None:
 
86
  max_idx = min(min_idx + batch_size, len(equation_data))
87
 
88
  batch_equations = equation_data[min_idx:max_idx]
 
 
 
 
89
  batch_images = [eq["image"] for eq in batch_equations]
90
 
91
  model_output = self.texify_model(
92
+ batch_images
 
93
  )
94
 
95
  for j, output in enumerate(model_output):
96
+ token_count = self.get_total_texify_tokens(output.text)
97
+ if token_count >= self.model_max_length - 1:
98
+ output.text = ""
99
 
100
  image_idx = i + j
101
+ predictions[image_idx] = output.text
102
  return predictions
103
 
104
  def get_total_texify_tokens(self, text):
105
  tokenizer = self.texify_model.processor.tokenizer
106
  tokens = tokenizer(text)
107
+ return len(tokens["input_ids"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
marker/renderers/markdown.py CHANGED
@@ -12,12 +12,16 @@ from marker.schema import BlockTypes
12
  from marker.schema.document import Document
13
 
14
 
 
 
 
15
  def cleanup_text(full_text):
16
  full_text = re.sub(r'\n{3,}', '\n\n', full_text)
17
  full_text = re.sub(r'(\n\s){3,}', '\n\n', full_text)
18
  return full_text.strip()
19
 
20
  def get_formatted_table_text(element):
 
21
  text = []
22
  for content in element.contents:
23
  if content is None:
@@ -26,13 +30,14 @@ def get_formatted_table_text(element):
26
  if isinstance(content, NavigableString):
27
  stripped = content.strip()
28
  if stripped:
29
- text.append(stripped)
30
  elif content.name == 'br':
31
  text.append('<br>')
32
  elif content.name == "math":
33
  text.append("$" + content.text + "$")
34
  else:
35
- text.append(str(content))
 
36
 
37
  full_text = ""
38
  for i, t in enumerate(text):
@@ -120,7 +125,7 @@ class Markdownify(MarkdownConverter):
120
  if r == 0 and c == 0:
121
  grid[row_idx][col_idx] = value
122
  else:
123
- grid[row_idx + r][col_idx + c] = ''
124
  except IndexError:
125
  # Sometimes the colspan/rowspan predictions can overflow
126
  print(f"Overflow in columns: {col_idx + c} >= {total_cols}")
 
12
  from marker.schema.document import Document
13
 
14
 
15
+ def escape_dollars(text):
16
+ return text.replace("$", r"\$")
17
+
18
  def cleanup_text(full_text):
19
  full_text = re.sub(r'\n{3,}', '\n\n', full_text)
20
  full_text = re.sub(r'(\n\s){3,}', '\n\n', full_text)
21
  return full_text.strip()
22
 
23
  def get_formatted_table_text(element):
24
+
25
  text = []
26
  for content in element.contents:
27
  if content is None:
 
30
  if isinstance(content, NavigableString):
31
  stripped = content.strip()
32
  if stripped:
33
+ text.append(escape_dollars(stripped))
34
  elif content.name == 'br':
35
  text.append('<br>')
36
  elif content.name == "math":
37
  text.append("$" + content.text + "$")
38
  else:
39
+ content_str = escape_dollars(str(content))
40
+ text.append(content_str)
41
 
42
  full_text = ""
43
  for i, t in enumerate(text):
 
125
  if r == 0 and c == 0:
126
  grid[row_idx][col_idx] = value
127
  else:
128
+ grid[row_idx + r][col_idx + c] = '' # Empty cell due to rowspan/colspan
129
  except IndexError:
130
  # Sometimes the colspan/rowspan predictions can overflow
131
  print(f"Overflow in columns: {col_idx + c} >= {total_cols}")