import os os.environ["GRADIO_TEMP_DIR"] = "./tmp" import gradio as gr import numpy as np import cv2 from PIL import Image # == Model configurations == MODELS = { "PP-OCRv6 Medium Det": "PP-OCRv6_medium_det", "PP-OCRv6 Small Det": "PP-OCRv6_small_det", "PP-OCRv6 Tiny Det": "PP-OCRv6_tiny_det", } # == Global model variables == current_model = None current_model_key = None # (display_name, thresh, unclip_ratio) cached_results = None # (pil_img, dt_polys, dt_scores) _COLOR = (0, 140, 255) # BGR def load_model_if_needed(model_name, thresh, unclip_ratio): global current_model, current_model_key key = (model_name, round(thresh, 3), round(unclip_ratio, 2)) if current_model_key == key and current_model is not None: return True try: from paddleocr import TextDetection paddle_name = MODELS[model_name] print(f"Loading {paddle_name} thresh={thresh} unclip_ratio={unclip_ratio}") current_model = TextDetection( model_name=paddle_name, engine="transformers", thresh=thresh, unclip_ratio=unclip_ratio, ) current_model_key = key return True except Exception as e: print(f"Error loading model: {e}") return False def visualize_detections(image_input, dt_polys, dt_scores, alpha=0.3, show_scores=True): if isinstance(image_input, Image.Image): image = cv2.cvtColor(np.array(image_input), cv2.COLOR_RGB2BGR) else: image = cv2.cvtColor(image_input, cv2.COLOR_RGB2BGR) if len(dt_polys) == 0: return cv2.cvtColor(image, cv2.COLOR_BGR2RGB) overlay = image.copy() for poly, score in zip(dt_polys, dt_scores): pts = np.array(poly, dtype=np.int32).reshape(-1, 1, 2) cv2.fillPoly(overlay, [pts], _COLOR) cv2.polylines(image, [pts], isClosed=True, color=_COLOR, thickness=3) if show_scores: ax, ay = int(pts[0, 0, 0]), int(pts[0, 0, 1]) text = f"{score:.3f}" (tw, th), bl = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.8, 2) cv2.rectangle(image, (ax, ay - th - bl - 4), (ax + tw + 8, ay), _COLOR, -1) cv2.putText(image, text, (ax + 4, ay - 6), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2) cv2.addWeighted(overlay, alpha, image, 1 - alpha, 0, image) return cv2.cvtColor(image, cv2.COLOR_BGR2RGB) def toggle_labels_visualization(show_scores, alpha): global cached_results if cached_results is None: return None, "⚠️ No cached results. Please analyze an image first." input_img, dt_polys, dt_scores = cached_results output = visualize_detections(input_img, dt_polys, dt_scores, alpha=alpha, show_scores=show_scores) labels_status = "with scores" if show_scores else "without scores" info = f"✅ Visualization updated ({labels_status}) | {len(dt_polys)} detections" return output, info def process_image(input_img, model_name, thresh, box_thresh, unclip_ratio, alpha, show_scores): global cached_results if input_img is None: return None, "❌ Please upload an image first." if not load_model_if_needed(model_name, thresh, unclip_ratio): return None, f"❌ Failed to load model {model_name}." try: if isinstance(input_img, np.ndarray): input_img = Image.fromarray(input_img) if input_img.mode != "RGB": input_img = input_img.convert("RGB") results = current_model.predict(input=np.array(input_img), batch_size=1) if not results: cached_results = None return np.array(input_img), "ℹ️ No detections found." res_dict = results[0].res if hasattr(results[0], "res") else results[0] dt_polys = res_dict.get("dt_polys", []) dt_scores = res_dict.get("dt_scores", []) pairs = [(p, s) for p, s in zip(dt_polys, dt_scores) if s >= box_thresh] if pairs: dt_polys, dt_scores = map(list, zip(*pairs)) else: dt_polys, dt_scores = [], [] cached_results = (input_img, dt_polys, dt_scores) output = visualize_detections(input_img, dt_polys, dt_scores, alpha=alpha, show_scores=show_scores) labels_status = "with scores" if show_scores else "without scores" info = ( f"✅ Found {len(dt_polys)} detections ({labels_status}) | " f"Model: {MODELS[model_name]} | " f"thresh: {thresh:.2f} | box_thresh: {box_thresh:.2f} | unclip: {unclip_ratio:.1f}" ) return output, info except Exception as e: print(f"[ERROR] process_image failed: {e}") cached_results = None error_msg = f"❌ Processing error: {str(e)}" if input_img is not None: return np.array(input_img), error_msg return np.zeros((512, 512, 3), dtype=np.uint8), error_msg if __name__ == "__main__": print(f"🚀 Starting PP-OCRv6 Text Detection App") print(f"🤖 Available models: {len(MODELS)}") custom_css = """ .gradio-container { max-width: 100% !important; padding: 15px !important; } .control-panel { background: #f8f9fa; border-radius: 12px; border: 1px solid #e9ecef; padding: 20px; margin-bottom: 15px; } .results-panel { background: #f8f9fa; border-radius: 12px; border: 1px solid #e9ecef; padding: 20px; min-height: 600px; } /* Gradio 5.x renders the image drop-zone with border-style:dashed via the .placeholder class. Override to match the original solid look. */ .placeholder { border-style: solid !important; } """ with gr.Blocks( title="📄 PP-OCRv6 Text Detection", theme=gr.themes.Soft(), css=custom_css ) as demo: gr.HTML("""
Polygon-level text localisation with PP-OCRv6 models