trOCR / app.py
iammraat's picture
Update app.py
5486945 verified
# import os
# import sys
# import numpy as np
# import cv2
# # ==========================================
# # πŸ”§ PATCH 1: Fix Torchvision Compatibility
# # ==========================================
# import torchvision.models.vgg
# if not hasattr(torchvision.models.vgg, 'model_urls'):
# torchvision.models.vgg.model_urls = {
# 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth'
# }
# # ==========================================
# # πŸ”§ PATCH 2: Fix NumPy Crash AND Coordinates
# # ==========================================
# import craft_text_detector.craft_utils as craft_utils_module
# # CRITICAL FIX: Added 'ratio_net=2' to the signature and calculation.
# # Without this, your boxes are 2x smaller than they should be.
# def fixed_adjustResultCoordinates(polys, ratio_w, ratio_h, ratio_net=2):
# if not polys:
# return []
# adjusted = []
# for poly in polys:
# if poly is None or len(poly) == 0:
# continue
# # Convert to numpy array safely
# poly = np.array(poly).reshape(-1, 2)
# # Scale coordinates:
# # 1. ratio_net scales from Heatmap -> Resized Image
# # 2. ratio_w/h scales from Resized Image -> Original Image
# poly[:, 0] *= (ratio_w * ratio_net)
# poly[:, 1] *= (ratio_h * ratio_net)
# adjusted.append(poly)
# return adjusted
# # Apply the patch
# craft_utils_module.adjustResultCoordinates = fixed_adjustResultCoordinates
# # ==========================================
# import gradio as gr
# from craft_hw_ocr import OCR
# print("⏳ Loading OCR models...")
# ocr = OCR.load_models()
# print("βœ… Models loaded!")
# def do_ocr(inp):
# if inp is None:
# return None, "No image uploaded."
# try:
# # Detection
# # 'detected_img' is usually just the input image in this library
# detected_img, results = OCR.detection(inp, ocr[2])
# # Visualization
# # Now that coordinates are scaled correctly, this should look right
# viz_img = OCR.visualize(inp, results)
# # Recognition
# try:
# # Note: The library has a typo 'recoginition' (extra 'i')
# bboxes, text = OCR.recoginition(detected_img, results, ocr[0], ocr[1])
# except Exception as e:
# print(f"Recognition error: {e}")
# text = f"Detection successful, but recognition failed: {str(e)}"
# return viz_img, text
# except Exception as e:
# print(f"OCR error: {e}")
# return None, f"Error processing image: {str(e)}"
# inputs = gr.Image(label="Upload Image")
# o1 = gr.Image(label="Detections")
# o2 = gr.Textbox(label="Text")
# title = "CRAFT-OCR (Fixed Coords)"
# description = "Handwriting OCR using CRAFT + TrOCR. Patched for NumPy and Coordinates."
# gr.Interface(
# fn=do_ocr,
# inputs=inputs,
# outputs=[o1, o2],
# title=title,
# description=description
# ).launch()
import os
import sys
import numpy as np
import cv2
# ==========================================
# πŸ”§ PATCH 1: Fix Torchvision Compatibility
# ==========================================
import torchvision.models.vgg
if not hasattr(torchvision.models.vgg, 'model_urls'):
torchvision.models.vgg.model_urls = {
'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth'
}
# ==========================================
# πŸ”§ PATCH 2: Fix NumPy Crash AND Coordinates
# ==========================================
import craft_text_detector.craft_utils as craft_utils_module
def fixed_adjustResultCoordinates(polys, ratio_w, ratio_h, ratio_net=2):
if not polys:
return []
adjusted = []
for poly in polys:
if poly is None or len(poly) == 0:
continue
# Convert to numpy array safely
poly = np.array(poly).reshape(-1, 2)
# Scale coordinates correctly
poly[:, 0] *= (ratio_w * ratio_net)
poly[:, 1] *= (ratio_h * ratio_net)
adjusted.append(poly)
return adjusted
craft_utils_module.adjustResultCoordinates = fixed_adjustResultCoordinates
# ==========================================
import gradio as gr
from craft_hw_ocr import OCR
# Import the core prediction function to bypass the wrapper limitations
from craft_text_detector.predict import get_prediction
print("⏳ Loading OCR models...")
ocr = OCR.load_models()
# ocr[2] is the Craft object wrapper
craft_wrapper = ocr[2]
# We extract the actual networks to run them manually
craft_net = craft_wrapper.craft_net
refine_net = craft_wrapper.refine_net
print("βœ… Models loaded!")
def do_ocr(inp, text_threshold, link_threshold, low_text):
if inp is None:
return None, "No image uploaded."
try:
print(f"βš™οΈ Running Direct Inference: Text={text_threshold}, Link={link_threshold}, Low={low_text}")
# 1. Direct Detection (Bypassing Craft.detect_text)
# This calls the engine directly, which accepts all our sliders.
prediction_result = get_prediction(
image=inp,
craft_net=craft_net,
refine_net=refine_net,
text_threshold=text_threshold,
link_threshold=link_threshold,
low_text=low_text,
cuda=False, # Space is CPU
poly=True
)
# 2. Visualization
# OCR.visualize expects (image, results_dict)
viz_img = OCR.visualize(inp, prediction_result)
# 3. Recognition
try:
bboxes, text = OCR.recoginition(inp, prediction_result, ocr[0], ocr[1])
except Exception as e:
text = f"Detection successful, but recognition failed: {e}"
return viz_img, text
except Exception as e:
print(f"OCR error: {e}")
return None, f"Error processing image: {str(e)}"
# ------------------------------------------------------
# πŸŽ›οΈ UI with Tuning Sliders
# ------------------------------------------------------
with gr.Blocks(title="CRAFT-OCR Tuner") as demo:
gr.Markdown("## πŸ”§ CRAFT-OCR Parameter Tuner")
gr.Markdown("Adjust sliders to fix issues like merged words or background noise.")
with gr.Row():
with gr.Column(scale=1):
input_img = gr.Image(label="Upload Image")
gr.Markdown("### πŸŽ›οΈ Fine-Tune Detection")
text_thres = gr.Slider(0.1, 0.9, value=0.7, step=0.05, label="Text Threshold",
info="Confidence to consider a pixel as text. Higher = Less Noise.")
link_thres = gr.Slider(0.1, 0.9, value=0.4, step=0.05, label="Link Threshold",
info="Confidence to link characters. HIGHER value splits merged words (Fixes 'Hamburgthen').")
low_text = gr.Slider(0.1, 0.9, value=0.4, step=0.05, label="Low Text Threshold",
info="Filters background noise. Higher = Cleaner background.")
btn = gr.Button("Run OCR", variant="primary")
with gr.Column(scale=1):
viz_output = gr.Image(label="Detections (Verify Boxes)")
text_output = gr.Textbox(label="Recognized Text", lines=10)
btn.click(
fn=do_ocr,
inputs=[input_img, text_thres, link_thres, low_text],
outputs=[viz_output, text_output]
)
if __name__ == "__main__":
demo.launch()