iammraat commited on
Commit
0ea7492
·
verified ·
1 Parent(s): 4815eac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -209
app.py CHANGED
@@ -1,179 +1,6 @@
1
- # import gradio as gr
2
- # import logging
3
- # import os
4
- # import numpy as np
5
- # import torch
6
- # from PIL import Image, ImageDraw
7
- # from transformers import TrOCRProcessor, VisionEncoderDecoderModel
8
-
9
- # # --- SURYA IMPORTS ---
10
- # try:
11
- # from surya.detection import batch_text_detection
12
- # from surya.model.detection.model import load_model as load_det_model, load_processor as load_det_processor
13
- # except ImportError:
14
- # from surya.detection import batch_inference as batch_text_detection
15
- # from surya.model.detection.segformer import load_model as load_det_model, load_processor as load_det_processor
16
-
17
- # # ==========================================
18
- # # 1. SETUP MODELS
19
- # # ==========================================
20
- # device = "cpu"
21
- # logging.basicConfig(level=logging.INFO)
22
- # logger = logging.getLogger(__name__)
23
-
24
- # logger.info("⏳ Loading Models...")
25
-
26
- # # A. SURYA DETECTION
27
- # det_processor = load_det_processor()
28
- # det_model = load_det_model().to(device)
29
-
30
- # # B. TROCR RECOGNITION
31
- # # NOTE: We do NOT use quantization here. It destroys the attention mechanism in ViT
32
- # # encoders on CPU, leading to "mode collapse" (hallucinations).
33
- # trocr_processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten')
34
- # trocr_model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten').to(device)
35
-
36
- # logger.info("✅ All Models Loaded.")
37
-
38
- # # ==========================================
39
- # # 2. HELPER FUNCTIONS
40
- # # ==========================================
41
- # def recognize_batch(crops):
42
- # """
43
- # Feeds raw crops directly to TrOCR.
44
- # """
45
- # if not crops: return []
46
-
47
- # # Ensure crops are valid
48
- # valid_crops = [c for c in crops if c.size[0] > 0 and c.size[1] > 0]
49
- # if not valid_crops: return []
50
-
51
- # pixel_values = trocr_processor(images=valid_crops, return_tensors="pt").pixel_values.to(device)
52
-
53
- # with torch.no_grad():
54
- # # Using a slightly lower max_length prevents it from rambling if it gets confused
55
- # generated_ids = trocr_model.generate(pixel_values, max_length=64)
56
- # text = trocr_processor.batch_decode(generated_ids, skip_special_tokens=True)
57
- # return text
58
-
59
- # def draw_boxes(image, prediction_objects):
60
- # draw = ImageDraw.Draw(image)
61
- # for obj in prediction_objects:
62
- # if hasattr(obj, "bbox"):
63
- # draw.rectangle(obj.bbox, outline="red", width=2)
64
- # else:
65
- # # Fallback if obj is just a list/tuple
66
- # draw.rectangle(obj, outline="red", width=2)
67
- # return image
68
-
69
- # # ==========================================
70
- # # 3. MAIN WORKFLOW
71
- # # ==========================================
72
- # def hybrid_ocr_workflow(image):
73
- # if image is None: return None, "Please upload an image."
74
-
75
- # # CRITICAL FIX: Ensure image is RGB (TrOCR fails on RGBA/P modes silently)
76
- # if image.mode != "RGB":
77
- # image = image.convert("RGB")
78
-
79
- # # 1. DETECT (Surya)
80
- # logger.info("Step 1: Detecting Lines with Surya...")
81
- # # Surya expects list of images
82
- # predictions = batch_text_detection([image], det_model, det_processor)
83
- # result = predictions[0]
84
-
85
- # # Extract BBoxes
86
- # lines_objects = []
87
- # if hasattr(result, "bboxes"):
88
- # lines_objects = result.bboxes
89
- # elif hasattr(result, "text_lines"):
90
- # lines_objects = result.text_lines
91
-
92
- # # Sort by Y-coordinate (top to bottom)
93
- # lines_objects.sort(key=lambda x: x.bbox[1])
94
-
95
- # # 2. CROP & RECOGNIZE
96
- # logger.info(f"Step 2: Recognizing {len(lines_objects)} lines with TrOCR...")
97
-
98
- # line_crops = []
99
- # w, h = image.size
100
-
101
- # for obj in lines_objects:
102
- # bbox = obj.bbox
103
-
104
- # # Crop the full line
105
- # pad = 6
106
- # x1 = max(0, int(bbox[0]) - pad)
107
- # y1 = max(0, int(bbox[1]) - pad)
108
- # x2 = min(w, int(bbox[2]) + pad)
109
- # y2 = min(h, int(bbox[3]) + pad)
110
-
111
- # line_crop = image.crop((x1, y1, x2, y2))
112
- # line_crops.append(line_crop)
113
-
114
- # # Batch processing
115
- # full_text_lines = []
116
- # batch_size = 4
117
-
118
- # for i in range(0, len(line_crops), batch_size):
119
- # batch = line_crops[i:i+batch_size]
120
- # try:
121
- # batch_results = recognize_batch(batch)
122
- # full_text_lines.extend(batch_results)
123
- # except Exception as e:
124
- # logger.error(f"Batch failed: {e}")
125
- # full_text_lines.append("[Error processing line]")
126
-
127
- # final_text = "\n".join(full_text_lines)
128
-
129
- # # Visualize
130
- # vis_img = draw_boxes(image.copy(), lines_objects)
131
-
132
- # return vis_img, final_text
133
-
134
- # # ==========================================
135
- # # 4. GRADIO UI
136
- # # ==========================================
137
- # custom_css = """
138
- # .gen-button { background-color: #ff4081 !important; color: white !important; font-weight: bold !important; }
139
- # """
140
-
141
- # with gr.Blocks(css=custom_css) as demo:
142
- # gr.Markdown("# 🚀 Hybrid OCR: Surya (Raw) + TrOCR (Corrected)")
143
-
144
- # with gr.Row():
145
- # ocr_input = gr.Image(type="pil", label="Upload Image")
146
- # ocr_output_img = gr.Image(type="pil", label="Surya Detections")
147
-
148
- # ocr_text = gr.Textbox(label="Recognized Text", lines=20)
149
- # ocr_button = gr.Button("Run Hybrid OCR", elem_classes="gen-button")
150
-
151
- # ocr_button.click(hybrid_ocr_workflow, inputs=[ocr_input], outputs=[ocr_output_img, ocr_text])
152
-
153
- # if __name__ == "__main__":
154
- # demo.launch(theme=gr.themes.Soft(), css=custom_css)
155
-
156
-
157
-
158
-
159
-
160
-
161
-
162
-
163
-
164
- import os
165
-
166
- # ==========================================
167
- # 0. SURYA CONFIGURATION
168
- # ==========================================
169
- # MUST be set before importing surya to take effect.
170
- # 1. Lower text threshold (0.6 -> 0.50) to catch faint handwriting strokes
171
- os.environ["DETECTOR_TEXT_THRESHOLD"] = "0.50"
172
- # 2. Raise blank threshold (0.7 -> 0.80) to prevent splitting wavy lines
173
- os.environ["DETECTOR_BLANK_THRESHOLD"] = "0.80"
174
-
175
  import gradio as gr
176
  import logging
 
177
  import numpy as np
178
  import torch
179
  from PIL import Image, ImageDraw
@@ -201,7 +28,8 @@ det_processor = load_det_processor()
201
  det_model = load_det_model().to(device)
202
 
203
  # B. TROCR RECOGNITION
204
- # NOTE: Quantization removed to prevent "prior domination" hallucinations.
 
205
  trocr_processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten')
206
  trocr_model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten').to(device)
207
 
@@ -210,43 +38,20 @@ logger.info("✅ All Models Loaded.")
210
  # ==========================================
211
  # 2. HELPER FUNCTIONS
212
  # ==========================================
213
- def pad_to_square(image):
214
- """
215
- Pads a crop to be roughly square (or at least 4:3) to prevent
216
- the ViT encoder from squashing long text strips into nonsense.
217
- """
218
- w, h = image.size
219
- # If already roughly square or tall, leave it
220
- if w <= h * 1.5:
221
- return image
222
-
223
- # Target a 2:1 aspect ratio roughly (or just make it taller)
224
- target_h = int(w * 0.5)
225
- if target_h <= h: return image
226
-
227
- # Create white background
228
- new_img = Image.new("RGB", (w, target_h), (255, 255, 255))
229
- paste_y = (target_h - h) // 2
230
- new_img.paste(image, (0, paste_y))
231
- return new_img
232
-
233
  def recognize_batch(crops):
234
  """
235
- Feeds processed crops to TrOCR.
236
  """
237
  if not crops: return []
238
 
239
- # Filter invalid crops
240
  valid_crops = [c for c in crops if c.size[0] > 0 and c.size[1] > 0]
241
  if not valid_crops: return []
242
 
243
- # Pre-process: Pad to avoid aspect ratio distortion
244
- processed_crops = [pad_to_square(c) for c in valid_crops]
245
-
246
- pixel_values = trocr_processor(images=processed_crops, return_tensors="pt").pixel_values.to(device)
247
 
248
  with torch.no_grad():
249
- # max_length=64 prevents rambling if model gets confused
250
  generated_ids = trocr_model.generate(pixel_values, max_length=64)
251
  text = trocr_processor.batch_decode(generated_ids, skip_special_tokens=True)
252
  return text
@@ -257,6 +62,7 @@ def draw_boxes(image, prediction_objects):
257
  if hasattr(obj, "bbox"):
258
  draw.rectangle(obj.bbox, outline="red", width=2)
259
  else:
 
260
  draw.rectangle(obj, outline="red", width=2)
261
  return image
262
 
@@ -266,12 +72,13 @@ def draw_boxes(image, prediction_objects):
266
  def hybrid_ocr_workflow(image):
267
  if image is None: return None, "Please upload an image."
268
 
269
- # CRITICAL: TrOCR fails silently on RGBA/P modes
270
  if image.mode != "RGB":
271
  image = image.convert("RGB")
272
 
273
  # 1. DETECT (Surya)
274
  logger.info("Step 1: Detecting Lines with Surya...")
 
275
  predictions = batch_text_detection([image], det_model, det_processor)
276
  result = predictions[0]
277
 
@@ -282,7 +89,7 @@ def hybrid_ocr_workflow(image):
282
  elif hasattr(result, "text_lines"):
283
  lines_objects = result.text_lines
284
 
285
- # Sort by Y-coordinate
286
  lines_objects.sort(key=lambda x: x.bbox[1])
287
 
288
  # 2. CROP & RECOGNIZE
@@ -294,7 +101,7 @@ def hybrid_ocr_workflow(image):
294
  for obj in lines_objects:
295
  bbox = obj.bbox
296
 
297
- # Crop with slight padding
298
  pad = 6
299
  x1 = max(0, int(bbox[0]) - pad)
300
  y1 = max(0, int(bbox[1]) - pad)
@@ -306,7 +113,7 @@ def hybrid_ocr_workflow(image):
306
 
307
  # Batch processing
308
  full_text_lines = []
309
- batch_size = 4
310
 
311
  for i in range(0, len(line_crops), batch_size):
312
  batch = line_crops[i:i+batch_size]
@@ -332,11 +139,11 @@ custom_css = """
332
  """
333
 
334
  with gr.Blocks(css=custom_css) as demo:
335
- gr.Markdown("# 🚀 Hybrid OCR: Surya (Optimized) + TrOCR (Corrected)")
336
 
337
  with gr.Row():
338
  ocr_input = gr.Image(type="pil", label="Upload Image")
339
- ocr_output_img = gr.Image(type="pil", label="Surya Detections (Tuned)")
340
 
341
  ocr_text = gr.Textbox(label="Recognized Text", lines=20)
342
  ocr_button = gr.Button("Run Hybrid OCR", elem_classes="gen-button")
@@ -344,4 +151,12 @@ with gr.Blocks(css=custom_css) as demo:
344
  ocr_button.click(hybrid_ocr_workflow, inputs=[ocr_input], outputs=[ocr_output_img, ocr_text])
345
 
346
  if __name__ == "__main__":
347
- demo.launch(theme=gr.themes.Soft(), css=custom_css)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import logging
3
+ import os
4
  import numpy as np
5
  import torch
6
  from PIL import Image, ImageDraw
 
28
  det_model = load_det_model().to(device)
29
 
30
  # B. TROCR RECOGNITION
31
+ # NOTE: We do NOT use quantization here. It destroys the attention mechanism in ViT
32
+ # encoders on CPU, leading to "mode collapse" (hallucinations).
33
  trocr_processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten')
34
  trocr_model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten').to(device)
35
 
 
38
  # ==========================================
39
  # 2. HELPER FUNCTIONS
40
  # ==========================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  def recognize_batch(crops):
42
  """
43
+ Feeds raw crops directly to TrOCR.
44
  """
45
  if not crops: return []
46
 
47
+ # Ensure crops are valid
48
  valid_crops = [c for c in crops if c.size[0] > 0 and c.size[1] > 0]
49
  if not valid_crops: return []
50
 
51
+ pixel_values = trocr_processor(images=valid_crops, return_tensors="pt").pixel_values.to(device)
 
 
 
52
 
53
  with torch.no_grad():
54
+ # Using a slightly lower max_length prevents it from rambling if it gets confused
55
  generated_ids = trocr_model.generate(pixel_values, max_length=64)
56
  text = trocr_processor.batch_decode(generated_ids, skip_special_tokens=True)
57
  return text
 
62
  if hasattr(obj, "bbox"):
63
  draw.rectangle(obj.bbox, outline="red", width=2)
64
  else:
65
+ # Fallback if obj is just a list/tuple
66
  draw.rectangle(obj, outline="red", width=2)
67
  return image
68
 
 
72
  def hybrid_ocr_workflow(image):
73
  if image is None: return None, "Please upload an image."
74
 
75
+ # CRITICAL FIX: Ensure image is RGB (TrOCR fails on RGBA/P modes silently)
76
  if image.mode != "RGB":
77
  image = image.convert("RGB")
78
 
79
  # 1. DETECT (Surya)
80
  logger.info("Step 1: Detecting Lines with Surya...")
81
+ # Surya expects list of images
82
  predictions = batch_text_detection([image], det_model, det_processor)
83
  result = predictions[0]
84
 
 
89
  elif hasattr(result, "text_lines"):
90
  lines_objects = result.text_lines
91
 
92
+ # Sort by Y-coordinate (top to bottom)
93
  lines_objects.sort(key=lambda x: x.bbox[1])
94
 
95
  # 2. CROP & RECOGNIZE
 
101
  for obj in lines_objects:
102
  bbox = obj.bbox
103
 
104
+ # Crop the full line
105
  pad = 6
106
  x1 = max(0, int(bbox[0]) - pad)
107
  y1 = max(0, int(bbox[1]) - pad)
 
113
 
114
  # Batch processing
115
  full_text_lines = []
116
+ batch_size = 4
117
 
118
  for i in range(0, len(line_crops), batch_size):
119
  batch = line_crops[i:i+batch_size]
 
139
  """
140
 
141
  with gr.Blocks(css=custom_css) as demo:
142
+ gr.Markdown("# 🚀 Hybrid OCR: Surya (Raw) + TrOCR (Corrected)")
143
 
144
  with gr.Row():
145
  ocr_input = gr.Image(type="pil", label="Upload Image")
146
+ ocr_output_img = gr.Image(type="pil", label="Surya Detections")
147
 
148
  ocr_text = gr.Textbox(label="Recognized Text", lines=20)
149
  ocr_button = gr.Button("Run Hybrid OCR", elem_classes="gen-button")
 
151
  ocr_button.click(hybrid_ocr_workflow, inputs=[ocr_input], outputs=[ocr_output_img, ocr_text])
152
 
153
  if __name__ == "__main__":
154
+ demo.launch(theme=gr.themes.Soft(), css=custom_css)
155
+
156
+
157
+
158
+
159
+
160
+
161
+
162
+