imperiusrex commited on
Commit
051ce33
·
verified ·
1 Parent(s): 1bbbe26

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +401 -60
app.py CHANGED
@@ -1,76 +1,417 @@
1
- import spaces
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- # No spaces.GPU.require() here, remove it
4
 
5
- import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  import cv2
7
  import numpy as np
8
- import torch
9
- from PIL import Image
10
- from transformers import CLIPProcessor, CLIPModel
11
- from paddleocr import PaddleOCR
12
- import tempfile
13
 
14
- # Your utility functions here (run_text_detection, crop_and_warp_regions, etc.)
15
 
16
- clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
17
- clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
18
 
19
- language_map = {
20
- "english": "en",
21
- "telugu": "te",
22
- "chinese": "ch",
23
- "korean": "korean"
24
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  candidates = [
27
  "This is English text",
 
 
28
  "This is Telugu text",
 
 
29
  "This is Chinese text",
30
- "This is Korean text"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  ]
32
 
33
- @spaces.GPU # Decorate the function you want to run on GPU
34
- def process_image(image):
35
- # Your processing logic here
36
- image_pil = Image.fromarray(image).convert("RGB")
37
- img_path = "uploaded.jpg"
38
- image_pil.save(img_path)
39
-
40
- boxes = run_text_detection(img_path)
41
- crops = crop_and_warp_regions(img_path, boxes)
42
-
43
- all_results = []
44
- for crop in crops:
45
- lang = detect_language_clip(crop, clip_model, clip_processor, candidates)
46
- lang_code = language_map.get(lang, "en")
47
- ocr_model = PaddleOCR(
48
- use_doc_orientation_classify=False,
49
- use_doc_unwarping=False,
50
- use_textline_orientation=False,
51
- lang=lang_code,
52
- det=False,
53
- rec=True,
54
- cls=False,
55
- show_log=False
56
- )
57
- texts = run_paddle_ocr(crop, ocr_model)
58
- all_results.append({
59
- "lang": lang,
60
- "texts": texts,
61
- "image": crop
62
- })
63
-
64
- final_lines = group_text_by_position(all_results, boxes)
65
- return "\n".join(final_lines)
66
-
67
- interface = gr.Interface(
68
- fn=process_image,
69
- inputs=gr.Image(type="numpy", label="Upload an Image"),
70
- outputs=gr.Textbox(label="Reconstructed Text"),
71
- title="Printed Text OCR",
72
- description="Upload a printed or scanned document image. The system detects text regions, recognizes language, runs OCR, and reconstructs the output."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  )
74
 
75
- if __name__ == "__main__":
76
- interface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ !pip install transformers ftfy paddleocr paddlepaddle
2
+ from paddleocr import PaddleOCR
3
+
4
+ from transformers import CLIPProcessor, CLIPModel
5
+ from PIL import Image
6
+ import torch
7
+ from google.colab import files
8
+ import json
9
+ import os
10
+ import glob
11
+ from IPython.display import Image as ColabImage, display
12
+ from paddleocr import PaddleOCR
13
+ from paddleocr import TextDetection
14
+
15
+
16
+ uploaded = files.upload()
17
+ img_path = next(iter(uploaded.keys()))
18
+ image = Image.open(img_path).convert("RGB")
19
+
20
+ image = Image.open(img_path)
21
+ width, height = image.size
22
+ total_pixels = width * height
23
+
24
+ print(f"Width: {width}, Height: {height}, Total pixels: {total_pixels}")
25
+
26
+
27
 
 
28
 
29
+ # Initialize array for bounding boxes
30
+ arr = []
31
+
32
+ # Load and run the text detection model
33
+ model = TextDetection(model_name="PP-OCRv5_server_det")
34
+ output = model.predict(img_path, batch_size=1)
35
+
36
+ # Extract bounding boxes
37
+ for res in output:
38
+ polys = res['dt_polys'] # NumPy array of shape (N, 4, 2)
39
+ if polys is not None:
40
+ arr.extend(polys.tolist())
41
+
42
+ # Sort the bounding boxes in reading order
43
+ arr = sorted(arr, key=lambda box: (box[0][1], box[0][0]))
44
+
45
+ # Print the sorted bounding box coordinates
46
+ print("Extracted bounding box coordinates (in reading order):")
47
+ for box in arr:
48
+ print(box)
49
+
50
+ print(f"Number of detected text regions: {len(arr)}")
51
+
52
  import cv2
53
  import numpy as np
54
+ import os
55
+ import json
 
 
 
56
 
57
+ # Load original image
58
 
59
+ img = cv2.imread(img_path)
 
60
 
61
+ # Output setup
62
+ output_dir = "./output/crops_warped"
63
+ os.makedirs(output_dir, exist_ok=True)
64
+
65
+ cropped_images = []
66
+
67
+ for i, box in enumerate(arr):
68
+ box = np.array(box, dtype=np.float32) # shape: (4, 2)
69
+
70
+ # Compute width and height of the new image
71
+ width_a = np.linalg.norm(box[0] - box[1])
72
+ width_b = np.linalg.norm(box[2] - box[3])
73
+ height_a = np.linalg.norm(box[0] - box[3])
74
+ height_b = np.linalg.norm(box[1] - box[2])
75
+
76
+ width = int(max(width_a, width_b))
77
+ height = int(max(height_a, height_b))
78
+
79
+ # Destination rectangle
80
+ dst_rect = np.array([
81
+ [0, 0],
82
+ [width - 1, 0],
83
+ [width - 1, height - 1],
84
+ [0, height - 1]
85
+ ], dtype=np.float32)
86
+
87
+ # Perspective transform
88
+ M = cv2.getPerspectiveTransform(box, dst_rect)
89
+ warped = cv2.warpPerspective(img, M, (width, height))
90
+
91
+ cropped_images.append(warped)
92
+
93
+ # Save warped image
94
+ cv2.imwrite(os.path.join(output_dir, f"crop_{i}.png"), warped)
95
 
96
+ print(f"Cropped {len(cropped_images)} perspective-warped regions.")
97
+
98
+ # cropped_images.reverse()
99
+
100
+
101
+ import matplotlib.pyplot as plt
102
+
103
+ # Display all cropped images in a grid
104
+ for i, crop in enumerate(cropped_images):
105
+ plt.figure(figsize=(4, 4))
106
+ plt.imshow(cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)) # Convert BGR to RGB
107
+ plt.title(f'Cropped Image {i}')
108
+ plt.axis('off')
109
+ plt.show()
110
+
111
+ # Load CLIP model
112
+ clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
113
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
114
+
115
+ # Candidate language phrases for detection
116
  candidates = [
117
  "This is English text",
118
+ # "This is Hindi text",
119
+ # "This is Tamil text",
120
  "This is Telugu text",
121
+ # "This is Bengali text",
122
+ # "This is Arabic text",
123
  "This is Chinese text",
124
+ # "This is Japanese text",
125
+ "This is Korean text",
126
+ # "This is Russian text",
127
+ # "This is Kannada text",
128
+ # "This is Malayalam text",
129
+ # "This is Marathi text",
130
+ # "This is Urdu text",
131
+ # "This is French text",
132
+ # "This is Spanish text",
133
+ # "This is Italian text",
134
+ # "This is Portuguese text",
135
+ # "This is Romanian text",
136
+ # "This is Hungarian text",
137
+ # "This is Indonesian text",
138
+ # "This is Lithuanian text",
139
+ # "This is Chinese Traditional text",
140
+ # "This is Malay text",
141
+ # "This is Dutch text",
142
+ # "This is Norwegian text",
143
+ # "This is Bosnian text",
144
+ # "This is Polish text",
145
+ # "This is Czech text",
146
+ # "This is Slovak text",
147
+ # "This is Welsh text",
148
+ # "This is Slovenian text",
149
+ # "This is Danish text",
150
+ # "This is Albanian text",
151
+ # "This is Estonian text",
152
+ # "This is Swedish text",
153
+ # "This is Irish text",
154
+ # "This is Swahili text",
155
+ # "This is Croatian text",
156
+ # "This is Uzbek text",
157
+ # "This is Turkish text",
158
+ # "This is Latin text",
159
+ # "This is Belarusian text",
160
+ # "This is Ukrainian text"
161
  ]
162
 
163
+ # Map detected languages to PaddleOCR language codes
164
+ lang_map = {
165
+ "english": "en",
166
+ # "hindi": "hi",
167
+ # "tamil": "ta",
168
+ "telugu": "te",
169
+ # "bengali": "bn",
170
+ # "arabic": "ar",
171
+ "chinese": "ch",
172
+ # "japanese": "japan",
173
+ "korean": "korean",
174
+ # "russian": "ru",
175
+ # "kannada": "kn",
176
+ # "malayalam": "ml",
177
+ # "marathi": "mr",
178
+ # "urdu": "ur",
179
+ # "french": "fr",
180
+ # "spanish": "es",
181
+ # "italian": "it",
182
+ # "portuguese": "pt",
183
+ # "romanian": "ro",
184
+ # "hungarian": "hu",
185
+ # "indonesian": "id",
186
+ # "lithuanian": "lt",
187
+ # "chinese traditional": "chinese_cht",
188
+ # "malay": "ms",
189
+ # "dutch": "nl",
190
+ # "norwegian": "no",
191
+ # "bosnian": "bs",
192
+ # "polish": "pl",
193
+ # "czech": "cs",
194
+ # "slovak": "sk",
195
+ # "welsh": "cy",
196
+ # "slovenian": "sl",
197
+ # "danish": "da",
198
+ # "albanian": "sq",
199
+ # "estonian": "et",
200
+ # "swedish": "sv",
201
+ # "irish": "ga",
202
+ # "swahili": "sw",
203
+ # "croatian": "hr",
204
+ # "uzbek": "uz",
205
+ # "turkish": "tr",
206
+ # "latin": "la",
207
+ # "belarusian": "be",
208
+ # "ukrainian": "uk"
209
+ }
210
+ for img in cropped_images:
211
+ # Get probabilities
212
+ inputs = processor(text=candidates, images=img, return_tensors="pt", padding=True)
213
+ with torch.no_grad():
214
+ logits_per_image = clip_model(**inputs).logits_per_image
215
+ probs = logits_per_image.softmax(dim=1)
216
+
217
+ # Get best language match
218
+ best = probs.argmax().item()
219
+ detected_lang_phrase = candidates[best]
220
+ detected_lang = detected_lang_phrase.split()[-2].lower()
221
+ lang_code = lang_map.get(detected_lang, "en")
222
+
223
+ print(f"\n✅ Detected script/language: {detected_lang_phrase} → PaddleOCR lang='{lang_code}'")
224
+
225
+
226
+ import numpy as np # Ensure numpy is imported
227
+ import os # Ensure os is imported
228
+
229
+ ocr = PaddleOCR(
230
+ use_doc_orientation_classify=False, # Enable orientation classification for auto lang detection
231
+ use_doc_unwarping=False,
232
+ use_textline_orientation=False, # Enable textline orientation for auto lang detection
233
+ lang=lang_code, # Use paddleOCR's auto language detection
234
+ device="cpu"
235
  )
236
 
237
+ for i, img in enumerate(cropped_images):
238
+ # Define output diarectory and make sure it exists
239
+
240
+ # Get base name of uploaded image (without extension)
241
+ # Use a unique name for each cropped image
242
+ base_name = f"cropped_image_{i}"
243
+ bounding_boxes_image_path=os.path.join("/content/", f"{base_name}_bounding_boxes.jpg")
244
+ json_file_path=os.path.join("/content/", f"{base_name}.json")
245
+
246
+
247
+ # Convert the PIL Image to a NumPy array
248
+ # image_np = np.array(img) # img is already a numpy array from cv2
249
+
250
+ # Skip small images that might cause errors
251
+ if img.shape[0] < 10 or img.shape[1] < 10:
252
+ print(f"Skipping small image {i} with shape {img.shape}")
253
+ continue
254
+
255
+ # Perform OCR and save results
256
+ result = ocr.predict(img) # Pass the NumPy array
257
+ # print(f"\n====json_output for cropped image {i}====\n")
258
+
259
+
260
+ # Assuming the first element of the result contains the overall detection
261
+ if result and result[0]: # Check if result is not empty and has at least one element
262
+ # Print results for each detected element
263
+ # for res in result:
264
+ # res.print()
265
+
266
+ # Save the combined result to image and json
267
+ result[0].save_to_img(bounding_boxes_image_path)
268
+ result[0].save_to_json(json_file_path)
269
+
270
+ else:
271
+ print(f"No OCR results found for cropped image {i}.")
272
+
273
+ # Construct the expected saved image path
274
+ saved_image_path = bounding_boxes_image_path
275
+
276
+ if os.path.exists(saved_image_path):
277
+ display(ColabImage(filename=saved_image_path))
278
+ else:
279
+ print(f"No OCR image found at: {saved_image_path}")
280
+
281
+ # print("\n===== Markdown Layout Preview =====\n")
282
+
283
+ # Construct the expected saved JSON path
284
+ saved_json_path = json_file_path
285
+
286
+ if os.path.exists(saved_json_path):
287
+ with open(saved_json_path, "r", encoding="utf-8") as f:
288
+ data = json.load(f)
289
+
290
+ texts = data["rec_texts"]
291
+ boxes = data["rec_boxes"]
292
+
293
+ elements = []
294
+ for text, box in zip(texts, boxes):
295
+ x1, y1, x2, y2 = box
296
+ elements.append({"text": text, "x": x1, "y": y1, "line_y": (y1 + y2) / 2})
297
+
298
+ elements.sort(key=lambda e: (round(e["line_y"] / 10), e["x"]))
299
+
300
+ lines = []
301
+ current_line_y = None
302
+ line = []
303
+
304
+ for elem in elements:
305
+ if current_line_y is None or abs(elem["line_y"] - current_line_y) <= 10:
306
+ line.append(elem["text"])
307
+ current_line_y = elem["line_y"]
308
+ else:
309
+ lines.append(line)
310
+ line = [elem["text"]]
311
+ current_line_y = elem["line_y"]
312
+
313
+ if line:
314
+ lines.append(line)
315
+
316
+ markdown_output = ""
317
+
318
+ for line in lines:
319
+ markdown_output += " ".join(line) #+ "\n\n"
320
+
321
+
322
+ print(markdown_output)
323
+
324
+ else:
325
+ print(f"No JSON file found at: {saved_json_path}")
326
+
327
+ import json
328
+ import os
329
+
330
+ predicted_texts = []
331
+
332
+ # Iterate through the number of cropped images we have
333
+ for i in range(len(cropped_images)):
334
+ json_file_path = f"/content/cropped_image_{i}.json"
335
+ text_for_this_image = ""
336
+
337
+ if os.path.exists(json_file_path):
338
+ with open(json_file_path, 'r', encoding='utf-8') as f:
339
+ try:
340
+ # Load the JSON data
341
+ data = json.load(f)
342
+ # The result from predict() is a dictionary, not a list
343
+ if data and 'rec_texts' in data:
344
+ text_for_this_image = " ".join(data['rec_texts'])
345
+ except json.JSONDecodeError:
346
+ print(f"Warning: Could not decode JSON from {json_file_path}")
347
+ except KeyError as e:
348
+ print(f"Warning: Unexpected JSON structure in {json_file_path}: {e}")
349
+
350
+
351
+ predicted_texts.append(text_for_this_image)
352
+
353
+ # Display the final list of predicted texts
354
+ print("Predicted Texts Array (from JSON files):")
355
+ print(predicted_texts)
356
+
357
+ import json
358
+ import os
359
+
360
+ def get_box_center(box):
361
+ """Calculates the center of a bounding box."""
362
+ x_coords = [p[0] for p in box]
363
+ y_coords = [p[1] for p in box]
364
+ center_x = sum(x_coords) / len(x_coords)
365
+ center_y = sum(y_coords) / len(y_coords)
366
+ return center_x, center_y
367
+
368
+ # --- Step 1: Read all text and their centroid coordinates ---
369
+ all_text_blocks = []
370
+ for i, box in enumerate(arr):
371
+ json_file_path = f"/content/cropped_image_{i}.json"
372
+ if os.path.exists(json_file_path):
373
+ with open(json_file_path, 'r', encoding='utf-8') as f:
374
+ result = json.load(f)
375
+
376
+ if result and 'rec_texts' in result and result['rec_texts']:
377
+ text = " ".join(result['rec_texts'])
378
+ center_x, center_y = get_box_center(box)
379
+ all_text_blocks.append({
380
+ "text": text,
381
+ "center_x": center_x,
382
+ "center_y": center_y
383
+ })
384
+
385
+ # --- Step 2: Sort by y-coordinate, then by x-coordinate, and group into lines ---
386
+ if all_text_blocks:
387
+ # Sort by center_y, then by center_x
388
+ sorted_blocks = sorted(all_text_blocks, key=lambda item: (item["center_y"], item["center_x"]))
389
+
390
+ lines = []
391
+ if sorted_blocks:
392
+ current_line = [sorted_blocks[0]]
393
+ for block in sorted_blocks[1:]:
394
+ # Check if the vertical centers are close enough to be on the same line
395
+ if abs(block["center_y"] - current_line[-1]["center_y"]) < 40: # Y-threshold
396
+ current_line.append(block)
397
+ else:
398
+ # Sort the current line by x-coordinate and add it to the lines list
399
+ current_line.sort(key=lambda item: item["center_x"])
400
+ lines.append(" ".join([item["text"] for item in current_line]))
401
+ current_line = [block]
402
+
403
+ # Add the last line
404
+ if current_line:
405
+ current_line.sort(key=lambda item: item["center_x"])
406
+ lines.append(" ".join([item["text"] for item in current_line]))
407
+
408
+ # --- Step 3: Print the final reconstructed text ---
409
+ if lines:
410
+ for line in lines:
411
+ print(line)
412
+ else:
413
+ print("No text was reconstructed.")
414
+ else:
415
+ print("No text blocks were found.")
416
+
417
+ these are the code cells for colab in for printed text OCR. now i need to deploy this in huggingface space using H200 GPU. Make sure u include import space from GPU for the GPU working, also give me the requirements.txt, and use GRadio for the UI interface