Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,175 +1,120 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
import torch
|
| 3 |
-
from transformers import CLIPProcessor, CLIPModel
|
| 4 |
-
from paddleocr import PaddleOCR, TextDetection
|
| 5 |
-
from PIL import Image
|
| 6 |
import numpy as np
|
| 7 |
import cv2
|
| 8 |
-
import
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
#
|
| 14 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 15 |
-
print(f"Device being used: {device}")
|
| 16 |
|
| 17 |
-
#
|
| 18 |
-
clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
|
| 19 |
-
|
| 20 |
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
-
# Candidate language phrases for detection
|
| 26 |
candidates = [
|
| 27 |
"This is English text",
|
| 28 |
"This is Telugu text",
|
| 29 |
"This is Chinese text",
|
| 30 |
-
"This is Korean text"
|
| 31 |
]
|
| 32 |
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
|
|
|
|
|
|
| 40 |
|
| 41 |
-
print("β
Models loaded successfully.")
|
| 42 |
-
|
| 43 |
-
# --- Utility Functions ---
|
| 44 |
-
def get_box_center(box):
|
| 45 |
-
"""Calculates the center of a bounding box."""
|
| 46 |
-
x_coords = [p[0] for p in box]
|
| 47 |
-
y_coords = [p[1] for p in box]
|
| 48 |
-
center_x = sum(x_coords) / len(x_coords)
|
| 49 |
-
center_y = sum(y_coords) / len(y_coords)
|
| 50 |
-
return center_x, center_y
|
| 51 |
-
|
| 52 |
-
@spaces.GPU
|
| 53 |
-
def ocr_pipeline(image_pil: Image.Image) -> str:
|
| 54 |
-
"""
|
| 55 |
-
Performs OCR on an input image using a multi-step pipeline.
|
| 56 |
-
|
| 57 |
-
Args:
|
| 58 |
-
image_pil: A PIL Image object from the Gradio interface.
|
| 59 |
-
|
| 60 |
-
Returns:
|
| 61 |
-
A string containing the reconstructed text.
|
| 62 |
-
"""
|
| 63 |
-
if image_pil is None:
|
| 64 |
-
return "No image provided."
|
| 65 |
-
|
| 66 |
-
print("Starting OCR pipeline...")
|
| 67 |
-
|
| 68 |
-
# Convert PIL image to a NumPy array for OpenCV and Paddle
|
| 69 |
-
img_np = np.array(image_pil.convert("RGB"))
|
| 70 |
-
|
| 71 |
-
# Step 1: Text Detection with PaddleOCR's model
|
| 72 |
-
output = det_model.predict(img_np, batch_size=1)
|
| 73 |
-
|
| 74 |
arr = []
|
| 75 |
-
|
| 76 |
-
|
|
|
|
|
|
|
| 77 |
|
| 78 |
-
|
| 79 |
-
sorted_polys = sorted(arr, key=lambda box: (box[0][1], box[0][0]))
|
| 80 |
|
| 81 |
-
if not sorted_polys:
|
| 82 |
-
print("No text regions detected.")
|
| 83 |
-
return "No text regions detected."
|
| 84 |
-
|
| 85 |
cropped_images = []
|
| 86 |
-
|
|
|
|
|
|
|
| 87 |
box = np.array(box, dtype=np.float32)
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
height_b = np.linalg.norm(box[1] - box[2])
|
| 92 |
-
width = int(max(width_a, width_b))
|
| 93 |
-
height = int(max(height_a, height_b))
|
| 94 |
-
dst_rect = np.array([
|
| 95 |
-
[0, 0],
|
| 96 |
-
[width - 1, 0],
|
| 97 |
-
[width - 1, height - 1],
|
| 98 |
-
[0, height - 1]
|
| 99 |
-
], dtype=np.float32)
|
| 100 |
M = cv2.getPerspectiveTransform(box, dst_rect)
|
| 101 |
-
warped = cv2.warpPerspective(
|
| 102 |
cropped_images.append(warped)
|
|
|
|
| 103 |
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
#
|
| 111 |
-
|
| 112 |
with torch.no_grad():
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
lang_code = lang_map.get(detected_lang, "en")
|
| 121 |
-
|
| 122 |
-
# Initialize PaddleOCR with the detected language.
|
| 123 |
-
ocr = PaddleOCR(lang=lang_code, use_angle_cls=False, use_doc_unwarping=False, use_gpu=True)
|
| 124 |
-
result = ocr.predict(img)
|
| 125 |
-
|
| 126 |
-
# Extract text from OCR result
|
| 127 |
-
text_for_this_image = ""
|
| 128 |
-
if result and result[0] and 'rec_texts' in result[0]:
|
| 129 |
-
text_for_this_image = " ".join(result[0]['rec_texts'])
|
| 130 |
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
"text": text_for_this_image,
|
| 135 |
-
"center_x": center_x,
|
| 136 |
-
"center_y": center_y
|
| 137 |
-
})
|
| 138 |
-
|
| 139 |
-
# Step 3: Reconstruct the text in reading order
|
| 140 |
-
if not all_text_blocks:
|
| 141 |
-
print("No text could be extracted.")
|
| 142 |
-
return "No text could be extracted."
|
| 143 |
|
| 144 |
-
|
| 145 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
lines = []
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
outputs=gr.Textbox(label="Recognized Text"),
|
| 170 |
-
title="Printed Text OCR with PaddleOCR and CLIP",
|
| 171 |
-
description="Upload a printed text image. The app will detect text regions, identify the language with CLIP, and perform OCR to return the text in reading order. This space uses an H200 GPU for high-speed processing."
|
| 172 |
-
)
|
| 173 |
|
| 174 |
if __name__ == "__main__":
|
| 175 |
-
iface
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import torch
|
|
|
|
|
|
|
|
|
|
| 3 |
import numpy as np
|
| 4 |
import cv2
|
| 5 |
+
import os
|
| 6 |
+
import json
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from transformers import CLIPProcessor, CLIPModel
|
| 9 |
+
from paddleocr import PaddleOCR, TextDetection
|
| 10 |
+
from spaces import GPU # Required for ZeroGPU on Hugging Face
|
|
|
|
|
|
|
| 11 |
|
| 12 |
+
# Setup
|
| 13 |
+
clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
|
| 14 |
+
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
|
| 15 |
|
| 16 |
+
lang_map = {
|
| 17 |
+
"english": "en",
|
| 18 |
+
"telugu": "te",
|
| 19 |
+
"chinese": "ch",
|
| 20 |
+
"korean": "korean",
|
| 21 |
+
}
|
| 22 |
|
|
|
|
| 23 |
candidates = [
|
| 24 |
"This is English text",
|
| 25 |
"This is Telugu text",
|
| 26 |
"This is Chinese text",
|
| 27 |
+
"This is Korean text"
|
| 28 |
]
|
| 29 |
|
| 30 |
+
text_detector = TextDetection(model_name="PP-OCRv5_server_det")
|
| 31 |
+
|
| 32 |
+
@GPU
|
| 33 |
+
def ocr_pipeline(image_np):
|
| 34 |
+
image_pil = Image.fromarray(image_np).convert("RGB")
|
| 35 |
+
width, height = image_pil.size
|
| 36 |
+
img_cv = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
|
| 37 |
+
|
| 38 |
+
output = text_detector.predict(image_np, batch_size=1)
|
| 39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
arr = []
|
| 41 |
+
for res in output:
|
| 42 |
+
polys = res.get("dt_polys", [])
|
| 43 |
+
if polys is not None:
|
| 44 |
+
arr.extend(polys.tolist())
|
| 45 |
|
| 46 |
+
arr = sorted(arr, key=lambda box: (box[0][1], box[0][0]))
|
|
|
|
| 47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
cropped_images = []
|
| 49 |
+
warped_boxes = []
|
| 50 |
+
|
| 51 |
+
for box in arr:
|
| 52 |
box = np.array(box, dtype=np.float32)
|
| 53 |
+
width = int(max(np.linalg.norm(box[0] - box[1]), np.linalg.norm(box[2] - box[3])))
|
| 54 |
+
height = int(max(np.linalg.norm(box[0] - box[3]), np.linalg.norm(box[1] - box[2])))
|
| 55 |
+
dst_rect = np.array([[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]], dtype=np.float32)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
M = cv2.getPerspectiveTransform(box, dst_rect)
|
| 57 |
+
warped = cv2.warpPerspective(img_cv, M, (width, height))
|
| 58 |
cropped_images.append(warped)
|
| 59 |
+
warped_boxes.append(box)
|
| 60 |
|
| 61 |
+
final_output_lines = []
|
| 62 |
+
|
| 63 |
+
for i, crop in enumerate(cropped_images):
|
| 64 |
+
if crop.shape[0] < 10 or crop.shape[1] < 10:
|
| 65 |
+
continue
|
| 66 |
+
|
| 67 |
+
# Language detection
|
| 68 |
+
clip_inputs = clip_processor(text=candidates, images=crop, return_tensors="pt", padding=True)
|
| 69 |
with torch.no_grad():
|
| 70 |
+
probs = clip_model(**clip_inputs).logits_per_image.softmax(dim=1)
|
| 71 |
+
lang_index = probs.argmax().item()
|
| 72 |
+
lang_detected = candidates[lang_index].split()[-2].lower()
|
| 73 |
+
lang_code = lang_map.get(lang_detected, "en")
|
| 74 |
+
|
| 75 |
+
ocr = PaddleOCR(lang=lang_code, use_doc_orientation_classify=False,
|
| 76 |
+
use_doc_unwarping=False, use_textline_orientation=False, device='cpu')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
+
result = ocr.ocr(crop)
|
| 79 |
+
if not result or not result[0]:
|
| 80 |
+
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
+
for line in result[0]:
|
| 83 |
+
text = line[1][0]
|
| 84 |
+
box = line[0]
|
| 85 |
+
center_x = sum([p[0] for p in box]) / 4
|
| 86 |
+
center_y = sum([p[1] for p in box]) / 4
|
| 87 |
+
final_output_lines.append({"text": text, "cx": center_x, "cy": center_y})
|
| 88 |
+
|
| 89 |
+
if not final_output_lines:
|
| 90 |
+
return "β No text detected."
|
| 91 |
+
|
| 92 |
+
# Grouping by line
|
| 93 |
+
sorted_blocks = sorted(final_output_lines, key=lambda b: (b["cy"], b["cx"]))
|
| 94 |
lines = []
|
| 95 |
+
current_line = [sorted_blocks[0]]
|
| 96 |
+
for block in sorted_blocks[1:]:
|
| 97 |
+
if abs(block["cy"] - current_line[-1]["cy"]) < 40:
|
| 98 |
+
current_line.append(block)
|
| 99 |
+
else:
|
| 100 |
+
lines.append(" ".join([x["text"] for x in sorted(current_line, key=lambda b: b["cx"])]))
|
| 101 |
+
current_line = [block]
|
| 102 |
+
if current_line:
|
| 103 |
+
lines.append(" ".join([x["text"] for x in sorted(current_line, key=lambda b: b["cx"])]))
|
| 104 |
+
|
| 105 |
+
return "\n".join(lines)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
# Gradio Interface
|
| 109 |
+
def build_interface():
|
| 110 |
+
return gr.Interface(
|
| 111 |
+
fn=ocr_pipeline,
|
| 112 |
+
inputs=gr.Image(type="numpy", label="Upload Handwritten Image"),
|
| 113 |
+
outputs="text",
|
| 114 |
+
title="π Multilingual Handwritten OCR with CLIP + PaddleOCR",
|
| 115 |
+
description="π Upload a handwritten document image. Detects language using CLIP and performs text detection + recognition with PaddleOCR."
|
| 116 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
if __name__ == "__main__":
|
| 119 |
+
iface = build_interface()
|
| 120 |
+
iface.launch()
|