imperiusrex commited on
Commit
7ed45dc
·
verified ·
1 Parent(s): bc5fb2a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +222 -107
app.py CHANGED
@@ -20,119 +20,234 @@ device = "cpu"
20
  clip_model.to(device)
21
 
22
  # Language map for OCR models
23
- lang_map = {
24
- "english": "en",
25
- "telugu": "te",
26
- "chinese": "ch",
27
- "korean": "korean",
28
- }
29
-
30
- # Candidate labels for CLIP-based language detection
31
- candidates = [
32
- "This is English text",
33
- "This is Telugu text",
34
- "This is Chinese text",
35
- "This is Korean text"
36
- ]
37
-
38
- # Initialize PaddleOCR for text detection only
39
- ocr_detector = PaddleOCR(use_textline_orientation=False, lang='en')
40
-
41
-
42
-
43
- # Cache OCR recognizers by language
44
- @lru_cache(maxsize=4)
45
- def get_ocr_recognizer(lang_code):
46
- return PaddleOCR(lang=lang_code, use_textline_orientation=False, det=False, rec=True, use_gpu=False)
47
-
48
- # Helper function to validate detection result
49
- def is_valid_ocr_detection(det_result):
50
- return det_result and isinstance(det_result[0], list) and len(det_result[0]) > 0
51
-
52
- # Main OCR pipeline
53
- def ocr_pipeline(image_np):
54
- image_pil = Image.fromarray(image_np).convert("RGB")
55
- img_cv = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
56
-
57
- detection_result = ocr_detector.ocr(image_np, det=True, rec=False)
58
-
59
- if not is_valid_ocr_detection(detection_result):
60
- return " No text detected."
61
-
62
- arr = [line[0] for line in detection_result[0]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  arr = sorted(arr, key=lambda box: (box[0][1], box[0][0]))
64
 
 
 
65
  cropped_images = []
66
- warped_boxes = []
67
-
68
- for box in arr:
69
  box = np.array(box, dtype=np.float32)
70
- width = int(max(np.linalg.norm(box[0] - box[1]), np.linalg.norm(box[2] - box[3])))
71
- height = int(max(np.linalg.norm(box[0] - box[3]), np.linalg.norm(box[1] - box[2])))
 
 
 
 
72
  dst_rect = np.array([[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]], dtype=np.float32)
73
  M = cv2.getPerspectiveTransform(box, dst_rect)
74
- warped = cv2.warpPerspective(img_cv, M, (width, height))
75
  cropped_images.append(warped)
76
- warped_boxes.append(box)
77
-
78
- final_output_lines = []
79
-
80
- for i, crop in enumerate(cropped_images):
81
- if crop.shape[0] < 10 or crop.shape[1] < 10:
82
- continue
83
-
84
- clip_inputs = clip_processor(text=candidates, images=Image.fromarray(crop), return_tensors="pt", padding=True)
85
- clip_inputs = {k: v.to(device) for k, v in clip_inputs.items()}
86
 
 
 
 
 
 
87
  with torch.no_grad():
88
- probs = clip_model(**clip_inputs).logits_per_image.softmax(dim=1)
89
-
90
- lang_index = probs.argmax().item()
91
- lang_detected = candidates[lang_index].split()[-2].lower()
92
- lang_code = lang_map.get(lang_detected, "en")
93
-
94
- ocr = get_ocr_recognizer(lang_code)
95
- result = ocr.ocr(crop)
96
-
97
- if not result or not isinstance(result[0], list) or len(result[0]) == 0:
98
- continue
99
-
100
- for line in result[0]:
101
- text = line[1][0]
102
- box = line[0]
103
- center_x = sum([p[0] for p in box]) / 4
104
- center_y = sum([p[1] for p in box]) / 4
105
- final_output_lines.append({"text": text, "cx": center_x, "cy": center_y})
106
-
107
- if not final_output_lines:
108
- return "❌ No text detected."
109
-
110
- sorted_blocks = sorted(final_output_lines, key=lambda b: (b["cy"], b["cx"]))
111
- lines = []
112
- current_line = [sorted_blocks[0]]
113
-
114
- for block in sorted_blocks[1:]:
115
- if abs(block["cy"] - current_line[-1]["cy"]) < 40:
116
- current_line.append(block)
117
- else:
118
- lines.append(" ".join([x["text"] for x in sorted(current_line, key=lambda b: b["cx"])]))
119
- current_line = [block]
120
-
121
- if current_line:
122
- lines.append(" ".join([x["text"] for x in sorted(current_line, key=lambda b: b["cx"])]))
123
-
124
- return "\n".join(lines)
125
-
126
- # Gradio interface setup
127
- def build_interface():
128
- return gr.Interface(
129
- fn=ocr_pipeline,
130
- inputs=gr.Image(type="numpy", label="Upload Printed Image"),
131
- outputs="text",
132
- title="\U0001F310 Multilingual Printed OCR with CLIP + PaddleOCR",
133
- description="\U0001F4C4 Upload a printed document image. Detects language using CLIP and performs text detection + recognition with PaddleOCR."
134
- )
135
-
136
- if __name__ == "__main__":
137
- iface = build_interface()
138
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  clip_model.to(device)
21
 
22
  # Language map for OCR models
23
+ def process_image(img_path):
24
+ """
25
+ Processes an image to detect, crop, and OCR text, returning it in reading order.
26
+
27
+ Args:
28
+ img_path: The path to the image file.
29
+
30
+ Returns:
31
+ A string containing the reconstructed text.
32
+ """
33
+ # Load CLIP model and processor
34
+ clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
35
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
36
+
37
+ # Candidate language phrases for detection
38
+ candidates = [
39
+ "This is English text",
40
+ # "This is Hindi text",
41
+ # "This is Tamil text",
42
+ "This is Telugu text",
43
+ # "This is Bengali text",
44
+ # "This is Arabic text",
45
+ "This is Chinese text",
46
+ # "This is Japanese text",
47
+ "This is Korean text",
48
+ # "This is Russian text",
49
+ # "This is Kannada text",
50
+ # "This is Malayalam text",
51
+ # "This is Marathi text",
52
+ # "This is Urdu text",
53
+ # "This is French text",
54
+ # "This is Spanish text",
55
+ # "This is Italian text",
56
+ # "This is Portuguese text",
57
+ # "This is Romanian text",
58
+ # "This is Hungarian text",
59
+ # "This is Indonesian text",
60
+ # "This is Lithuanian text",
61
+ # "This is Chinese Traditional text",
62
+ # "This is Malay text",
63
+ # "This is Dutch text",
64
+ # "This is Norwegian text",
65
+ # "This is Bosnian text",
66
+ # "This is Polish text",
67
+ # "This is Czech text",
68
+ # "This is Slovak text",
69
+ # "This is Welsh text",
70
+ # "This is Slovenian text",
71
+ # "This is Danish text",
72
+ # "This is Albanian text",
73
+ # "This is Estonian text",
74
+ # "This is Swedish text",
75
+ # "This is Irish text",
76
+ # "This is Swahili text",
77
+ # "This is Croatian text",
78
+ # "This is Uzbek text",
79
+ # "This is Turkish text",
80
+ # "This is Latin text",
81
+ # "This is Belarusian text",
82
+ # "This is Ukrainian text"
83
+ ]
84
+
85
+ # Map detected languages to PaddleOCR language codes
86
+ lang_map = {
87
+ "english": "en",
88
+ # "hindi": "hi",
89
+ # "tamil": "ta",
90
+ "telugu": "te",
91
+ # "bengali": "bn",
92
+ # "arabic": "ar",
93
+ "chinese": "ch",
94
+ # "japanese": "japan",
95
+ "korean": "korean",
96
+ # "russian": "ru",
97
+ # "kannada": "kn",
98
+ # "malayalam": "ml",
99
+ # "marathi": "mr",
100
+ # "urdu": "ur",
101
+ # "french": "fr",
102
+ # "spanish": "es",
103
+ # "italian": "it",
104
+ # "portuguese": "pt",
105
+ # "romanian": "ro",
106
+ # "hungarian": "hu",
107
+ # "indonesian": "id",
108
+ # "lithuanian": "lt",
109
+ # "chinese traditional": "chinese_cht",
110
+ # "malay": "ms",
111
+ # "dutch": "nl",
112
+ # "norwegian": "no",
113
+ # "bosnian": "bs",
114
+ # "polish": "pl",
115
+ # "czech": "cs",
116
+ # "slovak": "sk",
117
+ # "welsh": "cy",
118
+ # "slovenian": "sl",
119
+ # "danish": "da",
120
+ # "albanian": "sq",
121
+ # "estonian": "et",
122
+ # "swedish": "sv",
123
+ # "irish": "ga",
124
+ # "swahili": "sw",
125
+ # "croatian": "hr",
126
+ # "uzbek": "uz",
127
+ # "turkish": "tr",
128
+ # "latin": "la",
129
+ # "belarusian": "be",
130
+ # "ukrainian": "uk"
131
+ }
132
+
133
+ # Text Detection
134
+ arr = []
135
+ model_det = TextDetection(model_name="PP-OCRv5_server_det")
136
+ output = model_det.predict(img_path, batch_size=1)
137
+ for res in output:
138
+ polys = res['dt_polys']
139
+ if polys is not None:
140
+ arr.extend(polys.tolist())
141
  arr = sorted(arr, key=lambda box: (box[0][1], box[0][0]))
142
 
143
+ # Image Cropping and Warping
144
+ img = cv2.imread(img_path)
145
  cropped_images = []
146
+ for i, box in enumerate(arr):
 
 
147
  box = np.array(box, dtype=np.float32)
148
+ width_a = np.linalg.norm(box[0] - box[1])
149
+ width_b = np.linalg.norm(box[2] - box[3])
150
+ height_a = np.linalg.norm(box[0] - box[3])
151
+ height_b = np.linalg.norm(box[1] - box[2])
152
+ width = int(max(width_a, width_b))
153
+ height = int(max(height_a, height_b))
154
  dst_rect = np.array([[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]], dtype=np.float32)
155
  M = cv2.getPerspectiveTransform(box, dst_rect)
156
+ warped = cv2.warpPerspective(img, M, (width, height))
157
  cropped_images.append(warped)
 
 
 
 
 
 
 
 
 
 
158
 
159
+ # Perform language detection for each cropped image and then OCR
160
+ predicted_texts = []
161
+ for i, cropped_img in enumerate(cropped_images):
162
+ # Get probabilities
163
+ inputs = processor(text=candidates, images=cropped_img, return_tensors="pt", padding=True)
164
  with torch.no_grad():
165
+ logits_per_image = clip_model(**inputs).logits_per_image
166
+ probs = logits_per_image.softmax(dim=1)
167
+
168
+ # Get best language match
169
+ best = probs.argmax().item()
170
+ detected_lang_phrase = candidates[best]
171
+ detected_lang = detected_lang_phrase.split()[-2].lower()
172
+ lang_code = lang_map.get(detected_lang, "en")
173
+
174
+ # Perform OCR for the current cropped image with the detected language
175
+ ocr = PaddleOCR(
176
+ use_doc_orientation_classify=False,
177
+ use_doc_unwarping=False,
178
+ use_textline_orientation=False,
179
+ lang=lang_code,
180
+ device="cpu"
181
+ )
182
+
183
+ result = ocr.predict(cropped_img)
184
+
185
+ text_for_this_image = ""
186
+ if result and result[0] and 'rec_texts' in result[0]:
187
+ text_for_this_image = " ".join(result[0]['rec_texts'])
188
+
189
+ predicted_texts.append(text_for_this_image)
190
+
191
+
192
+ def get_box_center(box):
193
+ """Calculates the center of a bounding box."""
194
+ x_coords = [p[0] for p in box]
195
+ y_coords = [p[1] for p in box]
196
+ center_x = sum(x_coords) / len(x_coords)
197
+ center_y = sum(y_coords) / len(y_coords)
198
+ return center_x, center_y
199
+
200
+ # --- Step 1: Read all text and their centroid coordinates ---
201
+ all_text_blocks = []
202
+ for i, box in enumerate(arr):
203
+ # Use the predicted text from the list
204
+ text = predicted_texts[i]
205
+ if text: # Only add if text is not empty
206
+ center_x, center_y = get_box_center(box)
207
+ all_text_blocks.append({
208
+ "text": text,
209
+ "center_x": center_x,
210
+ "center_y": center_y
211
+ })
212
+
213
+
214
+ # --- Step 2: Sort by y-coordinate, then by x-coordinate, and group into lines ---
215
+ reconstructed_text = ""
216
+ if all_text_blocks:
217
+ # Sort by center_y, then by center_x
218
+ sorted_blocks = sorted(all_text_blocks, key=lambda item: (item["center_y"], item["center_x"]))
219
+
220
+ lines = []
221
+ if sorted_blocks:
222
+ current_line = [sorted_blocks[0]]
223
+ for block in sorted_blocks[1:]:
224
+ # Check if the vertical centers are close enough to be on the same line
225
+ if abs(block["center_y"] - current_line[-1]["center_y"]) < 40: # Y-threshold
226
+ current_line.append(block)
227
+ else:
228
+ # Sort the current line by x-coordinate and add it to the lines list
229
+ current_line.sort(key=lambda item: item["center_x"])
230
+ lines.append(" ".join([item["text"] for item in current_line]))
231
+ current_line = [block]
232
+
233
+ # Add the last line
234
+ if current_line:
235
+ current_line.sort(key=lambda item: item["center_x"])
236
+ lines.append(" ".join([item["text"] for item in current_line]))
237
+
238
+ # --- Step 3: Join the lines into a single string ---
239
+ reconstructed_text = "\n".join(lines)
240
+
241
+ return reconstructed_text
242
+
243
+ iface = gr.Interface(
244
+ fn=process_image,
245
+ inputs=gr.Image(type="filepath"),
246
+ outputs=gr.Text(),
247
+ title="Image OCR and Text Reconstruction",
248
+ description="Upload an image to perform text detection, cropping, language detection, OCR, and reconstruct the text in reading order."
249
+ )
250
+
251
+ if __name__== "__main__":
252
+ iface.launch(debug=True)
253
+