imperiusrex commited on
Commit
5828679
·
verified ·
1 Parent(s): a0ee49a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +133 -368
app.py CHANGED
@@ -1,362 +1,45 @@
1
-
2
- from paddleocr import PaddleOCR
3
-
4
  from transformers import CLIPProcessor, CLIPModel
 
5
  from PIL import Image
6
  import torch
7
- from google.colab import files
8
- import json
9
- import os
10
- import glob
11
- from IPython.display import Image as ColabImage, display
12
- from paddleocr import PaddleOCR
13
- from paddleocr import TextDetection
14
-
15
-
16
- uploaded = files.upload()
17
- img_path = next(iter(uploaded.keys()))
18
- image = Image.open(img_path).convert("RGB")
19
-
20
- image = Image.open(img_path)
21
- width, height = image.size
22
- total_pixels = width * height
23
-
24
- print(f"Width: {width}, Height: {height}, Total pixels: {total_pixels}")
25
-
26
-
27
-
28
-
29
- # Initialize array for bounding boxes
30
- arr = []
31
-
32
- # Load and run the text detection model
33
- model = TextDetection(model_name="PP-OCRv5_server_det")
34
- output = model.predict(img_path, batch_size=1)
35
-
36
- # Extract bounding boxes
37
- for res in output:
38
- polys = res['dt_polys'] # NumPy array of shape (N, 4, 2)
39
- if polys is not None:
40
- arr.extend(polys.tolist())
41
-
42
- # Sort the bounding boxes in reading order
43
- arr = sorted(arr, key=lambda box: (box[0][1], box[0][0]))
44
-
45
- # Print the sorted bounding box coordinates
46
- print("Extracted bounding box coordinates (in reading order):")
47
- for box in arr:
48
- print(box)
49
-
50
- print(f"Number of detected text regions: {len(arr)}")
51
-
52
- import cv2
53
  import numpy as np
 
54
  import os
55
- import json
56
-
57
- # Load original image
58
-
59
- img = cv2.imread(img_path)
60
 
61
- # Output setup
62
- output_dir = "./output/crops_warped"
63
- os.makedirs(output_dir, exist_ok=True)
64
 
65
- cropped_images = []
66
-
67
- for i, box in enumerate(arr):
68
- box = np.array(box, dtype=np.float32) # shape: (4, 2)
69
-
70
- # Compute width and height of the new image
71
- width_a = np.linalg.norm(box[0] - box[1])
72
- width_b = np.linalg.norm(box[2] - box[3])
73
- height_a = np.linalg.norm(box[0] - box[3])
74
- height_b = np.linalg.norm(box[1] - box[2])
75
-
76
- width = int(max(width_a, width_b))
77
- height = int(max(height_a, height_b))
78
-
79
- # Destination rectangle
80
- dst_rect = np.array([
81
- [0, 0],
82
- [width - 1, 0],
83
- [width - 1, height - 1],
84
- [0, height - 1]
85
- ], dtype=np.float32)
86
-
87
- # Perspective transform
88
- M = cv2.getPerspectiveTransform(box, dst_rect)
89
- warped = cv2.warpPerspective(img, M, (width, height))
90
-
91
- cropped_images.append(warped)
92
-
93
- # Save warped image
94
- cv2.imwrite(os.path.join(output_dir, f"crop_{i}.png"), warped)
95
-
96
- print(f"Cropped {len(cropped_images)} perspective-warped regions.")
97
-
98
- # cropped_images.reverse()
99
-
100
-
101
- import matplotlib.pyplot as plt
102
-
103
- # Display all cropped images in a grid
104
- for i, crop in enumerate(cropped_images):
105
- plt.figure(figsize=(4, 4))
106
- plt.imshow(cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)) # Convert BGR to RGB
107
- plt.title(f'Cropped Image {i}')
108
- plt.axis('off')
109
- plt.show()
110
-
111
- # Load CLIP model
112
- clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
113
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
114
 
 
 
 
 
 
115
  # Candidate language phrases for detection
116
  candidates = [
117
  "This is English text",
118
- # "This is Hindi text",
119
- # "This is Tamil text",
120
  "This is Telugu text",
121
- # "This is Bengali text",
122
- # "This is Arabic text",
123
  "This is Chinese text",
124
- # "This is Japanese text",
125
  "This is Korean text",
126
- # "This is Russian text",
127
- # "This is Kannada text",
128
- # "This is Malayalam text",
129
- # "This is Marathi text",
130
- # "This is Urdu text",
131
- # "This is French text",
132
- # "This is Spanish text",
133
- # "This is Italian text",
134
- # "This is Portuguese text",
135
- # "This is Romanian text",
136
- # "This is Hungarian text",
137
- # "This is Indonesian text",
138
- # "This is Lithuanian text",
139
- # "This is Chinese Traditional text",
140
- # "This is Malay text",
141
- # "This is Dutch text",
142
- # "This is Norwegian text",
143
- # "This is Bosnian text",
144
- # "This is Polish text",
145
- # "This is Czech text",
146
- # "This is Slovak text",
147
- # "This is Welsh text",
148
- # "This is Slovenian text",
149
- # "This is Danish text",
150
- # "This is Albanian text",
151
- # "This is Estonian text",
152
- # "This is Swedish text",
153
- # "This is Irish text",
154
- # "This is Swahili text",
155
- # "This is Croatian text",
156
- # "This is Uzbek text",
157
- # "This is Turkish text",
158
- # "This is Latin text",
159
- # "This is Belarusian text",
160
- # "This is Ukrainian text"
161
  ]
162
 
163
  # Map detected languages to PaddleOCR language codes
164
  lang_map = {
165
  "english": "en",
166
- # "hindi": "hi",
167
- # "tamil": "ta",
168
  "telugu": "te",
169
- # "bengali": "bn",
170
- # "arabic": "ar",
171
  "chinese": "ch",
172
- # "japanese": "japan",
173
  "korean": "korean",
174
- # "russian": "ru",
175
- # "kannada": "kn",
176
- # "malayalam": "ml",
177
- # "marathi": "mr",
178
- # "urdu": "ur",
179
- # "french": "fr",
180
- # "spanish": "es",
181
- # "italian": "it",
182
- # "portuguese": "pt",
183
- # "romanian": "ro",
184
- # "hungarian": "hu",
185
- # "indonesian": "id",
186
- # "lithuanian": "lt",
187
- # "chinese traditional": "chinese_cht",
188
- # "malay": "ms",
189
- # "dutch": "nl",
190
- # "norwegian": "no",
191
- # "bosnian": "bs",
192
- # "polish": "pl",
193
- # "czech": "cs",
194
- # "slovak": "sk",
195
- # "welsh": "cy",
196
- # "slovenian": "sl",
197
- # "danish": "da",
198
- # "albanian": "sq",
199
- # "estonian": "et",
200
- # "swedish": "sv",
201
- # "irish": "ga",
202
- # "swahili": "sw",
203
- # "croatian": "hr",
204
- # "uzbek": "uz",
205
- # "turkish": "tr",
206
- # "latin": "la",
207
- # "belarusian": "be",
208
- # "ukrainian": "uk"
209
  }
210
- for img in cropped_images:
211
- # Get probabilities
212
- inputs = processor(text=candidates, images=img, return_tensors="pt", padding=True)
213
- with torch.no_grad():
214
- logits_per_image = clip_model(**inputs).logits_per_image
215
- probs = logits_per_image.softmax(dim=1)
216
-
217
- # Get best language match
218
- best = probs.argmax().item()
219
- detected_lang_phrase = candidates[best]
220
- detected_lang = detected_lang_phrase.split()[-2].lower()
221
- lang_code = lang_map.get(detected_lang, "en")
222
-
223
- print(f"\n✅ Detected script/language: {detected_lang_phrase} → PaddleOCR lang='{lang_code}'")
224
-
225
-
226
- import numpy as np # Ensure numpy is imported
227
- import os # Ensure os is imported
228
-
229
- ocr = PaddleOCR(
230
- use_doc_orientation_classify=False, # Enable orientation classification for auto lang detection
231
- use_doc_unwarping=False,
232
- use_textline_orientation=False, # Enable textline orientation for auto lang detection
233
- lang=lang_code, # Use paddleOCR's auto language detection
234
- device="cpu"
235
- )
236
-
237
- for i, img in enumerate(cropped_images):
238
- # Define output diarectory and make sure it exists
239
-
240
- # Get base name of uploaded image (without extension)
241
- # Use a unique name for each cropped image
242
- base_name = f"cropped_image_{i}"
243
- bounding_boxes_image_path=os.path.join("/content/", f"{base_name}_bounding_boxes.jpg")
244
- json_file_path=os.path.join("/content/", f"{base_name}.json")
245
-
246
-
247
- # Convert the PIL Image to a NumPy array
248
- # image_np = np.array(img) # img is already a numpy array from cv2
249
-
250
- # Skip small images that might cause errors
251
- if img.shape[0] < 10 or img.shape[1] < 10:
252
- print(f"Skipping small image {i} with shape {img.shape}")
253
- continue
254
-
255
- # Perform OCR and save results
256
- result = ocr.predict(img) # Pass the NumPy array
257
- # print(f"\n====json_output for cropped image {i}====\n")
258
-
259
-
260
- # Assuming the first element of the result contains the overall detection
261
- if result and result[0]: # Check if result is not empty and has at least one element
262
- # Print results for each detected element
263
- # for res in result:
264
- # res.print()
265
-
266
- # Save the combined result to image and json
267
- result[0].save_to_img(bounding_boxes_image_path)
268
- result[0].save_to_json(json_file_path)
269
-
270
- else:
271
- print(f"No OCR results found for cropped image {i}.")
272
-
273
- # Construct the expected saved image path
274
- saved_image_path = bounding_boxes_image_path
275
-
276
- if os.path.exists(saved_image_path):
277
- display(ColabImage(filename=saved_image_path))
278
- else:
279
- print(f"No OCR image found at: {saved_image_path}")
280
-
281
- # print("\n===== Markdown Layout Preview =====\n")
282
-
283
- # Construct the expected saved JSON path
284
- saved_json_path = json_file_path
285
-
286
- if os.path.exists(saved_json_path):
287
- with open(saved_json_path, "r", encoding="utf-8") as f:
288
- data = json.load(f)
289
-
290
- texts = data["rec_texts"]
291
- boxes = data["rec_boxes"]
292
-
293
- elements = []
294
- for text, box in zip(texts, boxes):
295
- x1, y1, x2, y2 = box
296
- elements.append({"text": text, "x": x1, "y": y1, "line_y": (y1 + y2) / 2})
297
-
298
- elements.sort(key=lambda e: (round(e["line_y"] / 10), e["x"]))
299
-
300
- lines = []
301
- current_line_y = None
302
- line = []
303
-
304
- for elem in elements:
305
- if current_line_y is None or abs(elem["line_y"] - current_line_y) <= 10:
306
- line.append(elem["text"])
307
- current_line_y = elem["line_y"]
308
- else:
309
- lines.append(line)
310
- line = [elem["text"]]
311
- current_line_y = elem["line_y"]
312
-
313
- if line:
314
- lines.append(line)
315
-
316
- markdown_output = ""
317
-
318
- for line in lines:
319
- markdown_output += " ".join(line) #+ "\n\n"
320
-
321
-
322
- print(markdown_output)
323
-
324
- else:
325
- print(f"No JSON file found at: {saved_json_path}")
326
-
327
- import json
328
- import os
329
-
330
- predicted_texts = []
331
-
332
- # Iterate through the number of cropped images we have
333
- for i in range(len(cropped_images)):
334
- json_file_path = f"/content/cropped_image_{i}.json"
335
- text_for_this_image = ""
336
-
337
- if os.path.exists(json_file_path):
338
- with open(json_file_path, 'r', encoding='utf-8') as f:
339
- try:
340
- # Load the JSON data
341
- data = json.load(f)
342
- # The result from predict() is a dictionary, not a list
343
- if data and 'rec_texts' in data:
344
- text_for_this_image = " ".join(data['rec_texts'])
345
- except json.JSONDecodeError:
346
- print(f"Warning: Could not decode JSON from {json_file_path}")
347
- except KeyError as e:
348
- print(f"Warning: Unexpected JSON structure in {json_file_path}: {e}")
349
-
350
-
351
- predicted_texts.append(text_for_this_image)
352
-
353
- # Display the final list of predicted texts
354
- print("Predicted Texts Array (from JSON files):")
355
- print(predicted_texts)
356
-
357
- import json
358
- import os
359
 
 
360
  def get_box_center(box):
361
  """Calculates the center of a bounding box."""
362
  x_coords = [p[0] for p in box]
@@ -365,52 +48,134 @@ def get_box_center(box):
365
  center_y = sum(y_coords) / len(y_coords)
366
  return center_x, center_y
367
 
368
- # --- Step 1: Read all text and their centroid coordinates ---
369
- all_text_blocks = []
370
- for i, box in enumerate(arr):
371
- json_file_path = f"/content/cropped_image_{i}.json"
372
- if os.path.exists(json_file_path):
373
- with open(json_file_path, 'r', encoding='utf-8') as f:
374
- result = json.load(f)
375
-
376
- if result and 'rec_texts' in result and result['rec_texts']:
377
- text = " ".join(result['rec_texts'])
378
- center_x, center_y = get_box_center(box)
379
- all_text_blocks.append({
380
- "text": text,
381
- "center_x": center_x,
382
- "center_y": center_y
383
- })
384
-
385
- # --- Step 2: Sort by y-coordinate, then by x-coordinate, and group into lines ---
386
- if all_text_blocks:
387
- # Sort by center_y, then by center_x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
388
  sorted_blocks = sorted(all_text_blocks, key=lambda item: (item["center_y"], item["center_x"]))
389
-
390
  lines = []
391
  if sorted_blocks:
392
  current_line = [sorted_blocks[0]]
393
  for block in sorted_blocks[1:]:
394
- # Check if the vertical centers are close enough to be on the same line
395
- if abs(block["center_y"] - current_line[-1]["center_y"]) < 40: # Y-threshold
396
  current_line.append(block)
397
  else:
398
- # Sort the current line by x-coordinate and add it to the lines list
399
  current_line.sort(key=lambda item: item["center_x"])
400
  lines.append(" ".join([item["text"] for item in current_line]))
401
  current_line = [block]
402
 
403
- # Add the last line
404
  if current_line:
405
  current_line.sort(key=lambda item: item["center_x"])
406
  lines.append(" ".join([item["text"] for item in current_line]))
 
 
 
 
 
 
 
 
 
 
 
 
 
407
 
408
- # --- Step 3: Print the final reconstructed text ---
409
- if lines:
410
- for line in lines:
411
- print(line)
412
- else:
413
- print("No text was reconstructed.")
414
- else:
415
- print("No text blocks were found.")
416
-
 
1
+ import gradio as gr
 
 
2
  from transformers import CLIPProcessor, CLIPModel
3
+ from paddleocr import PaddleOCR, TextDetection
4
  from PIL import Image
5
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  import numpy as np
7
+ import cv2
8
  import os
9
+ import spaces
 
 
 
 
10
 
11
+ # --- Global setup for models and data ---
12
+ # This section runs once when the app starts.
13
+ print("Initializing models...")
14
 
15
+ # Load CLIP model once.
16
+ # By default, Hugging Face transformers will load models to the GPU if available.
17
+ clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
19
 
20
+ # Initialize Paddle's text detection model.
21
+ # The 'use_gpu=True' parameter is crucial here to ensure it uses the GPU.
22
+ # The model will be downloaded and loaded into GPU memory once.
23
+ det_model = TextDetection(model_name="PP-OCRv5_server_det", use_gpu=True)
24
+
25
  # Candidate language phrases for detection
26
  candidates = [
27
  "This is English text",
 
 
28
  "This is Telugu text",
 
 
29
  "This is Chinese text",
 
30
  "This is Korean text",
31
+ # Add other languages as needed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  ]
33
 
34
  # Map detected languages to PaddleOCR language codes
35
  lang_map = {
36
  "english": "en",
 
 
37
  "telugu": "te",
 
 
38
  "chinese": "ch",
 
39
  "korean": "korean",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
+ # --- Utility Functions ---
43
  def get_box_center(box):
44
  """Calculates the center of a bounding box."""
45
  x_coords = [p[0] for p in box]
 
48
  center_y = sum(y_coords) / len(y_coords)
49
  return center_x, center_y
50
 
51
+ # --- Main OCR Pipeline Function ---
52
+ @spaces.GPU # This decorator ensures the function is executed on the assigned GPU.
53
+ def ocr_pipeline(image: Image.Image) -> str:
54
+ """
55
+ Performs OCR on an input image using a multi-step pipeline.
56
+
57
+ Args:
58
+ image: A PIL Image object from the Gradio interface.
59
+
60
+ Returns:
61
+ A string containing the reconstructed text.
62
+ """
63
+ if image is None:
64
+ return "No image provided."
65
+
66
+ print("Starting OCR pipeline...")
67
+
68
+ # Convert PIL image to a NumPy array for OpenCV and Paddle
69
+ img_np = np.array(image.convert("RGB"))
70
+
71
+ # Step 1: Text Detection with PaddleOCR's model
72
+ # This will be fast on the H200 GPU.
73
+ output = det_model.predict(img_np, batch_size=1)
74
+
75
+ arr = []
76
+ for res in output:
77
+ polys = res['dt_polys']
78
+ if polys is not None:
79
+ arr.extend(polys.tolist())
80
+
81
+ # Sort the bounding boxes in reading order
82
+ arr = sorted(arr, key=lambda box: (box[0][1], box[0][0]))
83
+
84
+ if not arr:
85
+ print("No text regions detected.")
86
+ return "No text regions detected."
87
+
88
+ cropped_images = []
89
+ for box in arr:
90
+ box = np.array(box, dtype=np.float32)
91
+ width_a = np.linalg.norm(box[0] - box[1])
92
+ width_b = np.linalg.norm(box[2] - box[3])
93
+ height_a = np.linalg.norm(box[0] - box[3])
94
+ height_b = np.linalg.norm(box[1] - box[2])
95
+ width = int(max(width_a, width_b))
96
+ height = int(max(height_a, height_b))
97
+ dst_rect = np.array([
98
+ [0, 0],
99
+ [width - 1, 0],
100
+ [width - 1, height - 1],
101
+ [0, height - 1]
102
+ ], dtype=np.float32)
103
+ M = cv2.getPerspectiveTransform(box, dst_rect)
104
+ warped = cv2.warpPerspective(img_np, M, (width, height))
105
+ cropped_images.append(warped)
106
+
107
+ # Step 2: Language detection with CLIP and OCR on cropped images
108
+ all_text_blocks = []
109
+
110
+ for i, img in enumerate(cropped_images):
111
+ pil_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
112
+
113
+ # Use CLIP to detect language. The model is already on the GPU.
114
+ inputs = processor(text=candidates, images=pil_img, return_tensors="pt", padding=True)
115
+ # Move inputs to the GPU
116
+ inputs = {k: v.to(clip_model.device) for k, v in inputs.items()}
117
+ with torch.no_grad():
118
+ outputs = clip_model(**inputs)
119
+ logits_per_image = outputs.logits_per_image
120
+ probs = logits_per_image.softmax(dim=1)
121
+
122
+ best = probs.argmax().item()
123
+ detected_lang_phrase = candidates[best]
124
+ detected_lang = detected_lang_phrase.split()[-2].lower()
125
+ lang_code = lang_map.get(detected_lang, "en")
126
+
127
+ # Initialize PaddleOCR with the detected language.
128
+ # This part will run fast as the H200 GPU accelerates the model.
129
+ ocr = PaddleOCR(lang=lang_code, use_gpu=True, use_angle_cls=False, use_doc_unwarping=False)
130
+ result = ocr.predict(img)
131
+
132
+ # Extract text from OCR result
133
+ text_for_this_image = ""
134
+ if result and result[0] and result[0].get('rec_texts'):
135
+ text_for_this_image = " ".join(result[0]['rec_texts'])
136
+
137
+ # Store text and bounding box information
138
+ center_x, center_y = get_box_center(arr[i])
139
+ all_text_blocks.append({
140
+ "text": text_for_this_image,
141
+ "center_x": center_x,
142
+ "center_y": center_y
143
+ })
144
+
145
+ # Step 3: Reconstruct the text in reading order
146
+ if not all_text_blocks:
147
+ print("No text could be extracted.")
148
+ return "No text could be extracted."
149
+
150
  sorted_blocks = sorted(all_text_blocks, key=lambda item: (item["center_y"], item["center_x"]))
151
+
152
  lines = []
153
  if sorted_blocks:
154
  current_line = [sorted_blocks[0]]
155
  for block in sorted_blocks[1:]:
156
+ if abs(block["center_y"] - current_line[-1]["center_y"]) < 40:
 
157
  current_line.append(block)
158
  else:
 
159
  current_line.sort(key=lambda item: item["center_x"])
160
  lines.append(" ".join([item["text"] for item in current_line]))
161
  current_line = [block]
162
 
 
163
  if current_line:
164
  current_line.sort(key=lambda item: item["center_x"])
165
  lines.append(" ".join([item["text"] for item in current_line]))
166
+
167
+ final_text = "\n".join(lines)
168
+ print("OCR pipeline finished successfully.")
169
+ return final_text
170
+
171
+ # --- Gradio Interface ---
172
+ iface = gr.Interface(
173
+ fn=ocr_pipeline,
174
+ inputs=gr.Image(type="pil", label="Upload Image"),
175
+ outputs=gr.Textbox(label="Recognized Text"),
176
+ title="Printed Text OCR with PaddleOCR and CLIP",
177
+ description="Upload a printed text image. The app will detect text regions, identify the language with CLIP, and perform OCR to return the text in reading order. This space uses an H200 GPU for high-speed processing."
178
+ )
179
 
180
+ if __name__ == "__main__":
181
+ iface.launch()