imperiusrex commited on
Commit
6635183
·
verified ·
1 Parent(s): 802fa97

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +226 -94
app.py CHANGED
@@ -4,115 +4,247 @@ import numpy as np
4
  import cv2
5
  from PIL import Image
6
  from transformers import CLIPProcessor, CLIPModel
7
- from paddleocr import PaddleOCR
8
  from functools import lru_cache
 
9
 
10
- # Use GPU if available
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
 
13
- # Load CLIP globally
14
- clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
 
15
  clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
16
 
17
- # Language detection candidates
18
- candidates = [
19
- "This is English text",
20
- "This is Telugu text",
21
- "This is Chinese text",
22
- "This is Korean text"
23
- ]
24
-
25
- lang_map = {
26
- "english": "en",
27
- "telugu": "te",
28
- "chinese": "ch",
29
- "korean": "korean"
30
- }
31
-
32
- @lru_cache(maxsize=4)
33
- def get_ocr_model(lang_code):
34
- return PaddleOCR(
35
- use_doc_orientation_classify=False,
36
- use_doc_unwarping=False,
37
- use_textline_orientation=False,
38
- lang=lang_code,
39
- use_angle_cls=False,
40
- use_gpu=torch.cuda.is_available()
41
- )
42
-
43
- def get_box_center(box):
44
- x_coords = [p[0] for p in box]
45
- y_coords = [p[1] for p in box]
46
- return sum(x_coords) / 4, sum(y_coords) / 4
47
 
 
48
  def process_image(img_path):
49
- image = cv2.imread(img_path)
50
- ocr = get_ocr_model("en")
51
- result = ocr.ocr(image, cls=False)
52
-
53
- if not result or not result[0]:
54
- return "❌ No text detected."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
- boxes = result[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  cropped_images = []
58
- centers = []
59
- for line in boxes:
60
- box, (_, _) = line
61
- box = np.array(box).astype(np.float32)
62
- x_min, y_min = box[:, 0].min(), box[:, 1].min()
63
- x_max, y_max = box[:, 0].max(), box[:, 1].max()
64
- warped = image[int(y_min):int(y_max), int(x_min):int(x_max)]
 
 
 
 
65
  cropped_images.append(warped)
66
- centers.append(get_box_center(box))
67
 
 
68
  predicted_texts = []
69
- for crop in cropped_images:
70
- inputs = clip_processor(text=candidates, images=Image.fromarray(crop), return_tensors="pt", padding=True).to(device)
 
71
  with torch.no_grad():
72
- probs = clip_model(**inputs).logits_per_image.softmax(dim=1)
73
- lang = candidates[probs.argmax()].split()[-2].lower()
74
- lang_code = lang_map.get(lang, "en")
75
-
76
- ocr = get_ocr_model(lang_code)
77
- ocr_result = ocr.ocr(crop, cls=False)
78
-
79
- text = ""
80
- if ocr_result and ocr_result[0]:
81
- text = " ".join([line[1][0] for line in ocr_result[0]])
82
- predicted_texts.append(text)
83
-
84
- # Reconstruct text based on y then x sorting
85
- text_blocks = [
86
- {"text": text, "center": center}
87
- for text, center in zip(predicted_texts, centers) if text.strip()
88
- ]
89
- text_blocks.sort(key=lambda b: (b["center"][1], b["center"][0]))
90
-
91
- lines = []
92
- line = []
93
- last_y = None
94
- for block in text_blocks:
95
- y = block["center"][1]
96
- if last_y is None or abs(y - last_y) < 40:
97
- line.append(block)
98
- else:
99
- lines.append(line)
100
- line = [block]
101
- last_y = y
102
- if line:
103
- lines.append(line)
104
-
105
- final_text = "\n".join(" ".join([b["text"] for b in sorted(l, key=lambda b: b["center"][0])]) for l in lines)
106
- return final_text or "❌ No text reconstructed."
107
-
108
- # Gradio app
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  iface = gr.Interface(
110
  fn=process_image,
111
- inputs=gr.Image(type="filepath", label="Upload Image"),
112
- outputs=gr.Text(label="Extracted Text"),
113
- title="Multilingual OCR + Text Reconstruction",
114
- description="Uses PaddleOCR + CLIP to detect language, perform OCR, and reconstruct reading order text using GPU."
115
  )
116
 
117
- if __name__ == "__main__":
118
- iface.launch()
 
4
  import cv2
5
  from PIL import Image
6
  from transformers import CLIPProcessor, CLIPModel
7
+ from paddleocr import PaddleOCR, TextDetection
8
  from functools import lru_cache
9
+ from spaces import GPU
10
 
 
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
 
13
+ MODEL_HUB_ID = "imperiusrex/printedpaddle"
14
+ # Setup
15
+ clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
16
  clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
17
 
18
+ # Set device to CPU
19
+ device = "cpu"
20
+ clip_model.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ @GPU
23
  def process_image(img_path):
24
+ """
25
+ Processes an image to detect, crop, and OCR text, returning it in reading order.
26
+ Args:
27
+ img_path: The path to the image file.
28
+ Returns:
29
+ A string containing the reconstructed text.
30
+ """
31
+ # Load CLIP model and processor
32
+ clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
33
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
34
+
35
+ # Candidate language phrases for detection
36
+ candidates = [
37
+ "This is English text",
38
+ # "This is Hindi text",
39
+ # "This is Tamil text",
40
+ "This is Telugu text",
41
+ # "This is Bengali text",
42
+ # "This is Arabic text",
43
+ "This is Chinese text",
44
+ # "This is Japanese text",
45
+ "This is Korean text",
46
+ # "This is Russian text",
47
+ # "This is Kannada text",
48
+ # "This is Malayalam text",
49
+ # "This is Marathi text",
50
+ # "This is Urdu text",
51
+ # "This is French text",
52
+ # "This is Spanish text",
53
+ # "This is Italian text",
54
+ # "This is Portuguese text",
55
+ # "This is Romanian text",
56
+ # "This is Hungarian text",
57
+ # "This is Indonesian text",
58
+ # "This is Lithuanian text",
59
+ # "This is Chinese Traditional text",
60
+ # "This is Malay text",
61
+ # "This is Dutch text",
62
+ # "This is Norwegian text",
63
+ # "This is Bosnian text",
64
+ # "This is Polish text",
65
+ # "This is Czech text",
66
+ # "This is Slovak text",
67
+ # "This is Welsh text",
68
+ # "This is Slovenian text",
69
+ # "This is Danish text",
70
+ # "This is Albanian text",
71
+ # "This is Estonian text",
72
+ # "This is Swedish text",
73
+ # "This is Irish text",
74
+ # "This is Swahili text",
75
+ # "This is Croatian text",
76
+ # "This is Uzbek text",
77
+ # "This is Turkish text",
78
+ # "This is Latin text",
79
+ # "This is Belarusian text",
80
+ # "This is Ukrainian text"
81
+ ]
82
 
83
+ # Map detected languages to PaddleOCR language codes
84
+ lang_map = {
85
+ "english": "en",
86
+ # "hindi": "hi",
87
+ # "tamil": "ta",
88
+ "telugu": "te",
89
+ # "bengali": "bn",
90
+ # "arabic": "ar",
91
+ "chinese": "ch",
92
+ # "japanese": "japan",
93
+ "korean": "korean",
94
+ # "russian": "ru",
95
+ # "kannada": "kn",
96
+ # "malayalam": "ml",
97
+ # "marathi": "mr",
98
+ # "urdu": "ur",
99
+ # "french": "fr",
100
+ # "spanish": "es",
101
+ # "italian": "it",
102
+ # "portuguese": "pt",
103
+ # "romanian": "ro",
104
+ # "hungarian": "hu",
105
+ # "indonesian": "id",
106
+ # "lithuanian": "lt",
107
+ # "chinese traditional": "chinese_cht",
108
+ # "malay": "ms",
109
+ # "dutch": "nl",
110
+ # "norwegian": "no",
111
+ # "bosnian": "bs",
112
+ # "polish": "pl",
113
+ # "czech": "cs",
114
+ # "slovak": "sk",
115
+ # "welsh": "cy",
116
+ # "slovenian": "sl",
117
+ # "danish": "da",
118
+ # "albanian": "sq",
119
+ # "estonian": "et",
120
+ # "swedish": "sv",
121
+ # "irish": "ga",
122
+ # "swahili": "sw",
123
+ # "croatian": "hr",
124
+ # "uzbek": "uz",
125
+ # "turkish": "tr",
126
+ # "latin": "la",
127
+ # "belarusian": "be",
128
+ # "ukrainian": "uk"
129
+ }
130
+
131
+ # Text Detection
132
+ arr = []
133
+ model_det = TextDetection(model_name="PP-OCRv5_server_det")
134
+ output = model_det.predict(img_path, batch_size=1)
135
+ for res in output:
136
+ polys = res['dt_polys']
137
+ if polys is not None:
138
+ arr.extend(polys.tolist())
139
+ arr = sorted(arr, key=lambda box: (box[0][1], box[0][0]))
140
+
141
+ # Image Cropping and Warping
142
+ img = cv2.imread(img_path)
143
  cropped_images = []
144
+ for i, box in enumerate(arr):
145
+ box = np.array(box, dtype=np.float32)
146
+ width_a = np.linalg.norm(box[0] - box[1])
147
+ width_b = np.linalg.norm(box[2] - box[3])
148
+ height_a = np.linalg.norm(box[0] - box[3])
149
+ height_b = np.linalg.norm(box[1] - box[2])
150
+ width = int(max(width_a, width_b))
151
+ height = int(max(height_a, height_b))
152
+ dst_rect = np.array([[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]], dtype=np.float32)
153
+ M = cv2.getPerspectiveTransform(box, dst_rect)
154
+ warped = cv2.warpPerspective(img, M, (width, height))
155
  cropped_images.append(warped)
 
156
 
157
+ # Perform language detection for each cropped image and then OCR
158
  predicted_texts = []
159
+ for i, cropped_img in enumerate(cropped_images):
160
+ # Get probabilities
161
+ inputs = processor(text=candidates, images=cropped_img, return_tensors="pt", padding=True)
162
  with torch.no_grad():
163
+ logits_per_image = clip_model(**inputs).logits_per_image
164
+ probs = logits_per_image.softmax(dim=1)
165
+
166
+ # Get best language match
167
+ best = probs.argmax().item()
168
+ detected_lang_phrase = candidates[best]
169
+ detected_lang = detected_lang_phrase.split()[-2].lower()
170
+ lang_code = lang_map.get(detected_lang, "en")
171
+
172
+ # Perform OCR for the current cropped image with the detected language
173
+ ocr = PaddleOCR(
174
+ use_doc_orientation_classify=False,
175
+ use_doc_unwarping=False,
176
+ use_textline_orientation=False,
177
+ lang=lang_code,
178
+ device="cpu"
179
+ )
180
+
181
+ result = ocr.predict(cropped_img)
182
+
183
+ text_for_this_image = ""
184
+ if result and result[0] and 'rec_texts' in result[0]:
185
+ text_for_this_image = " ".join(result[0]['rec_texts'])
186
+
187
+ predicted_texts.append(text_for_this_image)
188
+
189
+
190
+ def get_box_center(box):
191
+ """Calculates the center of a bounding box."""
192
+ x_coords = [p[0] for p in box]
193
+ y_coords = [p[1] for p in box]
194
+ center_x = sum(x_coords) / len(x_coords)
195
+ center_y = sum(y_coords) / len(y_coords)
196
+ return center_x, center_y
197
+
198
+ # --- Step 1: Read all text and their centroid coordinates ---
199
+ all_text_blocks = []
200
+ for i, box in enumerate(arr):
201
+ # Use the predicted text from the list
202
+ text = predicted_texts[i]
203
+ if text: # Only add if text is not empty
204
+ center_x, center_y = get_box_center(box)
205
+ all_text_blocks.append({
206
+ "text": text,
207
+ "center_x": center_x,
208
+ "center_y": center_y
209
+ })
210
+
211
+
212
+ # --- Step 2: Sort by y-coordinate, then by x-coordinate, and group into lines ---
213
+ reconstructed_text = ""
214
+ if all_text_blocks:
215
+ # Sort by center_y, then by center_x
216
+ sorted_blocks = sorted(all_text_blocks, key=lambda item: (item["center_y"], item["center_x"]))
217
+
218
+ lines = []
219
+ if sorted_blocks:
220
+ current_line = [sorted_blocks[0]]
221
+ for block in sorted_blocks[1:]:
222
+ # Check if the vertical centers are close enough to be on the same line
223
+ if abs(block["center_y"] - current_line[-1]["center_y"]) < 40: # Y-threshold
224
+ current_line.append(block)
225
+ else:
226
+ # Sort the current line by x-coordinate and add it to the lines list
227
+ current_line.sort(key=lambda item: item["center_x"])
228
+ lines.append(" ".join([item["text"] for item in current_line]))
229
+ current_line = [block]
230
+
231
+ # Add the last line
232
+ if current_line:
233
+ current_line.sort(key=lambda item: item["center_x"])
234
+ lines.append(" ".join([item["text"] for item in current_line]))
235
+
236
+ # --- Step 3: Join the lines into a single string ---
237
+ reconstructed_text = "\n".join(lines)
238
+
239
+ return reconstructed_text
240
+
241
  iface = gr.Interface(
242
  fn=process_image,
243
+ inputs=gr.Image(type="filepath"),
244
+ outputs=gr.Text(),
245
+ title="Image OCR and Text Reconstruction",
246
+ description="Upload an image to perform text detection, cropping, language detection, OCR, and reconstruct the text in reading order."
247
  )
248
 
249
+ if __name__== "__main__":
250
+ iface.launch(debug=True)