File size: 7,636 Bytes
514209d 9005605 514209d 9005605 514209d 9005605 514209d 9005605 514209d 9005605 514209d 9005605 514209d 9005605 514209d 5bce9d0 514209d 5bce9d0 514209d 5bce9d0 514209d 9005605 5bce9d0 514209d 9005605 514209d 5bce9d0 9005605 5bce9d0 d182412 514209d 5bce9d0 d182412 5bce9d0 d182412 8ccf654 d182412 514209d 6552627 8ccf654 514209d 6552627 514209d 8ccf654 5486945 514209d 8ccf654 5486945 8ccf654 514209d 6552627 514209d d182412 514209d 5bce9d0 a1046da 5486945 a1046da 5bce9d0 a1046da 5486945 5bce9d0 a1046da 9005605 5bce9d0 712b04b 5486945 9005605 5486945 9005605 712b04b 5486945 9005605 712b04b 5486945 712b04b 9005605 712b04b 5486945 712b04b 9005605 5486945 9005605 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 |
# import os
# import sys
# import numpy as np
# import cv2
# # ==========================================
# # π§ PATCH 1: Fix Torchvision Compatibility
# # ==========================================
# import torchvision.models.vgg
# if not hasattr(torchvision.models.vgg, 'model_urls'):
# torchvision.models.vgg.model_urls = {
# 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth'
# }
# # ==========================================
# # π§ PATCH 2: Fix NumPy Crash AND Coordinates
# # ==========================================
# import craft_text_detector.craft_utils as craft_utils_module
# # CRITICAL FIX: Added 'ratio_net=2' to the signature and calculation.
# # Without this, your boxes are 2x smaller than they should be.
# def fixed_adjustResultCoordinates(polys, ratio_w, ratio_h, ratio_net=2):
# if not polys:
# return []
# adjusted = []
# for poly in polys:
# if poly is None or len(poly) == 0:
# continue
# # Convert to numpy array safely
# poly = np.array(poly).reshape(-1, 2)
# # Scale coordinates:
# # 1. ratio_net scales from Heatmap -> Resized Image
# # 2. ratio_w/h scales from Resized Image -> Original Image
# poly[:, 0] *= (ratio_w * ratio_net)
# poly[:, 1] *= (ratio_h * ratio_net)
# adjusted.append(poly)
# return adjusted
# # Apply the patch
# craft_utils_module.adjustResultCoordinates = fixed_adjustResultCoordinates
# # ==========================================
# import gradio as gr
# from craft_hw_ocr import OCR
# print("β³ Loading OCR models...")
# ocr = OCR.load_models()
# print("β
Models loaded!")
# def do_ocr(inp):
# if inp is None:
# return None, "No image uploaded."
# try:
# # Detection
# # 'detected_img' is usually just the input image in this library
# detected_img, results = OCR.detection(inp, ocr[2])
# # Visualization
# # Now that coordinates are scaled correctly, this should look right
# viz_img = OCR.visualize(inp, results)
# # Recognition
# try:
# # Note: The library has a typo 'recoginition' (extra 'i')
# bboxes, text = OCR.recoginition(detected_img, results, ocr[0], ocr[1])
# except Exception as e:
# print(f"Recognition error: {e}")
# text = f"Detection successful, but recognition failed: {str(e)}"
# return viz_img, text
# except Exception as e:
# print(f"OCR error: {e}")
# return None, f"Error processing image: {str(e)}"
# inputs = gr.Image(label="Upload Image")
# o1 = gr.Image(label="Detections")
# o2 = gr.Textbox(label="Text")
# title = "CRAFT-OCR (Fixed Coords)"
# description = "Handwriting OCR using CRAFT + TrOCR. Patched for NumPy and Coordinates."
# gr.Interface(
# fn=do_ocr,
# inputs=inputs,
# outputs=[o1, o2],
# title=title,
# description=description
# ).launch()
import os
import sys
import numpy as np
import cv2
# ==========================================
# π§ PATCH 1: Fix Torchvision Compatibility
# ==========================================
import torchvision.models.vgg
if not hasattr(torchvision.models.vgg, 'model_urls'):
torchvision.models.vgg.model_urls = {
'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth'
}
# ==========================================
# π§ PATCH 2: Fix NumPy Crash AND Coordinates
# ==========================================
import craft_text_detector.craft_utils as craft_utils_module
def fixed_adjustResultCoordinates(polys, ratio_w, ratio_h, ratio_net=2):
if not polys:
return []
adjusted = []
for poly in polys:
if poly is None or len(poly) == 0:
continue
# Convert to numpy array safely
poly = np.array(poly).reshape(-1, 2)
# Scale coordinates correctly
poly[:, 0] *= (ratio_w * ratio_net)
poly[:, 1] *= (ratio_h * ratio_net)
adjusted.append(poly)
return adjusted
craft_utils_module.adjustResultCoordinates = fixed_adjustResultCoordinates
# ==========================================
import gradio as gr
from craft_hw_ocr import OCR
# Import the core prediction function to bypass the wrapper limitations
from craft_text_detector.predict import get_prediction
print("β³ Loading OCR models...")
ocr = OCR.load_models()
# ocr[2] is the Craft object wrapper
craft_wrapper = ocr[2]
# We extract the actual networks to run them manually
craft_net = craft_wrapper.craft_net
refine_net = craft_wrapper.refine_net
print("β
Models loaded!")
def do_ocr(inp, text_threshold, link_threshold, low_text):
if inp is None:
return None, "No image uploaded."
try:
print(f"βοΈ Running Direct Inference: Text={text_threshold}, Link={link_threshold}, Low={low_text}")
# 1. Direct Detection (Bypassing Craft.detect_text)
# This calls the engine directly, which accepts all our sliders.
prediction_result = get_prediction(
image=inp,
craft_net=craft_net,
refine_net=refine_net,
text_threshold=text_threshold,
link_threshold=link_threshold,
low_text=low_text,
cuda=False, # Space is CPU
poly=True
)
# 2. Visualization
# OCR.visualize expects (image, results_dict)
viz_img = OCR.visualize(inp, prediction_result)
# 3. Recognition
try:
bboxes, text = OCR.recoginition(inp, prediction_result, ocr[0], ocr[1])
except Exception as e:
text = f"Detection successful, but recognition failed: {e}"
return viz_img, text
except Exception as e:
print(f"OCR error: {e}")
return None, f"Error processing image: {str(e)}"
# ------------------------------------------------------
# ποΈ UI with Tuning Sliders
# ------------------------------------------------------
with gr.Blocks(title="CRAFT-OCR Tuner") as demo:
gr.Markdown("## π§ CRAFT-OCR Parameter Tuner")
gr.Markdown("Adjust sliders to fix issues like merged words or background noise.")
with gr.Row():
with gr.Column(scale=1):
input_img = gr.Image(label="Upload Image")
gr.Markdown("### ποΈ Fine-Tune Detection")
text_thres = gr.Slider(0.1, 0.9, value=0.7, step=0.05, label="Text Threshold",
info="Confidence to consider a pixel as text. Higher = Less Noise.")
link_thres = gr.Slider(0.1, 0.9, value=0.4, step=0.05, label="Link Threshold",
info="Confidence to link characters. HIGHER value splits merged words (Fixes 'Hamburgthen').")
low_text = gr.Slider(0.1, 0.9, value=0.4, step=0.05, label="Low Text Threshold",
info="Filters background noise. Higher = Cleaner background.")
btn = gr.Button("Run OCR", variant="primary")
with gr.Column(scale=1):
viz_output = gr.Image(label="Detections (Verify Boxes)")
text_output = gr.Textbox(label="Recognized Text", lines=10)
btn.click(
fn=do_ocr,
inputs=[input_img, text_thres, link_thres, low_text],
outputs=[viz_output, text_output]
)
if __name__ == "__main__":
demo.launch() |