| # import os | |
| # import sys | |
| # import numpy as np | |
| # import cv2 | |
| # # ========================================== | |
| # # π§ PATCH 1: Fix Torchvision Compatibility | |
| # # ========================================== | |
| # import torchvision.models.vgg | |
| # if not hasattr(torchvision.models.vgg, 'model_urls'): | |
| # torchvision.models.vgg.model_urls = { | |
| # 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth' | |
| # } | |
| # # ========================================== | |
| # # π§ PATCH 2: Fix NumPy Crash AND Coordinates | |
| # # ========================================== | |
| # import craft_text_detector.craft_utils as craft_utils_module | |
| # # CRITICAL FIX: Added 'ratio_net=2' to the signature and calculation. | |
| # # Without this, your boxes are 2x smaller than they should be. | |
| # def fixed_adjustResultCoordinates(polys, ratio_w, ratio_h, ratio_net=2): | |
| # if not polys: | |
| # return [] | |
| # adjusted = [] | |
| # for poly in polys: | |
| # if poly is None or len(poly) == 0: | |
| # continue | |
| # # Convert to numpy array safely | |
| # poly = np.array(poly).reshape(-1, 2) | |
| # # Scale coordinates: | |
| # # 1. ratio_net scales from Heatmap -> Resized Image | |
| # # 2. ratio_w/h scales from Resized Image -> Original Image | |
| # poly[:, 0] *= (ratio_w * ratio_net) | |
| # poly[:, 1] *= (ratio_h * ratio_net) | |
| # adjusted.append(poly) | |
| # return adjusted | |
| # # Apply the patch | |
| # craft_utils_module.adjustResultCoordinates = fixed_adjustResultCoordinates | |
| # # ========================================== | |
| # import gradio as gr | |
| # from craft_hw_ocr import OCR | |
| # print("β³ Loading OCR models...") | |
| # ocr = OCR.load_models() | |
| # print("β Models loaded!") | |
| # def do_ocr(inp): | |
| # if inp is None: | |
| # return None, "No image uploaded." | |
| # try: | |
| # # Detection | |
| # # 'detected_img' is usually just the input image in this library | |
| # detected_img, results = OCR.detection(inp, ocr[2]) | |
| # # Visualization | |
| # # Now that coordinates are scaled correctly, this should look right | |
| # viz_img = OCR.visualize(inp, results) | |
| # # Recognition | |
| # try: | |
| # # Note: The library has a typo 'recoginition' (extra 'i') | |
| # bboxes, text = OCR.recoginition(detected_img, results, ocr[0], ocr[1]) | |
| # except Exception as e: | |
| # print(f"Recognition error: {e}") | |
| # text = f"Detection successful, but recognition failed: {str(e)}" | |
| # return viz_img, text | |
| # except Exception as e: | |
| # print(f"OCR error: {e}") | |
| # return None, f"Error processing image: {str(e)}" | |
| # inputs = gr.Image(label="Upload Image") | |
| # o1 = gr.Image(label="Detections") | |
| # o2 = gr.Textbox(label="Text") | |
| # title = "CRAFT-OCR (Fixed Coords)" | |
| # description = "Handwriting OCR using CRAFT + TrOCR. Patched for NumPy and Coordinates." | |
| # gr.Interface( | |
| # fn=do_ocr, | |
| # inputs=inputs, | |
| # outputs=[o1, o2], | |
| # title=title, | |
| # description=description | |
| # ).launch() | |
| import os | |
| import sys | |
| import numpy as np | |
| import cv2 | |
| # ========================================== | |
| # π§ PATCH 1: Fix Torchvision Compatibility | |
| # ========================================== | |
| import torchvision.models.vgg | |
| if not hasattr(torchvision.models.vgg, 'model_urls'): | |
| torchvision.models.vgg.model_urls = { | |
| 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth' | |
| } | |
| # ========================================== | |
| # π§ PATCH 2: Fix NumPy Crash AND Coordinates | |
| # ========================================== | |
| import craft_text_detector.craft_utils as craft_utils_module | |
| def fixed_adjustResultCoordinates(polys, ratio_w, ratio_h, ratio_net=2): | |
| if not polys: | |
| return [] | |
| adjusted = [] | |
| for poly in polys: | |
| if poly is None or len(poly) == 0: | |
| continue | |
| # Convert to numpy array safely | |
| poly = np.array(poly).reshape(-1, 2) | |
| # Scale coordinates correctly | |
| poly[:, 0] *= (ratio_w * ratio_net) | |
| poly[:, 1] *= (ratio_h * ratio_net) | |
| adjusted.append(poly) | |
| return adjusted | |
| craft_utils_module.adjustResultCoordinates = fixed_adjustResultCoordinates | |
| # ========================================== | |
| import gradio as gr | |
| from craft_hw_ocr import OCR | |
| # Import the core prediction function to bypass the wrapper limitations | |
| from craft_text_detector.predict import get_prediction | |
| print("β³ Loading OCR models...") | |
| ocr = OCR.load_models() | |
| # ocr[2] is the Craft object wrapper | |
| craft_wrapper = ocr[2] | |
| # We extract the actual networks to run them manually | |
| craft_net = craft_wrapper.craft_net | |
| refine_net = craft_wrapper.refine_net | |
| print("β Models loaded!") | |
| def do_ocr(inp, text_threshold, link_threshold, low_text): | |
| if inp is None: | |
| return None, "No image uploaded." | |
| try: | |
| print(f"βοΈ Running Direct Inference: Text={text_threshold}, Link={link_threshold}, Low={low_text}") | |
| # 1. Direct Detection (Bypassing Craft.detect_text) | |
| # This calls the engine directly, which accepts all our sliders. | |
| prediction_result = get_prediction( | |
| image=inp, | |
| craft_net=craft_net, | |
| refine_net=refine_net, | |
| text_threshold=text_threshold, | |
| link_threshold=link_threshold, | |
| low_text=low_text, | |
| cuda=False, # Space is CPU | |
| poly=True | |
| ) | |
| # 2. Visualization | |
| # OCR.visualize expects (image, results_dict) | |
| viz_img = OCR.visualize(inp, prediction_result) | |
| # 3. Recognition | |
| try: | |
| bboxes, text = OCR.recoginition(inp, prediction_result, ocr[0], ocr[1]) | |
| except Exception as e: | |
| text = f"Detection successful, but recognition failed: {e}" | |
| return viz_img, text | |
| except Exception as e: | |
| print(f"OCR error: {e}") | |
| return None, f"Error processing image: {str(e)}" | |
| # ------------------------------------------------------ | |
| # ποΈ UI with Tuning Sliders | |
| # ------------------------------------------------------ | |
| with gr.Blocks(title="CRAFT-OCR Tuner") as demo: | |
| gr.Markdown("## π§ CRAFT-OCR Parameter Tuner") | |
| gr.Markdown("Adjust sliders to fix issues like merged words or background noise.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_img = gr.Image(label="Upload Image") | |
| gr.Markdown("### ποΈ Fine-Tune Detection") | |
| text_thres = gr.Slider(0.1, 0.9, value=0.7, step=0.05, label="Text Threshold", | |
| info="Confidence to consider a pixel as text. Higher = Less Noise.") | |
| link_thres = gr.Slider(0.1, 0.9, value=0.4, step=0.05, label="Link Threshold", | |
| info="Confidence to link characters. HIGHER value splits merged words (Fixes 'Hamburgthen').") | |
| low_text = gr.Slider(0.1, 0.9, value=0.4, step=0.05, label="Low Text Threshold", | |
| info="Filters background noise. Higher = Cleaner background.") | |
| btn = gr.Button("Run OCR", variant="primary") | |
| with gr.Column(scale=1): | |
| viz_output = gr.Image(label="Detections (Verify Boxes)") | |
| text_output = gr.Textbox(label="Recognized Text", lines=10) | |
| btn.click( | |
| fn=do_ocr, | |
| inputs=[input_img, text_thres, link_thres, low_text], | |
| outputs=[viz_output, text_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |