iammraat commited on
Commit
1ddb5bb
·
verified ·
1 Parent(s): f8bcded

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +163 -73
app.py CHANGED
@@ -66,6 +66,141 @@
66
 
67
 
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  import gradio as gr
70
  import torch
71
  import numpy as np
@@ -74,122 +209,77 @@ from PIL import Image
74
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
75
  from craft_text_detector import Craft
76
 
77
- # ----------------------------
78
- # Device
79
- # ----------------------------
80
  device = "cuda" if torch.cuda.is_available() else "cpu"
81
 
82
- # ----------------------------
83
- # Load TrOCR
84
- # ----------------------------
85
  print("Loading TrOCR model...")
86
  processor = TrOCRProcessor.from_pretrained("microsoft/trocr-small-handwritten")
87
- model = VisionEncoderDecoderModel.from_pretrained(
88
- "microsoft/trocr-small-handwritten"
89
- )
90
  model.to(device)
91
  model.eval()
92
 
93
- # ----------------------------
94
- # Load CRAFT
95
- # ----------------------------
96
  print("Loading CRAFT text detector...")
97
- craft = Craft(
98
- output_dir=None,
99
- crop_type="poly",
100
- cuda=(device == "cuda"),
101
- )
102
 
103
- # ----------------------------
104
- # Sort boxes (reading order)
105
- # ----------------------------
106
  def get_sorted_boxes(boxes):
107
  items = []
108
  for box in boxes:
109
  cx = np.mean(box[:, 0])
110
  cy = np.mean(box[:, 1])
111
  items.append((cy, cx, box))
112
-
113
- # group by line (roughly)
114
  items.sort(key=lambda x: (int(x[0] // 20), x[1]))
115
  return [b for _, _, b in items]
116
 
117
- # ----------------------------
118
- # OCR Pipeline
119
- # ----------------------------
120
- def process_full_page(image: Image.Image):
121
- # ALWAYS return (image_or_None, text)
122
  if image is None:
123
  return None, "Please upload an image."
124
-
125
  image_np = np.array(image)
126
-
127
  prediction = craft.detect_text(image_np)
128
  boxes = prediction.get("boxes", [])
129
-
130
  if not boxes:
131
  return image, "No text detected."
132
-
133
  sorted_boxes = get_sorted_boxes(boxes)
134
  annotated = image_np.copy()
135
  texts = []
136
-
137
  for box in sorted_boxes:
138
  box = box.astype(int)
139
  cv2.polylines(annotated, [box], True, (255, 0, 0), 2)
140
-
141
  x_min = max(0, box[:, 0].min())
142
  x_max = min(image_np.shape[1], box[:, 0].max())
143
  y_min = max(0, box[:, 1].min())
144
  y_max = min(image_np.shape[0], box[:, 1].max())
145
-
146
  if x_max - x_min < 5 or y_max - y_min < 5:
147
  continue
148
-
149
  crop = image_np[y_min:y_max, x_min:x_max]
150
  pil_crop = Image.fromarray(crop).convert("RGB")
151
-
152
  with torch.no_grad():
153
- pixels = processor(
154
- images=pil_crop,
155
- return_tensors="pt"
156
- ).pixel_values.to(device)
157
-
158
  ids = model.generate(pixels)
159
- text = processor.batch_decode(
160
- ids, skip_special_tokens=True
161
- )[0]
162
-
163
- if text.strip():
164
- texts.append(text)
165
-
166
  final_text = " ".join(texts)
167
  return Image.fromarray(annotated), final_text
168
 
169
- # ----------------------------
170
- # Gradio UI
171
- # ----------------------------
172
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
173
- gr.Markdown("# 🕵️‍♀️ Full-Page Handwritten OCR")
174
- gr.Markdown("**CRAFT TrOCR** (Detection + Recognition)")
175
-
176
- with gr.Row():
177
- input_img = gr.Image(type="pil", label="Upload Full Page")
178
-
179
- with gr.Row():
180
- vis_output = gr.Image(label="Detections")
181
- text_output = gr.Textbox(label="Extracted Text", lines=10)
182
-
183
- btn = gr.Button("Process Page", variant="primary")
184
- btn.click(
185
- fn=process_full_page,
186
- inputs=input_img,
187
- outputs=[vis_output, text_output],
188
- )
189
 
190
  if __name__ == "__main__":
191
- demo.launch(
192
- server_name="0.0.0.0",
193
- server_port=7860,
194
- show_api=False,
195
- )
 
66
 
67
 
68
 
69
+ # import gradio as gr
70
+ # import torch
71
+ # import numpy as np
72
+ # import cv2
73
+ # from PIL import Image
74
+ # from transformers import TrOCRProcessor, VisionEncoderDecoderModel
75
+ # from craft_text_detector import Craft
76
+
77
+ # # ----------------------------
78
+ # # Device
79
+ # # ----------------------------
80
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
81
+
82
+ # # ----------------------------
83
+ # # Load TrOCR
84
+ # # ----------------------------
85
+ # print("Loading TrOCR model...")
86
+ # processor = TrOCRProcessor.from_pretrained("microsoft/trocr-small-handwritten")
87
+ # model = VisionEncoderDecoderModel.from_pretrained(
88
+ # "microsoft/trocr-small-handwritten"
89
+ # )
90
+ # model.to(device)
91
+ # model.eval()
92
+
93
+ # # ----------------------------
94
+ # # Load CRAFT
95
+ # # ----------------------------
96
+ # print("Loading CRAFT text detector...")
97
+ # craft = Craft(
98
+ # output_dir=None,
99
+ # crop_type="poly",
100
+ # cuda=(device == "cuda"),
101
+ # )
102
+
103
+ # # ----------------------------
104
+ # # Sort boxes (reading order)
105
+ # # ----------------------------
106
+ # def get_sorted_boxes(boxes):
107
+ # items = []
108
+ # for box in boxes:
109
+ # cx = np.mean(box[:, 0])
110
+ # cy = np.mean(box[:, 1])
111
+ # items.append((cy, cx, box))
112
+
113
+ # # group by line (roughly)
114
+ # items.sort(key=lambda x: (int(x[0] // 20), x[1]))
115
+ # return [b for _, _, b in items]
116
+
117
+ # # ----------------------------
118
+ # # OCR Pipeline
119
+ # # ----------------------------
120
+ # def process_full_page(image: Image.Image):
121
+ # # ALWAYS return (image_or_None, text)
122
+ # if image is None:
123
+ # return None, "Please upload an image."
124
+
125
+ # image_np = np.array(image)
126
+
127
+ # prediction = craft.detect_text(image_np)
128
+ # boxes = prediction.get("boxes", [])
129
+
130
+ # if not boxes:
131
+ # return image, "No text detected."
132
+
133
+ # sorted_boxes = get_sorted_boxes(boxes)
134
+ # annotated = image_np.copy()
135
+ # texts = []
136
+
137
+ # for box in sorted_boxes:
138
+ # box = box.astype(int)
139
+ # cv2.polylines(annotated, [box], True, (255, 0, 0), 2)
140
+
141
+ # x_min = max(0, box[:, 0].min())
142
+ # x_max = min(image_np.shape[1], box[:, 0].max())
143
+ # y_min = max(0, box[:, 1].min())
144
+ # y_max = min(image_np.shape[0], box[:, 1].max())
145
+
146
+ # if x_max - x_min < 5 or y_max - y_min < 5:
147
+ # continue
148
+
149
+ # crop = image_np[y_min:y_max, x_min:x_max]
150
+ # pil_crop = Image.fromarray(crop).convert("RGB")
151
+
152
+ # with torch.no_grad():
153
+ # pixels = processor(
154
+ # images=pil_crop,
155
+ # return_tensors="pt"
156
+ # ).pixel_values.to(device)
157
+
158
+ # ids = model.generate(pixels)
159
+ # text = processor.batch_decode(
160
+ # ids, skip_special_tokens=True
161
+ # )[0]
162
+
163
+ # if text.strip():
164
+ # texts.append(text)
165
+
166
+ # final_text = " ".join(texts)
167
+ # return Image.fromarray(annotated), final_text
168
+
169
+ # # ----------------------------
170
+ # # Gradio UI
171
+ # # ----------------------------
172
+ # with gr.Blocks(theme=gr.themes.Soft()) as demo:
173
+ # gr.Markdown("# 🕵️‍♀️ Full-Page Handwritten OCR")
174
+ # gr.Markdown("**CRAFT ➜ TrOCR** (Detection + Recognition)")
175
+
176
+ # with gr.Row():
177
+ # input_img = gr.Image(type="pil", label="Upload Full Page")
178
+
179
+ # with gr.Row():
180
+ # vis_output = gr.Image(label="Detections")
181
+ # text_output = gr.Textbox(label="Extracted Text", lines=10)
182
+
183
+ # btn = gr.Button("Process Page", variant="primary")
184
+ # btn.click(
185
+ # fn=process_full_page,
186
+ # inputs=input_img,
187
+ # outputs=[vis_output, text_output],
188
+ # )
189
+
190
+ # if __name__ == "__main__":
191
+ # demo.launch(
192
+ # server_name="0.0.0.0",
193
+ # server_port=7860,
194
+ # show_api=False,
195
+ # )
196
+
197
+
198
+
199
+
200
+
201
+
202
+
203
+
204
  import gradio as gr
205
  import torch
206
  import numpy as np
 
209
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
210
  from craft_text_detector import Craft
211
 
 
 
 
212
  device = "cuda" if torch.cuda.is_available() else "cpu"
213
 
 
 
 
214
  print("Loading TrOCR model...")
215
  processor = TrOCRProcessor.from_pretrained("microsoft/trocr-small-handwritten")
216
+ model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-small-handwritten")
 
 
217
  model.to(device)
218
  model.eval()
219
 
 
 
 
220
  print("Loading CRAFT text detector...")
221
+ craft = Craft(output_dir=None, crop_type="poly", cuda=(device == "cuda"))
 
 
 
 
222
 
 
 
 
223
  def get_sorted_boxes(boxes):
224
  items = []
225
  for box in boxes:
226
  cx = np.mean(box[:, 0])
227
  cy = np.mean(box[:, 1])
228
  items.append((cy, cx, box))
 
 
229
  items.sort(key=lambda x: (int(x[0] // 20), x[1]))
230
  return [b for _, _, b in items]
231
 
232
+ def process_full_page(image):
 
 
 
 
233
  if image is None:
234
  return None, "Please upload an image."
235
+
236
  image_np = np.array(image)
 
237
  prediction = craft.detect_text(image_np)
238
  boxes = prediction.get("boxes", [])
239
+
240
  if not boxes:
241
  return image, "No text detected."
242
+
243
  sorted_boxes = get_sorted_boxes(boxes)
244
  annotated = image_np.copy()
245
  texts = []
246
+
247
  for box in sorted_boxes:
248
  box = box.astype(int)
249
  cv2.polylines(annotated, [box], True, (255, 0, 0), 2)
250
+
251
  x_min = max(0, box[:, 0].min())
252
  x_max = min(image_np.shape[1], box[:, 0].max())
253
  y_min = max(0, box[:, 1].min())
254
  y_max = min(image_np.shape[0], box[:, 1].max())
255
+
256
  if x_max - x_min < 5 or y_max - y_min < 5:
257
  continue
258
+
259
  crop = image_np[y_min:y_max, x_min:x_max]
260
  pil_crop = Image.fromarray(crop).convert("RGB")
261
+
262
  with torch.no_grad():
263
+ pixels = processor(images=pil_crop, return_tensors="pt").pixel_values.to(device)
 
 
 
 
264
  ids = model.generate(pixels)
265
+ text = processor.batch_decode(ids, skip_special_tokens=True)[0]
266
+
267
+ if text.strip():
268
+ texts.append(text)
269
+
 
 
270
  final_text = " ".join(texts)
271
  return Image.fromarray(annotated), final_text
272
 
273
+ demo = gr.Interface(
274
+ fn=process_full_page,
275
+ inputs=gr.Image(type="pil", label="Upload Full Page"),
276
+ outputs=[
277
+ gr.Image(label="Detections"),
278
+ gr.Textbox(label="Extracted Text", lines=10)
279
+ ],
280
+ title="🕵️‍♀️ Full-Page Handwritten OCR",
281
+ description="CRAFT TrOCR (Detection + Recognition)"
282
+ )
 
 
 
 
 
 
 
 
 
 
283
 
284
  if __name__ == "__main__":
285
+ demo.launch(server_name="0.0.0.0", server_port=7860)