Update app.py
Browse files
app.py
CHANGED
|
@@ -244,7 +244,6 @@
|
|
| 244 |
|
| 245 |
|
| 246 |
|
| 247 |
-
|
| 248 |
import gradio as gr
|
| 249 |
import torch
|
| 250 |
import numpy as np
|
|
@@ -257,23 +256,22 @@ from paddleocr import PaddleOCR
|
|
| 257 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 258 |
print(f"Loading TrOCR on {device}...")
|
| 259 |
|
| 260 |
-
# Using the 'base' model for better accuracy on the crops
|
| 261 |
processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten')
|
| 262 |
model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten').to(device).eval()
|
| 263 |
|
| 264 |
# --- 2. SETUP PADDLEOCR (Detection Only) ---
|
| 265 |
print("Loading PaddleOCR (DBNet)...")
|
| 266 |
-
#
|
| 267 |
-
# lang='en' loads the English detection model
|
| 268 |
detector = PaddleOCR(use_angle_cls=True, lang='en', show_log=False)
|
| 269 |
|
| 270 |
def get_sorted_boxes(boxes):
|
| 271 |
"""Sorts boxes top-to-bottom (lines), then left-to-right."""
|
| 272 |
-
if
|
|
|
|
|
|
|
| 273 |
items = []
|
| 274 |
for box in boxes:
|
| 275 |
-
# Paddle returns boxes as
|
| 276 |
-
# We convert to numpy for easier calc
|
| 277 |
box = np.array(box).astype(np.float32)
|
| 278 |
cy = np.mean(box[:, 1])
|
| 279 |
cx = np.mean(box[:, 0])
|
|
@@ -287,43 +285,45 @@ def process_image(image):
|
|
| 287 |
if image is None:
|
| 288 |
return None, [], "Please upload an image."
|
| 289 |
|
| 290 |
-
# Convert to standard RGB Numpy array
|
| 291 |
image_np = np.array(image.convert("RGB"))
|
| 292 |
|
| 293 |
-
#
|
| 294 |
-
#
|
| 295 |
-
#
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
#
|
| 299 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 300 |
return image, [], "No text detected."
|
| 301 |
|
| 302 |
-
#
|
| 303 |
-
|
| 304 |
|
| 305 |
-
sorted_boxes = get_sorted_boxes(boxes)
|
| 306 |
annotated_img = image_np.copy()
|
| 307 |
results = []
|
| 308 |
debug_crops = []
|
| 309 |
|
| 310 |
-
#
|
| 311 |
for box in sorted_boxes:
|
| 312 |
box_int = box.astype(np.int32)
|
| 313 |
|
| 314 |
-
# Draw
|
| 315 |
cv2.polylines(annotated_img, [box_int], True, (255, 0, 0), 2)
|
| 316 |
|
| 317 |
-
#
|
| 318 |
-
|
| 319 |
-
PADDING = 8
|
| 320 |
-
|
| 321 |
x_min = max(0, np.min(box_int[:, 0]) - PADDING)
|
| 322 |
x_max = min(image_np.shape[1], np.max(box_int[:, 0]) + PADDING)
|
| 323 |
y_min = max(0, np.min(box_int[:, 1]) - PADDING)
|
| 324 |
y_max = min(image_np.shape[0], np.max(box_int[:, 1]) + PADDING)
|
| 325 |
|
| 326 |
-
# Skip
|
| 327 |
if (x_max - x_min) < 15 or (y_max - y_min) < 10:
|
| 328 |
continue
|
| 329 |
|
|
@@ -331,7 +331,7 @@ def process_image(image):
|
|
| 331 |
pil_crop = Image.fromarray(crop)
|
| 332 |
debug_crops.append(pil_crop)
|
| 333 |
|
| 334 |
-
#
|
| 335 |
with torch.no_grad():
|
| 336 |
pixel_values = processor(images=pil_crop, return_tensors="pt").pixel_values.to(device)
|
| 337 |
generated_ids = model.generate(pixel_values)
|
|
@@ -344,10 +344,10 @@ def process_image(image):
|
|
| 344 |
|
| 345 |
return Image.fromarray(annotated_img), debug_crops, full_text
|
| 346 |
|
| 347 |
-
# ---
|
| 348 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
| 349 |
-
gr.Markdown("# ⚡ PaddleOCR + TrOCR")
|
| 350 |
-
gr.Markdown("Using
|
| 351 |
|
| 352 |
with gr.Row():
|
| 353 |
with gr.Column(scale=1):
|
|
@@ -359,7 +359,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
| 359 |
output_txt = gr.Textbox(label="Extracted Text", lines=15, show_copy_button=True)
|
| 360 |
|
| 361 |
with gr.Row():
|
| 362 |
-
gallery = gr.Gallery(label="Line Crops", columns=6, height=200)
|
| 363 |
|
| 364 |
btn.click(process_image, input_img, [output_img, gallery, output_txt])
|
| 365 |
|
|
|
|
| 244 |
|
| 245 |
|
| 246 |
|
|
|
|
| 247 |
import gradio as gr
|
| 248 |
import torch
|
| 249 |
import numpy as np
|
|
|
|
| 256 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 257 |
print(f"Loading TrOCR on {device}...")
|
| 258 |
|
|
|
|
| 259 |
processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten')
|
| 260 |
model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten').to(device).eval()
|
| 261 |
|
| 262 |
# --- 2. SETUP PADDLEOCR (Detection Only) ---
|
| 263 |
print("Loading PaddleOCR (DBNet)...")
|
| 264 |
+
# We load the detector but we will bypass the main .ocr() method to avoid bugs
|
|
|
|
| 265 |
detector = PaddleOCR(use_angle_cls=True, lang='en', show_log=False)
|
| 266 |
|
| 267 |
def get_sorted_boxes(boxes):
|
| 268 |
"""Sorts boxes top-to-bottom (lines), then left-to-right."""
|
| 269 |
+
if boxes is None or len(boxes) == 0:
|
| 270 |
+
return []
|
| 271 |
+
|
| 272 |
items = []
|
| 273 |
for box in boxes:
|
| 274 |
+
# Paddle returns boxes as numpy arrays or lists
|
|
|
|
| 275 |
box = np.array(box).astype(np.float32)
|
| 276 |
cy = np.mean(box[:, 1])
|
| 277 |
cx = np.mean(box[:, 0])
|
|
|
|
| 285 |
if image is None:
|
| 286 |
return None, [], "Please upload an image."
|
| 287 |
|
| 288 |
+
# Convert to standard RGB Numpy array
|
| 289 |
image_np = np.array(image.convert("RGB"))
|
| 290 |
|
| 291 |
+
# ============================================================
|
| 292 |
+
# 🔴 FIX: Direct Detection Bypass
|
| 293 |
+
# ============================================================
|
| 294 |
+
# The standard 'detector.ocr()' method has a bug in the current
|
| 295 |
+
# version that crashes when checking "if not boxes".
|
| 296 |
+
# We call the internal 'text_detector' directly to skip that check.
|
| 297 |
+
try:
|
| 298 |
+
dt_boxes, _ = detector.text_detector(image_np)
|
| 299 |
+
except Exception as e:
|
| 300 |
+
return image, [], f"Detection Error: {str(e)}"
|
| 301 |
+
|
| 302 |
+
if dt_boxes is None or len(dt_boxes) == 0:
|
| 303 |
return image, [], "No text detected."
|
| 304 |
|
| 305 |
+
# dt_boxes is already a numpy array of coordinates
|
| 306 |
+
sorted_boxes = get_sorted_boxes(dt_boxes)
|
| 307 |
|
|
|
|
| 308 |
annotated_img = image_np.copy()
|
| 309 |
results = []
|
| 310 |
debug_crops = []
|
| 311 |
|
| 312 |
+
# Process Boxes
|
| 313 |
for box in sorted_boxes:
|
| 314 |
box_int = box.astype(np.int32)
|
| 315 |
|
| 316 |
+
# Draw Box (Red, thickness 2)
|
| 317 |
cv2.polylines(annotated_img, [box_int], True, (255, 0, 0), 2)
|
| 318 |
|
| 319 |
+
# Crop with Padding (Prevents TrOCR Hallucinations)
|
| 320 |
+
PADDING = 10
|
|
|
|
|
|
|
| 321 |
x_min = max(0, np.min(box_int[:, 0]) - PADDING)
|
| 322 |
x_max = min(image_np.shape[1], np.max(box_int[:, 0]) + PADDING)
|
| 323 |
y_min = max(0, np.min(box_int[:, 1]) - PADDING)
|
| 324 |
y_max = min(image_np.shape[0], np.max(box_int[:, 1]) + PADDING)
|
| 325 |
|
| 326 |
+
# Skip noise
|
| 327 |
if (x_max - x_min) < 15 or (y_max - y_min) < 10:
|
| 328 |
continue
|
| 329 |
|
|
|
|
| 331 |
pil_crop = Image.fromarray(crop)
|
| 332 |
debug_crops.append(pil_crop)
|
| 333 |
|
| 334 |
+
# Recognition (TrOCR)
|
| 335 |
with torch.no_grad():
|
| 336 |
pixel_values = processor(images=pil_crop, return_tensors="pt").pixel_values.to(device)
|
| 337 |
generated_ids = model.generate(pixel_values)
|
|
|
|
| 344 |
|
| 345 |
return Image.fromarray(annotated_img), debug_crops, full_text
|
| 346 |
|
| 347 |
+
# --- UI ---
|
| 348 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
| 349 |
+
gr.Markdown("# ⚡ PaddleOCR + TrOCR (Robust)")
|
| 350 |
+
gr.Markdown("Using direct DBNet inference to avoid library bugs.")
|
| 351 |
|
| 352 |
with gr.Row():
|
| 353 |
with gr.Column(scale=1):
|
|
|
|
| 359 |
output_txt = gr.Textbox(label="Extracted Text", lines=15, show_copy_button=True)
|
| 360 |
|
| 361 |
with gr.Row():
|
| 362 |
+
gallery = gr.Gallery(label="Line Crops (Debug)", columns=6, height=200)
|
| 363 |
|
| 364 |
btn.click(process_image, input_img, [output_img, gallery, output_txt])
|
| 365 |
|