iammraat commited on
Commit
f2f27bb
·
verified ·
1 Parent(s): b02cd5a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -356
app.py CHANGED
@@ -66,248 +66,9 @@
66
 
67
 
68
 
69
- # import gradio as gr
70
- # import torch
71
- # import numpy as np
72
- # import cv2
73
- # from PIL import Image
74
- # from transformers import TrOCRProcessor, VisionEncoderDecoderModel
75
- # from craft_text_detector import Craft
76
-
77
- # # ----------------------------
78
- # # Device
79
- # # ----------------------------
80
- # device = "cuda" if torch.cuda.is_available() else "cpu"
81
-
82
- # # ----------------------------
83
- # # Load TrOCR
84
- # # ----------------------------
85
- # print("Loading TrOCR model...")
86
- # processor = TrOCRProcessor.from_pretrained("microsoft/trocr-small-handwritten")
87
- # model = VisionEncoderDecoderModel.from_pretrained(
88
- # "microsoft/trocr-small-handwritten"
89
- # )
90
- # model.to(device)
91
- # model.eval()
92
-
93
- # # ----------------------------
94
- # # Load CRAFT
95
- # # ----------------------------
96
- # print("Loading CRAFT text detector...")
97
- # craft = Craft(
98
- # output_dir=None,
99
- # crop_type="poly",
100
- # cuda=(device == "cuda"),
101
- # )
102
-
103
- # # ----------------------------
104
- # # Sort boxes (reading order)
105
- # # ----------------------------
106
- # def get_sorted_boxes(boxes):
107
- # items = []
108
- # for box in boxes:
109
- # cx = np.mean(box[:, 0])
110
- # cy = np.mean(box[:, 1])
111
- # items.append((cy, cx, box))
112
-
113
- # # group by line (roughly)
114
- # items.sort(key=lambda x: (int(x[0] // 20), x[1]))
115
- # return [b for _, _, b in items]
116
-
117
- # # ----------------------------
118
- # # OCR Pipeline
119
- # # ----------------------------
120
- # def process_full_page(image: Image.Image):
121
- # # ALWAYS return (image_or_None, text)
122
- # if image is None:
123
- # return None, "Please upload an image."
124
-
125
- # image_np = np.array(image)
126
-
127
- # prediction = craft.detect_text(image_np)
128
- # boxes = prediction.get("boxes", [])
129
-
130
- # if not boxes:
131
- # return image, "No text detected."
132
-
133
- # sorted_boxes = get_sorted_boxes(boxes)
134
- # annotated = image_np.copy()
135
- # texts = []
136
-
137
- # for box in sorted_boxes:
138
- # box = box.astype(int)
139
- # cv2.polylines(annotated, [box], True, (255, 0, 0), 2)
140
-
141
- # x_min = max(0, box[:, 0].min())
142
- # x_max = min(image_np.shape[1], box[:, 0].max())
143
- # y_min = max(0, box[:, 1].min())
144
- # y_max = min(image_np.shape[0], box[:, 1].max())
145
-
146
- # if x_max - x_min < 5 or y_max - y_min < 5:
147
- # continue
148
-
149
- # crop = image_np[y_min:y_max, x_min:x_max]
150
- # pil_crop = Image.fromarray(crop).convert("RGB")
151
-
152
- # with torch.no_grad():
153
- # pixels = processor(
154
- # images=pil_crop,
155
- # return_tensors="pt"
156
- # ).pixel_values.to(device)
157
-
158
- # ids = model.generate(pixels)
159
- # text = processor.batch_decode(
160
- # ids, skip_special_tokens=True
161
- # )[0]
162
-
163
- # if text.strip():
164
- # texts.append(text)
165
-
166
- # final_text = " ".join(texts)
167
- # return Image.fromarray(annotated), final_text
168
-
169
- # # ----------------------------
170
- # # Gradio UI
171
- # # ----------------------------
172
- # with gr.Blocks(theme=gr.themes.Soft()) as demo:
173
- # gr.Markdown("# 🕵️‍♀️ Full-Page Handwritten OCR")
174
- # gr.Markdown("**CRAFT ➜ TrOCR** (Detection + Recognition)")
175
-
176
- # with gr.Row():
177
- # input_img = gr.Image(type="pil", label="Upload Full Page")
178
-
179
- # with gr.Row():
180
- # vis_output = gr.Image(label="Detections")
181
- # text_output = gr.Textbox(label="Extracted Text", lines=10)
182
-
183
- # btn = gr.Button("Process Page", variant="primary")
184
- # btn.click(
185
- # fn=process_full_page,
186
- # inputs=input_img,
187
- # outputs=[vis_output, text_output],
188
- # )
189
-
190
- # if __name__ == "__main__":
191
- # demo.launch(
192
- # server_name="0.0.0.0",
193
- # server_port=7860,
194
- # show_api=False,
195
- # )
196
-
197
-
198
-
199
-
200
-
201
-
202
-
203
- # import gradio as gr
204
- # import torch
205
- # import numpy as np
206
- # import cv2
207
- # from PIL import Image
208
- # from transformers import TrOCRProcessor, VisionEncoderDecoderModel
209
- # from craft_text_detector import Craft
210
-
211
- # # PATCH: Fix NumPy inhomogeneous array crash
212
- # import craft_text_detector.craft_utils as craft_utils_module
213
-
214
- # _original_adjust = craft_utils_module.adjustResultCoordinates
215
-
216
- # def fixed_adjustResultCoordinates(polys, ratio_w, ratio_h):
217
- # if not polys:
218
- # return []
219
- # adjusted = []
220
- # for poly in polys:
221
- # if poly is None or len(poly) == 0:
222
- # continue
223
- # poly = np.array(poly).reshape(-1, 2)
224
- # poly[:, 0] *= ratio_w
225
- # poly[:, 1] *= ratio_h
226
- # adjusted.append(poly)
227
- # return adjusted
228
 
229
- # craft_utils_module.adjustResultCoordinates = fixed_adjustResultCoordinates
230
 
231
- # # Device
232
- # device = "cuda" if torch.cuda.is_available() else "cpu"
233
-
234
- # # Load TrOCR
235
- # print("Loading TrOCR model...")
236
- # processor = TrOCRProcessor.from_pretrained("microsoft/trocr-small-handwritten")
237
- # model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-small-handwritten")
238
- # model.to(device)
239
- # model.eval()
240
 
241
- # # Load CRAFT
242
- # print("Loading CRAFT text detector...")
243
- # craft = Craft(output_dir=None, crop_type="poly", cuda=(device == "cuda"))
244
-
245
- # # Sort boxes (reading order)
246
- # def get_sorted_boxes(boxes):
247
- # items = []
248
- # for box in boxes:
249
- # cx = np.mean(box[:, 0])
250
- # cy = np.mean(box[:, 1])
251
- # items.append((cy, cx, box))
252
- # items.sort(key=lambda x: (int(x[0] // 20), x[1]))
253
- # return [b for _, _, b in items]
254
-
255
- # # OCR Pipeline
256
- # def process_full_page(image):
257
- # if image is None:
258
- # return None, "Please upload an image."
259
-
260
- # image_np = np.array(image)
261
- # prediction = craft.detect_text(image_np)
262
- # boxes = prediction.get("boxes", [])
263
-
264
- # if not boxes:
265
- # return image, "No text detected."
266
-
267
- # sorted_boxes = get_sorted_boxes(boxes)
268
- # annotated = image_np.copy()
269
- # texts = []
270
-
271
- # for box in sorted_boxes:
272
- # box = box.astype(int)
273
- # cv2.polylines(annotated, [box], True, (255, 0, 0), 2)
274
-
275
- # x_min = max(0, box[:, 0].min())
276
- # x_max = min(image_np.shape[1], box[:, 0].max())
277
- # y_min = max(0, box[:, 1].min())
278
- # y_max = min(image_np.shape[0], box[:, 1].max())
279
-
280
- # if x_max - x_min < 5 or y_max - y_min < 5:
281
- # continue
282
-
283
- # crop = image_np[y_min:y_max, x_min:x_max]
284
- # pil_crop = Image.fromarray(crop).convert("RGB")
285
-
286
- # with torch.no_grad():
287
- # pixels = processor(images=pil_crop, return_tensors="pt").pixel_values.to(device)
288
- # ids = model.generate(pixels)
289
- # text = processor.batch_decode(ids, skip_special_tokens=True)[0]
290
-
291
- # if text.strip():
292
- # texts.append(text)
293
-
294
- # final_text = " ".join(texts)
295
- # return Image.fromarray(annotated), final_text
296
-
297
- # # Gradio UI
298
- # demo = gr.Interface(
299
- # fn=process_full_page,
300
- # inputs=gr.Image(type="pil", label="Upload Full Page"),
301
- # outputs=[
302
- # gr.Image(label="Detections"),
303
- # gr.Textbox(label="Extracted Text", lines=10)
304
- # ],
305
- # title="🕵️‍♀️ Full-Page Handwritten OCR",
306
- # description="CRAFT ➜ TrOCR (Detection + Recognition)"
307
- # )
308
-
309
- # if __name__ == "__main__":
310
- # demo.launch(server_name="0.0.0.0", server_port=7860)
311
 
312
 
313
 
@@ -326,7 +87,7 @@
326
  # from craft_text_detector import Craft
327
 
328
  # # ==========================================
329
- # # 🔧 PATCH 1: Fix Torchvision (From your code)
330
  # # ==========================================
331
  # import torchvision.models.vgg
332
  # if not hasattr(torchvision.models.vgg, 'model_urls'):
@@ -335,7 +96,7 @@
335
  # }
336
 
337
  # # ==========================================
338
- # # 🔧 PATCH 2: The Logic Fix (Ratio Net)
339
  # # ==========================================
340
  # import craft_text_detector.craft_utils as craft_utils_module
341
 
@@ -348,11 +109,10 @@
348
  # if poly is None or len(poly) == 0:
349
  # continue
350
 
351
- # # Safe numpy conversion
352
  # p = np.array(poly).reshape(-1, 2)
353
 
354
- # # CRITICAL FIX: Multiply by ratio_net (defaults to 2)
355
- # # This scales the 1/2 size heatmap output back to full image size
356
  # p[:, 0] *= (ratio_w * ratio_net)
357
  # p[:, 1] *= (ratio_h * ratio_net)
358
 
@@ -360,21 +120,24 @@
360
 
361
  # return adjusted
362
 
363
- # # Apply the patch
364
  # craft_utils_module.adjustResultCoordinates = fixed_adjustResultCoordinates
365
  # # ==========================================
366
 
367
- # # --- Load TrOCR (Recognition) ---
 
368
  # device = "cuda" if torch.cuda.is_available() else "cpu"
369
- # print(f"Loading TrOCR on {device}...")
370
- # processor = TrOCRProcessor.from_pretrained('microsoft/trocr-small-handwritten')
371
- # model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-small-handwritten').to(device).eval()
 
 
 
372
 
373
- # # --- Load CRAFT (Detection) ---
374
  # print("Loading CRAFT...")
375
- # # crop_type="box" ensures we get clean rectangles
376
  # craft = Craft(output_dir=None, crop_type="box", cuda=(device == "cuda"))
377
 
 
 
378
  # def get_sorted_boxes(boxes):
379
  # """Sorts boxes top-to-bottom (lines), then left-to-right."""
380
  # if not boxes: return []
@@ -384,54 +147,60 @@
384
  # cx = np.mean(box[:, 0])
385
  # items.append((cy, cx, box))
386
 
387
- # # Sort by Y (grouping by 40px lines) then X
388
- # items.sort(key=lambda x: (int(x[0] // 40), x[1]))
389
  # return [x[2] for x in items]
390
 
391
  # def process_image(image):
392
  # if image is None:
393
- # return None, "Please upload an image."
394
 
395
- # # Convert to numpy
 
396
  # image_np = np.array(image.convert("RGB"))
397
 
398
  # # 1. DETECT
399
- # # The patch we added above will now auto-multiply coordinates by 2 * ratio
400
- # # fixing the "tiny box" issue.
401
  # prediction = craft.detect_text(image_np)
402
  # boxes = prediction.get("boxes", [])
403
 
404
  # if not boxes:
405
- # return image, "No text detected."
406
 
407
- # # 2. VISUALIZE & CROP
408
  # sorted_boxes = get_sorted_boxes(boxes)
409
  # annotated_img = image_np.copy()
410
  # results = []
 
411
 
 
412
  # for box in sorted_boxes:
413
- # # Cast to int for drawing
414
  # box_int = box.astype(np.int32)
415
 
416
- # # Draw on image (Blue, thickness 3)
417
  # cv2.polylines(annotated_img, [box_int], True, (255, 0, 0), 3)
418
 
419
- # # Get Crop Coordinates
420
- # x_min = max(0, np.min(box_int[:, 0]))
421
- # x_max = min(image_np.shape[1], np.max(box_int[:, 0]))
422
- # y_min = max(0, np.min(box_int[:, 1]))
423
- # y_max = min(image_np.shape[0], np.max(box_int[:, 1]))
424
 
425
- # # Filter noise
426
- # if (x_max - x_min) < 10 or (y_max - y_min) < 10:
 
 
 
 
 
427
  # continue
428
 
429
  # crop = image_np[y_min:y_max, x_min:x_max]
430
- # if crop.size == 0: continue
431
 
 
432
  # pil_crop = Image.fromarray(crop)
433
 
434
- # # 3. RECOGNIZE (TrOCR)
 
 
 
435
  # with torch.no_grad():
436
  # pixel_values = processor(images=pil_crop, return_tensors="pt").pixel_values.to(device)
437
  # generated_ids = model.generate(pixel_values)
@@ -441,22 +210,28 @@
441
  # results.append(text)
442
 
443
  # full_text = "\n".join(results)
444
- # return Image.fromarray(annotated_img), full_text
 
445
 
446
- # # --- Gradio UI ---
447
- # with gr.Blocks(title="Handwritten OCR Fixed") as demo:
448
- # gr.Markdown("# 📝 Handwritten OCR (Fixed Pipeline)")
 
449
 
450
  # with gr.Row():
451
- # with gr.Column():
452
  # input_img = gr.Image(type="pil", label="Upload Image")
453
  # btn = gr.Button("Transcribe", variant="primary")
454
 
455
- # with gr.Column():
456
  # output_img = gr.Image(label="Detections")
457
- # output_txt = gr.Textbox(label="Result", lines=20)
 
 
 
 
458
 
459
- # btn.click(process_image, input_img, [output_img, output_txt])
460
 
461
  # if __name__ == "__main__":
462
  # demo.launch()
@@ -468,76 +243,43 @@
468
 
469
 
470
 
 
 
471
  import gradio as gr
472
  import torch
473
  import numpy as np
474
  import cv2
475
  from PIL import Image
476
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
477
- from craft_text_detector import Craft
478
-
479
- # ==========================================
480
- # 🔧 PATCH 1: Fix Torchvision Compatibility
481
- # ==========================================
482
- import torchvision.models.vgg
483
- if not hasattr(torchvision.models.vgg, 'model_urls'):
484
- torchvision.models.vgg.model_urls = {
485
- 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth'
486
- }
487
-
488
- # ==========================================
489
- # 🔧 PATCH 2: The "Ratio Net" Logic Fix
490
- # ==========================================
491
- import craft_text_detector.craft_utils as craft_utils_module
492
-
493
- def fixed_adjustResultCoordinates(polys, ratio_w, ratio_h, ratio_net=2):
494
- if not polys:
495
- return []
496
-
497
- adjusted = []
498
- for poly in polys:
499
- if poly is None or len(poly) == 0:
500
- continue
501
-
502
- # Convert to numpy and reshape
503
- p = np.array(poly).reshape(-1, 2)
504
-
505
- # Scale correctly using ratio_net
506
- p[:, 0] *= (ratio_w * ratio_net)
507
- p[:, 1] *= (ratio_h * ratio_net)
508
-
509
- adjusted.append(p)
510
-
511
- return adjusted
512
-
513
- craft_utils_module.adjustResultCoordinates = fixed_adjustResultCoordinates
514
- # ==========================================
515
-
516
 
517
- # --- 1. SETUP MODEL (Switched to BASE for stability) ---
518
  device = "cuda" if torch.cuda.is_available() else "cpu"
519
- print(f"Loading TrOCR-Base on {device}...")
520
 
521
- # We use the 'base' model because 'small' hallucinates Wikipedia text on tight crops
522
- MODEL_ID = "microsoft/trocr-base-handwritten"
523
- processor = TrOCRProcessor.from_pretrained(MODEL_ID)
524
- model = VisionEncoderDecoderModel.from_pretrained(MODEL_ID).to(device).eval()
525
 
526
- print("Loading CRAFT...")
527
- craft = Craft(output_dir=None, crop_type="box", cuda=(device == "cuda"))
 
 
 
528
 
529
-
530
- # --- 2. HELPER FUNCTIONS ---
531
  def get_sorted_boxes(boxes):
532
  """Sorts boxes top-to-bottom (lines), then left-to-right."""
533
  if not boxes: return []
534
  items = []
535
  for box in boxes:
 
 
 
536
  cy = np.mean(box[:, 1])
537
  cx = np.mean(box[:, 0])
538
  items.append((cy, cx, box))
539
 
540
- # Sort by line (approx 20px tolerance) then by column
541
  items.sort(key=lambda x: (int(x[0] // 20), x[1]))
542
  return [x[2] for x in items]
543
 
@@ -545,18 +287,21 @@ def process_image(image):
545
  if image is None:
546
  return None, [], "Please upload an image."
547
 
548
- # Convert to standard RGB Numpy array
549
- # We use the FULL resolution image (no resizing) to keep text sharp
550
  image_np = np.array(image.convert("RGB"))
551
 
552
- # 1. DETECT
553
- # The patch ensures coordinates map perfectly to this full-res image
554
- prediction = craft.detect_text(image_np)
555
- boxes = prediction.get("boxes", [])
556
 
557
- if not boxes:
 
558
  return image, [], "No text detected."
559
-
 
 
 
560
  sorted_boxes = get_sorted_boxes(boxes)
561
  annotated_img = image_np.copy()
562
  results = []
@@ -566,31 +311,27 @@ def process_image(image):
566
  for box in sorted_boxes:
567
  box_int = box.astype(np.int32)
568
 
569
- # Draw the box (Visual verification)
570
- cv2.polylines(annotated_img, [box_int], True, (255, 0, 0), 3)
571
 
572
- # --- CROP WITH PADDING (Crucial Fix) ---
573
- # TrOCR needs 'breathing room' or it hallucinates.
574
- PADDING = 10
575
 
576
  x_min = max(0, np.min(box_int[:, 0]) - PADDING)
577
  x_max = min(image_np.shape[1], np.max(box_int[:, 0]) + PADDING)
578
  y_min = max(0, np.min(box_int[:, 1]) - PADDING)
579
  y_max = min(image_np.shape[0], np.max(box_int[:, 1]) + PADDING)
580
 
581
- # Skip noise
582
- if (x_max - x_min) < 20 or (y_max - y_min) < 10:
583
  continue
584
 
585
  crop = image_np[y_min:y_max, x_min:x_max]
586
-
587
- # Convert to PIL for Model
588
  pil_crop = Image.fromarray(crop)
589
-
590
- # Add to debug gallery so user can see what the model sees
591
  debug_crops.append(pil_crop)
592
 
593
- # 3. RECOGNIZE
594
  with torch.no_grad():
595
  pixel_values = processor(images=pil_crop, return_tensors="pt").pixel_values.to(device)
596
  generated_ids = model.generate(pixel_values)
@@ -603,10 +344,10 @@ def process_image(image):
603
 
604
  return Image.fromarray(annotated_img), debug_crops, full_text
605
 
606
- # --- 3. GRADIO UI ---
607
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
608
- gr.Markdown("# 📝 Robust Handwritten OCR (Base Model)")
609
- gr.Markdown("Includes padding and a stronger model to prevent hallucinations.")
610
 
611
  with gr.Row():
612
  with gr.Column(scale=1):
@@ -614,14 +355,13 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
614
  btn = gr.Button("Transcribe", variant="primary")
615
 
616
  with gr.Column(scale=1):
617
- output_img = gr.Image(label="Detections")
618
  output_txt = gr.Textbox(label="Extracted Text", lines=15, show_copy_button=True)
619
-
620
  with gr.Row():
621
- # Gallery to check if crops are valid or empty
622
- crop_gallery = gr.Gallery(label="Debug: See what the model sees (Crops)", columns=6, height=200)
623
 
624
- btn.click(process_image, input_img, [output_img, crop_gallery, output_txt])
625
 
626
  if __name__ == "__main__":
627
  demo.launch()
 
66
 
67
 
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
 
70
 
 
 
 
 
 
 
 
 
 
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
 
74
 
 
87
  # from craft_text_detector import Craft
88
 
89
  # # ==========================================
90
+ # # 🔧 PATCH 1: Fix Torchvision Compatibility
91
  # # ==========================================
92
  # import torchvision.models.vgg
93
  # if not hasattr(torchvision.models.vgg, 'model_urls'):
 
96
  # }
97
 
98
  # # ==========================================
99
+ # # 🔧 PATCH 2: The "Ratio Net" Logic Fix
100
  # # ==========================================
101
  # import craft_text_detector.craft_utils as craft_utils_module
102
 
 
109
  # if poly is None or len(poly) == 0:
110
  # continue
111
 
112
+ # # Convert to numpy and reshape
113
  # p = np.array(poly).reshape(-1, 2)
114
 
115
+ # # Scale correctly using ratio_net
 
116
  # p[:, 0] *= (ratio_w * ratio_net)
117
  # p[:, 1] *= (ratio_h * ratio_net)
118
 
 
120
 
121
  # return adjusted
122
 
 
123
  # craft_utils_module.adjustResultCoordinates = fixed_adjustResultCoordinates
124
  # # ==========================================
125
 
126
+
127
+ # # --- 1. SETUP MODEL (Switched to BASE for stability) ---
128
  # device = "cuda" if torch.cuda.is_available() else "cpu"
129
+ # print(f"Loading TrOCR-Base on {device}...")
130
+
131
+ # # We use the 'base' model because 'small' hallucinates Wikipedia text on tight crops
132
+ # MODEL_ID = "microsoft/trocr-base-handwritten"
133
+ # processor = TrOCRProcessor.from_pretrained(MODEL_ID)
134
+ # model = VisionEncoderDecoderModel.from_pretrained(MODEL_ID).to(device).eval()
135
 
 
136
  # print("Loading CRAFT...")
 
137
  # craft = Craft(output_dir=None, crop_type="box", cuda=(device == "cuda"))
138
 
139
+
140
+ # # --- 2. HELPER FUNCTIONS ---
141
  # def get_sorted_boxes(boxes):
142
  # """Sorts boxes top-to-bottom (lines), then left-to-right."""
143
  # if not boxes: return []
 
147
  # cx = np.mean(box[:, 0])
148
  # items.append((cy, cx, box))
149
 
150
+ # # Sort by line (approx 20px tolerance) then by column
151
+ # items.sort(key=lambda x: (int(x[0] // 20), x[1]))
152
  # return [x[2] for x in items]
153
 
154
  # def process_image(image):
155
  # if image is None:
156
+ # return None, [], "Please upload an image."
157
 
158
+ # # Convert to standard RGB Numpy array
159
+ # # We use the FULL resolution image (no resizing) to keep text sharp
160
  # image_np = np.array(image.convert("RGB"))
161
 
162
  # # 1. DETECT
163
+ # # The patch ensures coordinates map perfectly to this full-res image
 
164
  # prediction = craft.detect_text(image_np)
165
  # boxes = prediction.get("boxes", [])
166
 
167
  # if not boxes:
168
+ # return image, [], "No text detected."
169
 
 
170
  # sorted_boxes = get_sorted_boxes(boxes)
171
  # annotated_img = image_np.copy()
172
  # results = []
173
+ # debug_crops = []
174
 
175
+ # # 2. PROCESS BOXES
176
  # for box in sorted_boxes:
 
177
  # box_int = box.astype(np.int32)
178
 
179
+ # # Draw the box (Visual verification)
180
  # cv2.polylines(annotated_img, [box_int], True, (255, 0, 0), 3)
181
 
182
+ # # --- CROP WITH PADDING (Crucial Fix) ---
183
+ # # TrOCR needs 'breathing room' or it hallucinates.
184
+ # PADDING = 10
 
 
185
 
186
+ # x_min = max(0, np.min(box_int[:, 0]) - PADDING)
187
+ # x_max = min(image_np.shape[1], np.max(box_int[:, 0]) + PADDING)
188
+ # y_min = max(0, np.min(box_int[:, 1]) - PADDING)
189
+ # y_max = min(image_np.shape[0], np.max(box_int[:, 1]) + PADDING)
190
+
191
+ # # Skip noise
192
+ # if (x_max - x_min) < 20 or (y_max - y_min) < 10:
193
  # continue
194
 
195
  # crop = image_np[y_min:y_max, x_min:x_max]
 
196
 
197
+ # # Convert to PIL for Model
198
  # pil_crop = Image.fromarray(crop)
199
 
200
+ # # Add to debug gallery so user can see what the model sees
201
+ # debug_crops.append(pil_crop)
202
+
203
+ # # 3. RECOGNIZE
204
  # with torch.no_grad():
205
  # pixel_values = processor(images=pil_crop, return_tensors="pt").pixel_values.to(device)
206
  # generated_ids = model.generate(pixel_values)
 
210
  # results.append(text)
211
 
212
  # full_text = "\n".join(results)
213
+
214
+ # return Image.fromarray(annotated_img), debug_crops, full_text
215
 
216
+ # # --- 3. GRADIO UI ---
217
+ # with gr.Blocks(theme=gr.themes.Soft()) as demo:
218
+ # gr.Markdown("# 📝 Robust Handwritten OCR (Base Model)")
219
+ # gr.Markdown("Includes padding and a stronger model to prevent hallucinations.")
220
 
221
  # with gr.Row():
222
+ # with gr.Column(scale=1):
223
  # input_img = gr.Image(type="pil", label="Upload Image")
224
  # btn = gr.Button("Transcribe", variant="primary")
225
 
226
+ # with gr.Column(scale=1):
227
  # output_img = gr.Image(label="Detections")
228
+ # output_txt = gr.Textbox(label="Extracted Text", lines=15, show_copy_button=True)
229
+
230
+ # with gr.Row():
231
+ # # Gallery to check if crops are valid or empty
232
+ # crop_gallery = gr.Gallery(label="Debug: See what the model sees (Crops)", columns=6, height=200)
233
 
234
+ # btn.click(process_image, input_img, [output_img, crop_gallery, output_txt])
235
 
236
  # if __name__ == "__main__":
237
  # demo.launch()
 
243
 
244
 
245
 
246
+
247
+
248
  import gradio as gr
249
  import torch
250
  import numpy as np
251
  import cv2
252
  from PIL import Image
253
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
254
+ from paddleocr import PaddleOCR
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
 
256
+ # --- 1. SETUP TR-OCR (Recognition) ---
257
  device = "cuda" if torch.cuda.is_available() else "cpu"
258
+ print(f"Loading TrOCR on {device}...")
259
 
260
+ # Using the 'base' model for better accuracy on the crops
261
+ processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten')
262
+ model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten').to(device).eval()
 
263
 
264
+ # --- 2. SETUP PADDLEOCR (Detection Only) ---
265
+ print("Loading PaddleOCR (DBNet)...")
266
+ # use_angle_cls=True helps if the page is slightly rotated
267
+ # lang='en' loads the English detection model
268
+ detector = PaddleOCR(use_angle_cls=True, lang='en', show_log=False)
269
 
 
 
270
  def get_sorted_boxes(boxes):
271
  """Sorts boxes top-to-bottom (lines), then left-to-right."""
272
  if not boxes: return []
273
  items = []
274
  for box in boxes:
275
+ # Paddle returns boxes as list of points [[x1,y1], [x2,y2], ...]
276
+ # We convert to numpy for easier calc
277
+ box = np.array(box).astype(np.float32)
278
  cy = np.mean(box[:, 1])
279
  cx = np.mean(box[:, 0])
280
  items.append((cy, cx, box))
281
 
282
+ # Sort by Y (line tolerance 20px) then X
283
  items.sort(key=lambda x: (int(x[0] // 20), x[1]))
284
  return [x[2] for x in items]
285
 
 
287
  if image is None:
288
  return None, [], "Please upload an image."
289
 
290
+ # Convert to standard RGB Numpy array (Full Resolution)
 
291
  image_np = np.array(image.convert("RGB"))
292
 
293
+ # 1. DETECT with PaddleOCR
294
+ # cls=False because we don't need orientation classification for just boxes
295
+ # rec=False because we ONLY want boxes (we will use TrOCR to read)
296
+ result = detector.ocr(image_np, cls=False, rec=False)
297
 
298
+ # Paddle returns a list of results (one per page). We just have 1 page.
299
+ if not result or result[0] is None:
300
  return image, [], "No text detected."
301
+
302
+ # Extract boxes from result
303
+ boxes = result[0] # [[x1, y1], [x2, y2], ...]
304
+
305
  sorted_boxes = get_sorted_boxes(boxes)
306
  annotated_img = image_np.copy()
307
  results = []
 
311
  for box in sorted_boxes:
312
  box_int = box.astype(np.int32)
313
 
314
+ # Draw the box (Red, thickness 2)
315
+ cv2.polylines(annotated_img, [box_int], True, (255, 0, 0), 2)
316
 
317
+ # --- CROP WITH PADDING ---
318
+ # Padding helps TrOCR see the start/end of letters
319
+ PADDING = 8
320
 
321
  x_min = max(0, np.min(box_int[:, 0]) - PADDING)
322
  x_max = min(image_np.shape[1], np.max(box_int[:, 0]) + PADDING)
323
  y_min = max(0, np.min(box_int[:, 1]) - PADDING)
324
  y_max = min(image_np.shape[0], np.max(box_int[:, 1]) + PADDING)
325
 
326
+ # Skip tiny noise
327
+ if (x_max - x_min) < 15 or (y_max - y_min) < 10:
328
  continue
329
 
330
  crop = image_np[y_min:y_max, x_min:x_max]
 
 
331
  pil_crop = Image.fromarray(crop)
 
 
332
  debug_crops.append(pil_crop)
333
 
334
+ # 3. RECOGNIZE (TrOCR)
335
  with torch.no_grad():
336
  pixel_values = processor(images=pil_crop, return_tensors="pt").pixel_values.to(device)
337
  generated_ids = model.generate(pixel_values)
 
344
 
345
  return Image.fromarray(annotated_img), debug_crops, full_text
346
 
347
+ # --- GRADIO UI ---
348
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
349
+ gr.Markdown("# PaddleOCR + TrOCR")
350
+ gr.Markdown("Using **PaddleOCR (DBNet)** for sharp detection on cramped text, and **TrOCR** for reading.")
351
 
352
  with gr.Row():
353
  with gr.Column(scale=1):
 
355
  btn = gr.Button("Transcribe", variant="primary")
356
 
357
  with gr.Column(scale=1):
358
+ output_img = gr.Image(label="Detections (Paddle)")
359
  output_txt = gr.Textbox(label="Extracted Text", lines=15, show_copy_button=True)
360
+
361
  with gr.Row():
362
+ gallery = gr.Gallery(label="Line Crops", columns=6, height=200)
 
363
 
364
+ btn.click(process_image, input_img, [output_img, gallery, output_txt])
365
 
366
  if __name__ == "__main__":
367
  demo.launch()