Update app.py
Browse files
app.py
CHANGED
|
@@ -1,179 +1,6 @@
|
|
| 1 |
-
# import gradio as gr
|
| 2 |
-
# import logging
|
| 3 |
-
# import os
|
| 4 |
-
# import numpy as np
|
| 5 |
-
# import torch
|
| 6 |
-
# from PIL import Image, ImageDraw
|
| 7 |
-
# from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
| 8 |
-
|
| 9 |
-
# # --- SURYA IMPORTS ---
|
| 10 |
-
# try:
|
| 11 |
-
# from surya.detection import batch_text_detection
|
| 12 |
-
# from surya.model.detection.model import load_model as load_det_model, load_processor as load_det_processor
|
| 13 |
-
# except ImportError:
|
| 14 |
-
# from surya.detection import batch_inference as batch_text_detection
|
| 15 |
-
# from surya.model.detection.segformer import load_model as load_det_model, load_processor as load_det_processor
|
| 16 |
-
|
| 17 |
-
# # ==========================================
|
| 18 |
-
# # 1. SETUP MODELS
|
| 19 |
-
# # ==========================================
|
| 20 |
-
# device = "cpu"
|
| 21 |
-
# logging.basicConfig(level=logging.INFO)
|
| 22 |
-
# logger = logging.getLogger(__name__)
|
| 23 |
-
|
| 24 |
-
# logger.info("⏳ Loading Models...")
|
| 25 |
-
|
| 26 |
-
# # A. SURYA DETECTION
|
| 27 |
-
# det_processor = load_det_processor()
|
| 28 |
-
# det_model = load_det_model().to(device)
|
| 29 |
-
|
| 30 |
-
# # B. TROCR RECOGNITION
|
| 31 |
-
# # NOTE: We do NOT use quantization here. It destroys the attention mechanism in ViT
|
| 32 |
-
# # encoders on CPU, leading to "mode collapse" (hallucinations).
|
| 33 |
-
# trocr_processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten')
|
| 34 |
-
# trocr_model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten').to(device)
|
| 35 |
-
|
| 36 |
-
# logger.info("✅ All Models Loaded.")
|
| 37 |
-
|
| 38 |
-
# # ==========================================
|
| 39 |
-
# # 2. HELPER FUNCTIONS
|
| 40 |
-
# # ==========================================
|
| 41 |
-
# def recognize_batch(crops):
|
| 42 |
-
# """
|
| 43 |
-
# Feeds raw crops directly to TrOCR.
|
| 44 |
-
# """
|
| 45 |
-
# if not crops: return []
|
| 46 |
-
|
| 47 |
-
# # Ensure crops are valid
|
| 48 |
-
# valid_crops = [c for c in crops if c.size[0] > 0 and c.size[1] > 0]
|
| 49 |
-
# if not valid_crops: return []
|
| 50 |
-
|
| 51 |
-
# pixel_values = trocr_processor(images=valid_crops, return_tensors="pt").pixel_values.to(device)
|
| 52 |
-
|
| 53 |
-
# with torch.no_grad():
|
| 54 |
-
# # Using a slightly lower max_length prevents it from rambling if it gets confused
|
| 55 |
-
# generated_ids = trocr_model.generate(pixel_values, max_length=64)
|
| 56 |
-
# text = trocr_processor.batch_decode(generated_ids, skip_special_tokens=True)
|
| 57 |
-
# return text
|
| 58 |
-
|
| 59 |
-
# def draw_boxes(image, prediction_objects):
|
| 60 |
-
# draw = ImageDraw.Draw(image)
|
| 61 |
-
# for obj in prediction_objects:
|
| 62 |
-
# if hasattr(obj, "bbox"):
|
| 63 |
-
# draw.rectangle(obj.bbox, outline="red", width=2)
|
| 64 |
-
# else:
|
| 65 |
-
# # Fallback if obj is just a list/tuple
|
| 66 |
-
# draw.rectangle(obj, outline="red", width=2)
|
| 67 |
-
# return image
|
| 68 |
-
|
| 69 |
-
# # ==========================================
|
| 70 |
-
# # 3. MAIN WORKFLOW
|
| 71 |
-
# # ==========================================
|
| 72 |
-
# def hybrid_ocr_workflow(image):
|
| 73 |
-
# if image is None: return None, "Please upload an image."
|
| 74 |
-
|
| 75 |
-
# # CRITICAL FIX: Ensure image is RGB (TrOCR fails on RGBA/P modes silently)
|
| 76 |
-
# if image.mode != "RGB":
|
| 77 |
-
# image = image.convert("RGB")
|
| 78 |
-
|
| 79 |
-
# # 1. DETECT (Surya)
|
| 80 |
-
# logger.info("Step 1: Detecting Lines with Surya...")
|
| 81 |
-
# # Surya expects list of images
|
| 82 |
-
# predictions = batch_text_detection([image], det_model, det_processor)
|
| 83 |
-
# result = predictions[0]
|
| 84 |
-
|
| 85 |
-
# # Extract BBoxes
|
| 86 |
-
# lines_objects = []
|
| 87 |
-
# if hasattr(result, "bboxes"):
|
| 88 |
-
# lines_objects = result.bboxes
|
| 89 |
-
# elif hasattr(result, "text_lines"):
|
| 90 |
-
# lines_objects = result.text_lines
|
| 91 |
-
|
| 92 |
-
# # Sort by Y-coordinate (top to bottom)
|
| 93 |
-
# lines_objects.sort(key=lambda x: x.bbox[1])
|
| 94 |
-
|
| 95 |
-
# # 2. CROP & RECOGNIZE
|
| 96 |
-
# logger.info(f"Step 2: Recognizing {len(lines_objects)} lines with TrOCR...")
|
| 97 |
-
|
| 98 |
-
# line_crops = []
|
| 99 |
-
# w, h = image.size
|
| 100 |
-
|
| 101 |
-
# for obj in lines_objects:
|
| 102 |
-
# bbox = obj.bbox
|
| 103 |
-
|
| 104 |
-
# # Crop the full line
|
| 105 |
-
# pad = 6
|
| 106 |
-
# x1 = max(0, int(bbox[0]) - pad)
|
| 107 |
-
# y1 = max(0, int(bbox[1]) - pad)
|
| 108 |
-
# x2 = min(w, int(bbox[2]) + pad)
|
| 109 |
-
# y2 = min(h, int(bbox[3]) + pad)
|
| 110 |
-
|
| 111 |
-
# line_crop = image.crop((x1, y1, x2, y2))
|
| 112 |
-
# line_crops.append(line_crop)
|
| 113 |
-
|
| 114 |
-
# # Batch processing
|
| 115 |
-
# full_text_lines = []
|
| 116 |
-
# batch_size = 4
|
| 117 |
-
|
| 118 |
-
# for i in range(0, len(line_crops), batch_size):
|
| 119 |
-
# batch = line_crops[i:i+batch_size]
|
| 120 |
-
# try:
|
| 121 |
-
# batch_results = recognize_batch(batch)
|
| 122 |
-
# full_text_lines.extend(batch_results)
|
| 123 |
-
# except Exception as e:
|
| 124 |
-
# logger.error(f"Batch failed: {e}")
|
| 125 |
-
# full_text_lines.append("[Error processing line]")
|
| 126 |
-
|
| 127 |
-
# final_text = "\n".join(full_text_lines)
|
| 128 |
-
|
| 129 |
-
# # Visualize
|
| 130 |
-
# vis_img = draw_boxes(image.copy(), lines_objects)
|
| 131 |
-
|
| 132 |
-
# return vis_img, final_text
|
| 133 |
-
|
| 134 |
-
# # ==========================================
|
| 135 |
-
# # 4. GRADIO UI
|
| 136 |
-
# # ==========================================
|
| 137 |
-
# custom_css = """
|
| 138 |
-
# .gen-button { background-color: #ff4081 !important; color: white !important; font-weight: bold !important; }
|
| 139 |
-
# """
|
| 140 |
-
|
| 141 |
-
# with gr.Blocks(css=custom_css) as demo:
|
| 142 |
-
# gr.Markdown("# 🚀 Hybrid OCR: Surya (Raw) + TrOCR (Corrected)")
|
| 143 |
-
|
| 144 |
-
# with gr.Row():
|
| 145 |
-
# ocr_input = gr.Image(type="pil", label="Upload Image")
|
| 146 |
-
# ocr_output_img = gr.Image(type="pil", label="Surya Detections")
|
| 147 |
-
|
| 148 |
-
# ocr_text = gr.Textbox(label="Recognized Text", lines=20)
|
| 149 |
-
# ocr_button = gr.Button("Run Hybrid OCR", elem_classes="gen-button")
|
| 150 |
-
|
| 151 |
-
# ocr_button.click(hybrid_ocr_workflow, inputs=[ocr_input], outputs=[ocr_output_img, ocr_text])
|
| 152 |
-
|
| 153 |
-
# if __name__ == "__main__":
|
| 154 |
-
# demo.launch(theme=gr.themes.Soft(), css=custom_css)
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
import os
|
| 165 |
-
|
| 166 |
-
# ==========================================
|
| 167 |
-
# 0. SURYA CONFIGURATION
|
| 168 |
-
# ==========================================
|
| 169 |
-
# MUST be set before importing surya to take effect.
|
| 170 |
-
# 1. Lower text threshold (0.6 -> 0.50) to catch faint handwriting strokes
|
| 171 |
-
os.environ["DETECTOR_TEXT_THRESHOLD"] = "0.50"
|
| 172 |
-
# 2. Raise blank threshold (0.7 -> 0.80) to prevent splitting wavy lines
|
| 173 |
-
os.environ["DETECTOR_BLANK_THRESHOLD"] = "0.80"
|
| 174 |
-
|
| 175 |
import gradio as gr
|
| 176 |
import logging
|
|
|
|
| 177 |
import numpy as np
|
| 178 |
import torch
|
| 179 |
from PIL import Image, ImageDraw
|
|
@@ -201,7 +28,8 @@ det_processor = load_det_processor()
|
|
| 201 |
det_model = load_det_model().to(device)
|
| 202 |
|
| 203 |
# B. TROCR RECOGNITION
|
| 204 |
-
# NOTE:
|
|
|
|
| 205 |
trocr_processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten')
|
| 206 |
trocr_model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten').to(device)
|
| 207 |
|
|
@@ -210,43 +38,20 @@ logger.info("✅ All Models Loaded.")
|
|
| 210 |
# ==========================================
|
| 211 |
# 2. HELPER FUNCTIONS
|
| 212 |
# ==========================================
|
| 213 |
-
def pad_to_square(image):
|
| 214 |
-
"""
|
| 215 |
-
Pads a crop to be roughly square (or at least 4:3) to prevent
|
| 216 |
-
the ViT encoder from squashing long text strips into nonsense.
|
| 217 |
-
"""
|
| 218 |
-
w, h = image.size
|
| 219 |
-
# If already roughly square or tall, leave it
|
| 220 |
-
if w <= h * 1.5:
|
| 221 |
-
return image
|
| 222 |
-
|
| 223 |
-
# Target a 2:1 aspect ratio roughly (or just make it taller)
|
| 224 |
-
target_h = int(w * 0.5)
|
| 225 |
-
if target_h <= h: return image
|
| 226 |
-
|
| 227 |
-
# Create white background
|
| 228 |
-
new_img = Image.new("RGB", (w, target_h), (255, 255, 255))
|
| 229 |
-
paste_y = (target_h - h) // 2
|
| 230 |
-
new_img.paste(image, (0, paste_y))
|
| 231 |
-
return new_img
|
| 232 |
-
|
| 233 |
def recognize_batch(crops):
|
| 234 |
"""
|
| 235 |
-
Feeds
|
| 236 |
"""
|
| 237 |
if not crops: return []
|
| 238 |
|
| 239 |
-
#
|
| 240 |
valid_crops = [c for c in crops if c.size[0] > 0 and c.size[1] > 0]
|
| 241 |
if not valid_crops: return []
|
| 242 |
|
| 243 |
-
|
| 244 |
-
processed_crops = [pad_to_square(c) for c in valid_crops]
|
| 245 |
-
|
| 246 |
-
pixel_values = trocr_processor(images=processed_crops, return_tensors="pt").pixel_values.to(device)
|
| 247 |
|
| 248 |
with torch.no_grad():
|
| 249 |
-
# max_length
|
| 250 |
generated_ids = trocr_model.generate(pixel_values, max_length=64)
|
| 251 |
text = trocr_processor.batch_decode(generated_ids, skip_special_tokens=True)
|
| 252 |
return text
|
|
@@ -257,6 +62,7 @@ def draw_boxes(image, prediction_objects):
|
|
| 257 |
if hasattr(obj, "bbox"):
|
| 258 |
draw.rectangle(obj.bbox, outline="red", width=2)
|
| 259 |
else:
|
|
|
|
| 260 |
draw.rectangle(obj, outline="red", width=2)
|
| 261 |
return image
|
| 262 |
|
|
@@ -266,12 +72,13 @@ def draw_boxes(image, prediction_objects):
|
|
| 266 |
def hybrid_ocr_workflow(image):
|
| 267 |
if image is None: return None, "Please upload an image."
|
| 268 |
|
| 269 |
-
# CRITICAL: TrOCR fails
|
| 270 |
if image.mode != "RGB":
|
| 271 |
image = image.convert("RGB")
|
| 272 |
|
| 273 |
# 1. DETECT (Surya)
|
| 274 |
logger.info("Step 1: Detecting Lines with Surya...")
|
|
|
|
| 275 |
predictions = batch_text_detection([image], det_model, det_processor)
|
| 276 |
result = predictions[0]
|
| 277 |
|
|
@@ -282,7 +89,7 @@ def hybrid_ocr_workflow(image):
|
|
| 282 |
elif hasattr(result, "text_lines"):
|
| 283 |
lines_objects = result.text_lines
|
| 284 |
|
| 285 |
-
# Sort by Y-coordinate
|
| 286 |
lines_objects.sort(key=lambda x: x.bbox[1])
|
| 287 |
|
| 288 |
# 2. CROP & RECOGNIZE
|
|
@@ -294,7 +101,7 @@ def hybrid_ocr_workflow(image):
|
|
| 294 |
for obj in lines_objects:
|
| 295 |
bbox = obj.bbox
|
| 296 |
|
| 297 |
-
# Crop
|
| 298 |
pad = 6
|
| 299 |
x1 = max(0, int(bbox[0]) - pad)
|
| 300 |
y1 = max(0, int(bbox[1]) - pad)
|
|
@@ -306,7 +113,7 @@ def hybrid_ocr_workflow(image):
|
|
| 306 |
|
| 307 |
# Batch processing
|
| 308 |
full_text_lines = []
|
| 309 |
-
batch_size = 4
|
| 310 |
|
| 311 |
for i in range(0, len(line_crops), batch_size):
|
| 312 |
batch = line_crops[i:i+batch_size]
|
|
@@ -332,11 +139,11 @@ custom_css = """
|
|
| 332 |
"""
|
| 333 |
|
| 334 |
with gr.Blocks(css=custom_css) as demo:
|
| 335 |
-
gr.Markdown("# 🚀 Hybrid OCR: Surya (
|
| 336 |
|
| 337 |
with gr.Row():
|
| 338 |
ocr_input = gr.Image(type="pil", label="Upload Image")
|
| 339 |
-
ocr_output_img = gr.Image(type="pil", label="Surya Detections
|
| 340 |
|
| 341 |
ocr_text = gr.Textbox(label="Recognized Text", lines=20)
|
| 342 |
ocr_button = gr.Button("Run Hybrid OCR", elem_classes="gen-button")
|
|
@@ -344,4 +151,12 @@ with gr.Blocks(css=custom_css) as demo:
|
|
| 344 |
ocr_button.click(hybrid_ocr_workflow, inputs=[ocr_input], outputs=[ocr_output_img, ocr_text])
|
| 345 |
|
| 346 |
if __name__ == "__main__":
|
| 347 |
-
demo.launch(theme=gr.themes.Soft(), css=custom_css)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import logging
|
| 3 |
+
import os
|
| 4 |
import numpy as np
|
| 5 |
import torch
|
| 6 |
from PIL import Image, ImageDraw
|
|
|
|
| 28 |
det_model = load_det_model().to(device)
|
| 29 |
|
| 30 |
# B. TROCR RECOGNITION
|
| 31 |
+
# NOTE: We do NOT use quantization here. It destroys the attention mechanism in ViT
|
| 32 |
+
# encoders on CPU, leading to "mode collapse" (hallucinations).
|
| 33 |
trocr_processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten')
|
| 34 |
trocr_model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten').to(device)
|
| 35 |
|
|
|
|
| 38 |
# ==========================================
|
| 39 |
# 2. HELPER FUNCTIONS
|
| 40 |
# ==========================================
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
def recognize_batch(crops):
|
| 42 |
"""
|
| 43 |
+
Feeds raw crops directly to TrOCR.
|
| 44 |
"""
|
| 45 |
if not crops: return []
|
| 46 |
|
| 47 |
+
# Ensure crops are valid
|
| 48 |
valid_crops = [c for c in crops if c.size[0] > 0 and c.size[1] > 0]
|
| 49 |
if not valid_crops: return []
|
| 50 |
|
| 51 |
+
pixel_values = trocr_processor(images=valid_crops, return_tensors="pt").pixel_values.to(device)
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
with torch.no_grad():
|
| 54 |
+
# Using a slightly lower max_length prevents it from rambling if it gets confused
|
| 55 |
generated_ids = trocr_model.generate(pixel_values, max_length=64)
|
| 56 |
text = trocr_processor.batch_decode(generated_ids, skip_special_tokens=True)
|
| 57 |
return text
|
|
|
|
| 62 |
if hasattr(obj, "bbox"):
|
| 63 |
draw.rectangle(obj.bbox, outline="red", width=2)
|
| 64 |
else:
|
| 65 |
+
# Fallback if obj is just a list/tuple
|
| 66 |
draw.rectangle(obj, outline="red", width=2)
|
| 67 |
return image
|
| 68 |
|
|
|
|
| 72 |
def hybrid_ocr_workflow(image):
|
| 73 |
if image is None: return None, "Please upload an image."
|
| 74 |
|
| 75 |
+
# CRITICAL FIX: Ensure image is RGB (TrOCR fails on RGBA/P modes silently)
|
| 76 |
if image.mode != "RGB":
|
| 77 |
image = image.convert("RGB")
|
| 78 |
|
| 79 |
# 1. DETECT (Surya)
|
| 80 |
logger.info("Step 1: Detecting Lines with Surya...")
|
| 81 |
+
# Surya expects list of images
|
| 82 |
predictions = batch_text_detection([image], det_model, det_processor)
|
| 83 |
result = predictions[0]
|
| 84 |
|
|
|
|
| 89 |
elif hasattr(result, "text_lines"):
|
| 90 |
lines_objects = result.text_lines
|
| 91 |
|
| 92 |
+
# Sort by Y-coordinate (top to bottom)
|
| 93 |
lines_objects.sort(key=lambda x: x.bbox[1])
|
| 94 |
|
| 95 |
# 2. CROP & RECOGNIZE
|
|
|
|
| 101 |
for obj in lines_objects:
|
| 102 |
bbox = obj.bbox
|
| 103 |
|
| 104 |
+
# Crop the full line
|
| 105 |
pad = 6
|
| 106 |
x1 = max(0, int(bbox[0]) - pad)
|
| 107 |
y1 = max(0, int(bbox[1]) - pad)
|
|
|
|
| 113 |
|
| 114 |
# Batch processing
|
| 115 |
full_text_lines = []
|
| 116 |
+
batch_size = 4
|
| 117 |
|
| 118 |
for i in range(0, len(line_crops), batch_size):
|
| 119 |
batch = line_crops[i:i+batch_size]
|
|
|
|
| 139 |
"""
|
| 140 |
|
| 141 |
with gr.Blocks(css=custom_css) as demo:
|
| 142 |
+
gr.Markdown("# 🚀 Hybrid OCR: Surya (Raw) + TrOCR (Corrected)")
|
| 143 |
|
| 144 |
with gr.Row():
|
| 145 |
ocr_input = gr.Image(type="pil", label="Upload Image")
|
| 146 |
+
ocr_output_img = gr.Image(type="pil", label="Surya Detections")
|
| 147 |
|
| 148 |
ocr_text = gr.Textbox(label="Recognized Text", lines=20)
|
| 149 |
ocr_button = gr.Button("Run Hybrid OCR", elem_classes="gen-button")
|
|
|
|
| 151 |
ocr_button.click(hybrid_ocr_workflow, inputs=[ocr_input], outputs=[ocr_output_img, ocr_text])
|
| 152 |
|
| 153 |
if __name__ == "__main__":
|
| 154 |
+
demo.launch(theme=gr.themes.Soft(), css=custom_css)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
|