KarthiEz commited on
Commit
a8a9f71
·
verified ·
1 Parent(s): 55774c7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +201 -31
app.py CHANGED
@@ -26,7 +26,7 @@ from paddleocr import PaddleOCR
26
 
27
  # --------- Config knobs ----------
28
  LANG = os.getenv("OCR_LANG", "en")
29
- USE_GPU = os.getenv("OCR_USE_GPU", "false").lower() == "true"
30
  DET = os.getenv("OCR_DET_MODEL", "ch_PP-OCRv4_det")
31
  REC = os.getenv("OCR_REC_MODEL", "en_PP-OCRv4")
32
  CLS = True
@@ -46,6 +46,7 @@ def _build_ocr(use_cls: bool) -> PaddleOCR:
46
  show_log=False
47
  )
48
 
 
49
  _OCR = _build_ocr(CLS)
50
 
51
  class TextBlock:
@@ -53,10 +54,11 @@ class TextBlock:
53
  def __init__(self, text: str, confidence: float, bbox: List[List[int]], img_width: int, img_height: int):
54
  self.text = text
55
  self.confidence = confidence
56
- self.bbox = bbox
57
  self.img_width = img_width
58
  self.img_height = img_height
59
 
 
60
  x_coords = [p[0] for p in bbox]
61
  y_coords = [p[1] for p in bbox]
62
  self.x_min = min(x_coords)
@@ -68,62 +70,211 @@ class TextBlock:
68
  self.width = self.x_max - self.x_min
69
  self.height = self.y_max - self.y_min
70
 
71
- def _sort_blocks_by_layout(blocks: List[TextBlock], tolerance: float = 0.05) -> List[TextBlock]:
72
- """Sort text blocks to preserve reading order (top-to-bottom, left-to-right)."""
 
 
73
  if not blocks:
74
  return blocks
75
 
76
- max_height = max(b.y_max for b in blocks) if blocks else 1
77
-
78
  def get_sort_key(block: TextBlock):
79
- line_group = int(block.y_min / (max_height * tolerance))
80
- return (line_group, block.x_min)
 
81
 
82
  return sorted(blocks, key=get_sort_key)
83
 
84
  def _reconstruct_text_with_layout(blocks: List[TextBlock], img_width: int, img_height: int) -> str:
85
- """Reconstruct text preserving layout structure using bounding box positions."""
 
 
 
86
  if not blocks:
87
  return ""
88
 
89
  sorted_blocks = _sort_blocks_by_layout(blocks)
90
 
 
 
 
 
 
 
 
 
 
91
  lines = []
92
  current_line = []
93
  current_y = None
94
- line_tolerance = img_height * 0.02 # 2% of image height for line grouping
95
 
96
  for block in sorted_blocks:
97
- if current_y is None or abs(block.y_min - current_y) > line_tolerance:
98
- if current_line:
99
- lines.append(_format_line(current_line, img_width))
100
  current_line = [block]
101
- current_y = block.y_min
102
  else:
103
- current_line.append(block)
104
- current_y = sum(b.y_min for b in current_line) / len(current_line)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
 
106
  if current_line:
107
- lines.append(_format_line(current_line, img_width))
108
 
109
- return "\n".join(lines)
 
 
 
 
 
 
 
110
 
111
- def _format_line(blocks: List[TextBlock], img_width: int) -> str:
112
- """Format a line of text blocks, preserving horizontal spacing."""
 
 
 
113
  if not blocks:
114
  return ""
115
 
 
116
  blocks = sorted(blocks, key=lambda b: b.x_min)
117
 
118
- parts = []
119
- prev_x_end = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
  for block in blocks:
122
- gap = block.x_min - prev_x_end
 
 
 
 
 
 
123
 
124
- if gap > img_width * 0.05:
125
- num_spaces = max(1, int(gap / (img_width * 0.01)))
126
- parts.append(" " * min(num_spaces, 10))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
  parts.append(block.text)
129
  prev_x_end = block.x_max
@@ -131,7 +282,12 @@ def _format_line(blocks: List[TextBlock], img_width: int) -> str:
131
  return "".join(parts)
132
 
133
  def ocr_image(pil_img: Image.Image, preserve_layout: bool = True) -> Tuple[List[TextBlock], str]:
134
- """Perform OCR on image and return both structured blocks and formatted text."""
 
 
 
 
 
135
  img_cv = _pil_to_cv(pil_img)
136
  img_width, img_height = pil_img.size
137
 
@@ -143,6 +299,7 @@ def ocr_image(pil_img: Image.Image, preserve_layout: bool = True) -> Tuple[List[
143
  except RuntimeError as e:
144
  msg = str(e).lower()
145
  if "primitive" in msg or "mkldnn" in msg or "predictor.run" in msg:
 
146
  fallback_ocr = _build_ocr(False)
147
  result = _run(fallback_ocr, False)
148
  else:
@@ -153,7 +310,7 @@ def ocr_image(pil_img: Image.Image, preserve_layout: bool = True) -> Tuple[List[
153
  return blocks, ""
154
 
155
  for line in result[0]:
156
- bbox = line[0]
157
  txt = line[1][0]
158
  conf = float(line[1][1])
159
 
@@ -161,9 +318,11 @@ def ocr_image(pil_img: Image.Image, preserve_layout: bool = True) -> Tuple[List[
161
  block = TextBlock(txt, conf, bbox, img_width, img_height)
162
  blocks.append(block)
163
 
 
164
  if preserve_layout:
165
  formatted_text = _reconstruct_text_with_layout(blocks, img_width, img_height)
166
  else:
 
167
  formatted_text = "\n".join([b.text for b in blocks])
168
 
169
  return blocks, formatted_text
@@ -183,7 +342,12 @@ def read_pdf_pages(filepath: str):
183
  return pages
184
 
185
  def extract_text_from_file(filepath: str, preserve_layout: bool = True) -> Tuple[str, Dict[str, Any]]:
186
- """Extract text from file with layout preservation."""
 
 
 
 
 
187
  lower = filepath.lower()
188
  all_blocks = []
189
  all_texts = []
@@ -220,7 +384,12 @@ def extract_text_from_file(filepath: str, preserve_layout: bool = True) -> Tuple
220
  return final_text or "[No text detected]", metadata
221
 
222
  def infer(file_obj, preserve_layout: bool) -> Tuple[str, str]:
223
- """Main inference function."""
 
 
 
 
 
224
  try:
225
  if file_obj is None:
226
  return "No file uploaded.", "{}"
@@ -274,4 +443,5 @@ if __name__ == "__main__":
274
  server_name=os.getenv("GRADIO_SERVER_NAME", "0.0.0.0"),
275
  server_port=int(os.getenv("PORT", "7860")),
276
  show_error=True
277
- )
 
 
26
 
27
  # --------- Config knobs ----------
28
  LANG = os.getenv("OCR_LANG", "en")
29
+ USE_GPU = os.getenv("OCR_USE_GPU", "false").lower() == "true" # Spaces CPU → keep false
30
  DET = os.getenv("OCR_DET_MODEL", "ch_PP-OCRv4_det")
31
  REC = os.getenv("OCR_REC_MODEL", "en_PP-OCRv4")
32
  CLS = True
 
46
  show_log=False
47
  )
48
 
49
+ # Primary OCR instance (CLS on). If CLS crashes, we'll rebuild w/o CLS just-in-time.
50
  _OCR = _build_ocr(CLS)
51
 
52
  class TextBlock:
 
54
  def __init__(self, text: str, confidence: float, bbox: List[List[int]], img_width: int, img_height: int):
55
  self.text = text
56
  self.confidence = confidence
57
+ self.bbox = bbox # [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
58
  self.img_width = img_width
59
  self.img_height = img_height
60
 
61
+ # Calculate bounding box properties
62
  x_coords = [p[0] for p in bbox]
63
  y_coords = [p[1] for p in bbox]
64
  self.x_min = min(x_coords)
 
70
  self.width = self.x_max - self.x_min
71
  self.height = self.y_max - self.y_min
72
 
73
+ def _sort_blocks_by_layout(blocks: List[TextBlock]) -> List[TextBlock]:
74
+ """
75
+ Sort text blocks to preserve reading order (top-to-bottom, left-to-right).
76
+ """
77
  if not blocks:
78
  return blocks
79
 
 
 
80
  def get_sort_key(block: TextBlock):
81
+ # Primary sort by Y position (top to bottom)
82
+ # Secondary sort by X position (left to right)
83
+ return (block.y_min, block.x_min)
84
 
85
  return sorted(blocks, key=get_sort_key)
86
 
87
  def _reconstruct_text_with_layout(blocks: List[TextBlock], img_width: int, img_height: int) -> str:
88
+ """
89
+ Reconstruct text preserving layout structure using a character-grid approach.
90
+ This preserves exact positions and handles multi-column layouts better.
91
+ """
92
  if not blocks:
93
  return ""
94
 
95
  sorted_blocks = _sort_blocks_by_layout(blocks)
96
 
97
+ # Calculate adaptive line tolerance based on average block height
98
+ if blocks:
99
+ avg_height = sum(b.height for b in blocks) / len(blocks)
100
+ # Line tolerance should be based on block height, not image height
101
+ line_tolerance = max(avg_height * 0.5, img_height * 0.01, 8) # At least 8 pixels
102
+ else:
103
+ line_tolerance = img_height * 0.015
104
+
105
+ # Group blocks into lines based on Y position
106
  lines = []
107
  current_line = []
108
  current_y = None
 
109
 
110
  for block in sorted_blocks:
111
+ # Check if this block is on a new line
112
+ if current_y is None:
113
+ # First block
114
  current_line = [block]
115
+ current_y = block.center_y
116
  else:
117
+ # Check vertical overlap or proximity
118
+ y_diff = abs(block.center_y - current_y)
119
+ # Also check if blocks overlap vertically
120
+ overlaps = any(
121
+ not (block.y_max < b.y_min or block.y_min > b.y_max)
122
+ for b in current_line
123
+ )
124
+
125
+ if y_diff <= line_tolerance or overlaps:
126
+ # Same line, add to current line
127
+ current_line.append(block)
128
+ # Update current_y to weighted average (by block height)
129
+ total_height = sum(b.height for b in current_line)
130
+ current_y = sum(b.center_y * b.height for b in current_line) / total_height
131
+ else:
132
+ # New line
133
+ if current_line:
134
+ lines.append(current_line)
135
+ current_line = [block]
136
+ current_y = block.center_y
137
 
138
+ # Add last line
139
  if current_line:
140
+ lines.append(current_line)
141
 
142
+ # Format each line preserving exact positions
143
+ formatted_lines = []
144
+ for line_blocks in lines:
145
+ # Try character-grid approach first, fall back to gap-based if needed
146
+ formatted_line = _format_line_with_positions(line_blocks, img_width)
147
+ formatted_lines.append(formatted_line)
148
+
149
+ return "\n".join(formatted_lines)
150
 
151
+ def _format_line_with_positions(blocks: List[TextBlock], img_width: int) -> str:
152
+ """
153
+ Format a line of text blocks using a character-grid approach to preserve exact positions.
154
+ Uses a fixed character width to map X positions to column positions.
155
+ """
156
  if not blocks:
157
  return ""
158
 
159
+ # Sort blocks left to right
160
  blocks = sorted(blocks, key=lambda b: b.x_min)
161
 
162
+ # Use a character grid approach: map X positions to character positions
163
+ # Estimate character width based on average block width
164
+ if blocks:
165
+ # Calculate average character width more accurately
166
+ char_widths = []
167
+ for b in blocks:
168
+ if len(b.text) > 0:
169
+ char_widths.append(b.width / len(b.text))
170
+ if char_widths:
171
+ avg_char_width = sum(char_widths) / len(char_widths)
172
+ else:
173
+ avg_char_width = img_width / 100
174
+ # Use a reasonable character width (pixels per character)
175
+ char_width = max(avg_char_width * 0.7, img_width / 150) # At least 150 chars per line
176
+ else:
177
+ char_width = img_width / 100
178
+
179
+ # Build line using character positions
180
+ grid_size = int(img_width / char_width) + 1
181
+ line_chars = [' '] * grid_size
182
 
183
  for block in blocks:
184
+ # Calculate start and end positions in character grid
185
+ start_pos = int(block.x_min / char_width)
186
+ end_pos = int(block.x_max / char_width) + 1
187
+
188
+ # Ensure positions are within bounds
189
+ start_pos = max(0, min(start_pos, grid_size - 1))
190
+ end_pos = max(start_pos, min(end_pos, grid_size))
191
 
192
+ # Place text in the grid, character by character
193
+ text = block.text
194
+ text_len = len(text)
195
+ grid_len = end_pos - start_pos
196
+
197
+ if text_len > 0 and grid_len > 0:
198
+ # Distribute text characters across the grid positions
199
+ for i, char in enumerate(text):
200
+ pos = start_pos + int(i * grid_len / text_len)
201
+ if pos < grid_size:
202
+ # Only overwrite if it's a space or same character
203
+ if line_chars[pos] == ' ':
204
+ line_chars[pos] = char
205
+ elif line_chars[pos] != char and i == 0:
206
+ # If first char conflicts, try next position
207
+ if pos + 1 < grid_size and line_chars[pos + 1] == ' ':
208
+ line_chars[pos + 1] = char
209
+
210
+ # Convert grid to string, removing trailing spaces
211
+ result = ''.join(line_chars).rstrip()
212
+
213
+ # If grid approach didn't work well (too sparse or too compressed), fall back to gap-based approach
214
+ text_length = sum(len(b.text) for b in blocks)
215
+ if len(result.strip()) < text_length * 0.6 or len(result) > text_length * 3:
216
+ return _format_line_with_gaps(blocks, img_width)
217
+
218
+ return result
219
+
220
+ def _format_line_with_gaps(blocks: List[TextBlock], img_width: int) -> str:
221
+ """
222
+ Format a line of text blocks preserving gaps between blocks.
223
+ More accurate spacing calculation based on actual pixel positions.
224
+ """
225
+ if not blocks:
226
+ return ""
227
+
228
+ # Sort blocks left to right
229
+ blocks = sorted(blocks, key=lambda b: b.x_min)
230
+
231
+ # Estimate average character width for spacing - use median for better accuracy
232
+ if blocks:
233
+ char_widths = []
234
+ for b in blocks:
235
+ if len(b.text) > 0:
236
+ char_widths.append(b.width / len(b.text))
237
+ if char_widths:
238
+ # Use median to avoid outliers
239
+ char_widths.sort()
240
+ mid = len(char_widths) // 2
241
+ avg_char_width = char_widths[mid] if len(char_widths) % 2 == 1 else (char_widths[mid-1] + char_widths[mid]) / 2
242
+ else:
243
+ avg_char_width = img_width / 100
244
+ else:
245
+ avg_char_width = img_width / 100
246
+
247
+ parts = []
248
+ prev_x_end = None
249
+
250
+ for block in blocks:
251
+ if prev_x_end is not None:
252
+ # Calculate gap between previous block and current block
253
+ gap = block.x_min - prev_x_end
254
+
255
+ # Determine spacing based on gap
256
+ if gap < 0:
257
+ # Overlapping blocks - add minimal space
258
+ parts.append(" ")
259
+ elif gap < avg_char_width * 0.2:
260
+ # Very small gap - likely should be connected, but add space for safety
261
+ parts.append(" ")
262
+ elif gap < avg_char_width * 1.5:
263
+ # Small gap - single space
264
+ parts.append(" ")
265
+ else:
266
+ # Larger gap - calculate number of spaces
267
+ num_spaces = max(1, int(gap / avg_char_width))
268
+ # Cap at reasonable maximum to avoid excessive spacing
269
+ num_spaces = min(num_spaces, 30)
270
+ parts.append(" " * num_spaces)
271
+ else:
272
+ # First block - check if it starts far from left edge
273
+ if block.x_min > img_width * 0.02: # More than 2% from left
274
+ # Add leading spaces
275
+ num_spaces = max(1, int(block.x_min / avg_char_width))
276
+ num_spaces = min(num_spaces, 20)
277
+ parts.append(" " * num_spaces)
278
 
279
  parts.append(block.text)
280
  prev_x_end = block.x_max
 
282
  return "".join(parts)
283
 
284
  def ocr_image(pil_img: Image.Image, preserve_layout: bool = True) -> Tuple[List[TextBlock], str]:
285
+ """
286
+ Perform OCR on image and return both structured blocks and formatted text.
287
+
288
+ Returns:
289
+ Tuple of (list of TextBlock objects, formatted text string)
290
+ """
291
  img_cv = _pil_to_cv(pil_img)
292
  img_width, img_height = pil_img.size
293
 
 
299
  except RuntimeError as e:
300
  msg = str(e).lower()
301
  if "primitive" in msg or "mkldnn" in msg or "predictor.run" in msg:
302
+ # One-time fallback without angle classifier
303
  fallback_ocr = _build_ocr(False)
304
  result = _run(fallback_ocr, False)
305
  else:
 
310
  return blocks, ""
311
 
312
  for line in result[0]:
313
+ bbox = line[0] # [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
314
  txt = line[1][0]
315
  conf = float(line[1][1])
316
 
 
318
  block = TextBlock(txt, conf, bbox, img_width, img_height)
319
  blocks.append(block)
320
 
321
+ # Generate formatted text
322
  if preserve_layout:
323
  formatted_text = _reconstruct_text_with_layout(blocks, img_width, img_height)
324
  else:
325
+ # Simple concatenation (original behavior)
326
  formatted_text = "\n".join([b.text for b in blocks])
327
 
328
  return blocks, formatted_text
 
342
  return pages
343
 
344
  def extract_text_from_file(filepath: str, preserve_layout: bool = True) -> Tuple[str, Dict[str, Any]]:
345
+ """
346
+ Extract text from file with layout preservation.
347
+
348
+ Returns:
349
+ Tuple of (formatted text, metadata dict with blocks info)
350
+ """
351
  lower = filepath.lower()
352
  all_blocks = []
353
  all_texts = []
 
384
  return final_text or "[No text detected]", metadata
385
 
386
  def infer(file_obj, preserve_layout: bool) -> Tuple[str, str]:
387
+ """
388
+ Main inference function.
389
+
390
+ Returns:
391
+ Tuple of (formatted text, metadata JSON string)
392
+ """
393
  try:
394
  if file_obj is None:
395
  return "No file uploaded.", "{}"
 
443
  server_name=os.getenv("GRADIO_SERVER_NAME", "0.0.0.0"),
444
  server_port=int(os.getenv("PORT", "7860")),
445
  show_error=True
446
+ )
447
+