imperiusrex commited on
Commit
4b6aee6
Β·
verified Β·
1 Parent(s): d99150f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -147
app.py CHANGED
@@ -1,175 +1,120 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import CLIPProcessor, CLIPModel
4
- from paddleocr import PaddleOCR, TextDetection
5
- from PIL import Image
6
  import numpy as np
7
  import cv2
8
- import spaces
9
-
10
- # --- Global setup for models and data ---
11
- print("πŸ”„ Initializing models...")
12
-
13
- # Check for GPU and set device
14
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
- print(f"Device being used: {device}")
16
 
17
- # Load CLIP model once. This is memory-intensive, so we do it once.
18
- clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
19
- processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
20
 
21
- # Initialize Paddle's text detection model.
22
- # The latest versions of PaddlePaddle/PaddleOCR automatically use the GPU.
23
- det_model = TextDetection(model_name="PP-OCRv5_server_det")
 
 
 
24
 
25
- # Candidate language phrases for detection
26
  candidates = [
27
  "This is English text",
28
  "This is Telugu text",
29
  "This is Chinese text",
30
- "This is Korean text",
31
  ]
32
 
33
- # Map detected languages to PaddleOCR language codes
34
- lang_map = {
35
- "english": "en",
36
- "telugu": "te",
37
- "chinese": "ch",
38
- "korean": "korean",
39
- }
 
 
40
 
41
- print("βœ… Models loaded successfully.")
42
-
43
- # --- Utility Functions ---
44
- def get_box_center(box):
45
- """Calculates the center of a bounding box."""
46
- x_coords = [p[0] for p in box]
47
- y_coords = [p[1] for p in box]
48
- center_x = sum(x_coords) / len(x_coords)
49
- center_y = sum(y_coords) / len(y_coords)
50
- return center_x, center_y
51
-
52
- @spaces.GPU
53
- def ocr_pipeline(image_pil: Image.Image) -> str:
54
- """
55
- Performs OCR on an input image using a multi-step pipeline.
56
-
57
- Args:
58
- image_pil: A PIL Image object from the Gradio interface.
59
-
60
- Returns:
61
- A string containing the reconstructed text.
62
- """
63
- if image_pil is None:
64
- return "No image provided."
65
-
66
- print("Starting OCR pipeline...")
67
-
68
- # Convert PIL image to a NumPy array for OpenCV and Paddle
69
- img_np = np.array(image_pil.convert("RGB"))
70
-
71
- # Step 1: Text Detection with PaddleOCR's model
72
- output = det_model.predict(img_np, batch_size=1)
73
-
74
  arr = []
75
- if output and output[0] and 'dt_polys' in output[0] and output[0]['dt_polys'] is not None:
76
- arr.extend(output[0]['dt_polys'].tolist())
 
 
77
 
78
- # Sort the bounding boxes in reading order
79
- sorted_polys = sorted(arr, key=lambda box: (box[0][1], box[0][0]))
80
 
81
- if not sorted_polys:
82
- print("No text regions detected.")
83
- return "No text regions detected."
84
-
85
  cropped_images = []
86
- for box in sorted_polys:
 
 
87
  box = np.array(box, dtype=np.float32)
88
- width_a = np.linalg.norm(box[0] - box[1])
89
- width_b = np.linalg.norm(box[2] - box[3])
90
- height_a = np.linalg.norm(box[0] - box[3])
91
- height_b = np.linalg.norm(box[1] - box[2])
92
- width = int(max(width_a, width_b))
93
- height = int(max(height_a, height_b))
94
- dst_rect = np.array([
95
- [0, 0],
96
- [width - 1, 0],
97
- [width - 1, height - 1],
98
- [0, height - 1]
99
- ], dtype=np.float32)
100
  M = cv2.getPerspectiveTransform(box, dst_rect)
101
- warped = cv2.warpPerspective(img_np, M, (width, height))
102
  cropped_images.append(warped)
 
103
 
104
- # Step 2: Language detection with CLIP and OCR on cropped images
105
- all_text_blocks = []
106
-
107
- for i, img in enumerate(cropped_images):
108
- pil_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
109
-
110
- # Use CLIP to detect language. The model is already on the GPU.
111
- inputs = processor(text=candidates, images=pil_img, return_tensors="pt", padding=True).to(device)
112
  with torch.no_grad():
113
- outputs = clip_model(**inputs)
114
- logits_per_image = outputs.logits_per_image
115
- probs = logits_per_image.softmax(dim=1)
116
-
117
- best = probs.argmax().item()
118
- detected_lang_phrase = candidates[best]
119
- detected_lang = detected_lang_phrase.split()[-2].lower()
120
- lang_code = lang_map.get(detected_lang, "en")
121
-
122
- # Initialize PaddleOCR with the detected language.
123
- ocr = PaddleOCR(lang=lang_code, use_angle_cls=False, use_doc_unwarping=False, use_gpu=True)
124
- result = ocr.predict(img)
125
-
126
- # Extract text from OCR result
127
- text_for_this_image = ""
128
- if result and result[0] and 'rec_texts' in result[0]:
129
- text_for_this_image = " ".join(result[0]['rec_texts'])
130
 
131
- # Store text and bounding box information
132
- center_x, center_y = get_box_center(sorted_polys[i])
133
- all_text_blocks.append({
134
- "text": text_for_this_image,
135
- "center_x": center_x,
136
- "center_y": center_y
137
- })
138
-
139
- # Step 3: Reconstruct the text in reading order
140
- if not all_text_blocks:
141
- print("No text could be extracted.")
142
- return "No text could be extracted."
143
 
144
- sorted_blocks = sorted(all_text_blocks, key=lambda item: (item["center_y"], item["center_x"]))
145
-
 
 
 
 
 
 
 
 
 
 
146
  lines = []
147
- if sorted_blocks:
148
- current_line = [sorted_blocks[0]]
149
- for block in sorted_blocks[1:]:
150
- if abs(block["center_y"] - current_line[-1]["center_y"]) < 40:
151
- current_line.append(block)
152
- else:
153
- current_line.sort(key=lambda item: item["center_x"])
154
- lines.append(" ".join([item["text"] for item in current_line]))
155
- current_line = [block]
156
-
157
- if current_line:
158
- current_line.sort(key=lambda item: item["center_x"])
159
- lines.append(" ".join([item["text"] for item in current_line]))
160
-
161
- final_text = "\n".join(lines)
162
- print("OCR pipeline finished successfully.")
163
- return final_text
164
-
165
- # --- Gradio Interface ---
166
- iface = gr.Interface(
167
- fn=ocr_pipeline,
168
- inputs=gr.Image(type="pil", label="Upload Image"),
169
- outputs=gr.Textbox(label="Recognized Text"),
170
- title="Printed Text OCR with PaddleOCR and CLIP",
171
- description="Upload a printed text image. The app will detect text regions, identify the language with CLIP, and perform OCR to return the text in reading order. This space uses an H200 GPU for high-speed processing."
172
- )
173
 
174
  if __name__ == "__main__":
175
- iface.launch()
 
 
1
  import gradio as gr
2
  import torch
 
 
 
3
  import numpy as np
4
  import cv2
5
+ import os
6
+ import json
7
+ from PIL import Image
8
+ from transformers import CLIPProcessor, CLIPModel
9
+ from paddleocr import PaddleOCR, TextDetection
10
+ from spaces import GPU # Required for ZeroGPU on Hugging Face
 
 
11
 
12
+ # Setup
13
+ clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
14
+ clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
15
 
16
+ lang_map = {
17
+ "english": "en",
18
+ "telugu": "te",
19
+ "chinese": "ch",
20
+ "korean": "korean",
21
+ }
22
 
 
23
  candidates = [
24
  "This is English text",
25
  "This is Telugu text",
26
  "This is Chinese text",
27
+ "This is Korean text"
28
  ]
29
 
30
+ text_detector = TextDetection(model_name="PP-OCRv5_server_det")
31
+
32
+ @GPU
33
+ def ocr_pipeline(image_np):
34
+ image_pil = Image.fromarray(image_np).convert("RGB")
35
+ width, height = image_pil.size
36
+ img_cv = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
37
+
38
+ output = text_detector.predict(image_np, batch_size=1)
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  arr = []
41
+ for res in output:
42
+ polys = res.get("dt_polys", [])
43
+ if polys is not None:
44
+ arr.extend(polys.tolist())
45
 
46
+ arr = sorted(arr, key=lambda box: (box[0][1], box[0][0]))
 
47
 
 
 
 
 
48
  cropped_images = []
49
+ warped_boxes = []
50
+
51
+ for box in arr:
52
  box = np.array(box, dtype=np.float32)
53
+ width = int(max(np.linalg.norm(box[0] - box[1]), np.linalg.norm(box[2] - box[3])))
54
+ height = int(max(np.linalg.norm(box[0] - box[3]), np.linalg.norm(box[1] - box[2])))
55
+ dst_rect = np.array([[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]], dtype=np.float32)
 
 
 
 
 
 
 
 
 
56
  M = cv2.getPerspectiveTransform(box, dst_rect)
57
+ warped = cv2.warpPerspective(img_cv, M, (width, height))
58
  cropped_images.append(warped)
59
+ warped_boxes.append(box)
60
 
61
+ final_output_lines = []
62
+
63
+ for i, crop in enumerate(cropped_images):
64
+ if crop.shape[0] < 10 or crop.shape[1] < 10:
65
+ continue
66
+
67
+ # Language detection
68
+ clip_inputs = clip_processor(text=candidates, images=crop, return_tensors="pt", padding=True)
69
  with torch.no_grad():
70
+ probs = clip_model(**clip_inputs).logits_per_image.softmax(dim=1)
71
+ lang_index = probs.argmax().item()
72
+ lang_detected = candidates[lang_index].split()[-2].lower()
73
+ lang_code = lang_map.get(lang_detected, "en")
74
+
75
+ ocr = PaddleOCR(lang=lang_code, use_doc_orientation_classify=False,
76
+ use_doc_unwarping=False, use_textline_orientation=False, device='cpu')
 
 
 
 
 
 
 
 
 
 
77
 
78
+ result = ocr.ocr(crop)
79
+ if not result or not result[0]:
80
+ continue
 
 
 
 
 
 
 
 
 
81
 
82
+ for line in result[0]:
83
+ text = line[1][0]
84
+ box = line[0]
85
+ center_x = sum([p[0] for p in box]) / 4
86
+ center_y = sum([p[1] for p in box]) / 4
87
+ final_output_lines.append({"text": text, "cx": center_x, "cy": center_y})
88
+
89
+ if not final_output_lines:
90
+ return "❌ No text detected."
91
+
92
+ # Grouping by line
93
+ sorted_blocks = sorted(final_output_lines, key=lambda b: (b["cy"], b["cx"]))
94
  lines = []
95
+ current_line = [sorted_blocks[0]]
96
+ for block in sorted_blocks[1:]:
97
+ if abs(block["cy"] - current_line[-1]["cy"]) < 40:
98
+ current_line.append(block)
99
+ else:
100
+ lines.append(" ".join([x["text"] for x in sorted(current_line, key=lambda b: b["cx"])]))
101
+ current_line = [block]
102
+ if current_line:
103
+ lines.append(" ".join([x["text"] for x in sorted(current_line, key=lambda b: b["cx"])]))
104
+
105
+ return "\n".join(lines)
106
+
107
+
108
+ # Gradio Interface
109
+ def build_interface():
110
+ return gr.Interface(
111
+ fn=ocr_pipeline,
112
+ inputs=gr.Image(type="numpy", label="Upload Handwritten Image"),
113
+ outputs="text",
114
+ title="🌐 Multilingual Handwritten OCR with CLIP + PaddleOCR",
115
+ description="πŸ“„ Upload a handwritten document image. Detects language using CLIP and performs text detection + recognition with PaddleOCR."
116
+ )
 
 
 
 
117
 
118
  if __name__ == "__main__":
119
+ iface = build_interface()
120
+ iface.launch()