enpaiva's picture
Update app.py
67aeb09 verified
Raw
History Blame Contribute Delete
10.4 kB
import os
os.environ["GRADIO_TEMP_DIR"] = "./tmp"
import gradio as gr
import numpy as np
import cv2
from PIL import Image
# == Model configurations ==
MODELS = {
"PP-OCRv6 Medium Det": "PP-OCRv6_medium_det",
"PP-OCRv6 Small Det": "PP-OCRv6_small_det",
"PP-OCRv6 Tiny Det": "PP-OCRv6_tiny_det",
}
# == Global model variables ==
current_model = None
current_model_key = None # (display_name, thresh, unclip_ratio)
cached_results = None # (pil_img, dt_polys, dt_scores)
_COLOR = (0, 140, 255) # BGR
def load_model_if_needed(model_name, thresh, unclip_ratio):
global current_model, current_model_key
key = (model_name, round(thresh, 3), round(unclip_ratio, 2))
if current_model_key == key and current_model is not None:
return True
try:
from paddleocr import TextDetection
paddle_name = MODELS[model_name]
print(f"Loading {paddle_name} thresh={thresh} unclip_ratio={unclip_ratio}")
current_model = TextDetection(
model_name=paddle_name,
engine="transformers",
thresh=thresh,
unclip_ratio=unclip_ratio,
)
current_model_key = key
return True
except Exception as e:
print(f"Error loading model: {e}")
return False
def visualize_detections(image_input, dt_polys, dt_scores, alpha=0.3, show_scores=True):
if isinstance(image_input, Image.Image):
image = cv2.cvtColor(np.array(image_input), cv2.COLOR_RGB2BGR)
else:
image = cv2.cvtColor(image_input, cv2.COLOR_RGB2BGR)
if len(dt_polys) == 0:
return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
overlay = image.copy()
for poly, score in zip(dt_polys, dt_scores):
pts = np.array(poly, dtype=np.int32).reshape(-1, 1, 2)
cv2.fillPoly(overlay, [pts], _COLOR)
cv2.polylines(image, [pts], isClosed=True, color=_COLOR, thickness=3)
if show_scores:
ax, ay = int(pts[0, 0, 0]), int(pts[0, 0, 1])
text = f"{score:.3f}"
(tw, th), bl = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.8, 2)
cv2.rectangle(image,
(ax, ay - th - bl - 4),
(ax + tw + 8, ay),
_COLOR, -1)
cv2.putText(image, text, (ax + 4, ay - 6),
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2)
cv2.addWeighted(overlay, alpha, image, 1 - alpha, 0, image)
return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
def toggle_labels_visualization(show_scores, alpha):
global cached_results
if cached_results is None:
return None, "⚠️ No cached results. Please analyze an image first."
input_img, dt_polys, dt_scores = cached_results
output = visualize_detections(input_img, dt_polys, dt_scores, alpha=alpha, show_scores=show_scores)
labels_status = "with scores" if show_scores else "without scores"
info = f"βœ… Visualization updated ({labels_status}) | {len(dt_polys)} detections"
return output, info
def process_image(input_img, model_name, thresh, box_thresh, unclip_ratio, alpha, show_scores):
global cached_results
if input_img is None:
return None, "❌ Please upload an image first."
if not load_model_if_needed(model_name, thresh, unclip_ratio):
return None, f"❌ Failed to load model {model_name}."
try:
if isinstance(input_img, np.ndarray):
input_img = Image.fromarray(input_img)
if input_img.mode != "RGB":
input_img = input_img.convert("RGB")
results = current_model.predict(input=np.array(input_img), batch_size=1)
if not results:
cached_results = None
return np.array(input_img), "ℹ️ No detections found."
res_dict = results[0].res if hasattr(results[0], "res") else results[0]
dt_polys = res_dict.get("dt_polys", [])
dt_scores = res_dict.get("dt_scores", [])
pairs = [(p, s) for p, s in zip(dt_polys, dt_scores) if s >= box_thresh]
if pairs:
dt_polys, dt_scores = map(list, zip(*pairs))
else:
dt_polys, dt_scores = [], []
cached_results = (input_img, dt_polys, dt_scores)
output = visualize_detections(input_img, dt_polys, dt_scores, alpha=alpha, show_scores=show_scores)
labels_status = "with scores" if show_scores else "without scores"
info = (
f"βœ… Found {len(dt_polys)} detections ({labels_status}) | "
f"Model: {MODELS[model_name]} | "
f"thresh: {thresh:.2f} | box_thresh: {box_thresh:.2f} | unclip: {unclip_ratio:.1f}"
)
return output, info
except Exception as e:
print(f"[ERROR] process_image failed: {e}")
cached_results = None
error_msg = f"❌ Processing error: {str(e)}"
if input_img is not None:
return np.array(input_img), error_msg
return np.zeros((512, 512, 3), dtype=np.uint8), error_msg
if __name__ == "__main__":
print(f"πŸš€ Starting PP-OCRv6 Text Detection App")
print(f"πŸ€– Available models: {len(MODELS)}")
custom_css = """
.gradio-container {
max-width: 100% !important;
padding: 15px !important;
}
.control-panel {
background: #f8f9fa;
border-radius: 12px;
border: 1px solid #e9ecef;
padding: 20px;
margin-bottom: 15px;
}
.results-panel {
background: #f8f9fa;
border-radius: 12px;
border: 1px solid #e9ecef;
padding: 20px;
min-height: 600px;
}
/* Gradio 5.x renders the image drop-zone with border-style:dashed via
the .placeholder class. Override to match the original solid look. */
.placeholder {
border-style: solid !important;
}
"""
with gr.Blocks(
title="πŸ“„ PP-OCRv6 Text Detection",
theme=gr.themes.Soft(),
css=custom_css
) as demo:
gr.HTML("""
<div style='text-align: center; padding: 20px; background: linear-gradient(135deg, #f97316 0%, #c2410c 100%); color: white; border-radius: 12px; margin-bottom: 20px;'>
<h1 style='margin: 0; font-size: 2.5em;'>πŸ” PP-OCRv6 Text Detection</h1>
<p style='margin: 8px 0 0 0; font-size: 1.1em; opacity: 0.9;'>Polygon-level text localisation with PP-OCRv6 models</p>
</div>
""")
with gr.Row():
# LEFT COLUMN - Controls
with gr.Column(scale=1):
with gr.Group(elem_classes=["control-panel"]):
# 1. Image Upload (first)
gr.HTML("<h3>πŸ“„ Upload Image</h3>")
input_img = gr.Image(
label="Document Image",
type="pil",
height=300,
interactive=True
)
# 2. Model Selection
model_dropdown = gr.Dropdown(
choices=list(MODELS.keys()),
value="PP-OCRv6 Medium Det",
label="AI Model",
info="Model will be loaded automatically",
interactive=True
)
# 3. All parameters together (third)
with gr.Row():
thresh_slider = gr.Slider(
minimum=0.1, maximum=0.9, value=0.3, step=0.05,
label="Pixel Threshold", info="Detection threshold"
)
box_thresh_slider = gr.Slider(
minimum=0.1, maximum=0.99, value=0.6, step=0.05,
label="Box Confidence", info="Polygon score filter"
)
with gr.Row():
unclip_slider = gr.Slider(
minimum=1.0, maximum=3.0, value=1.5, step=0.1,
label="Unclip Ratio", info="Region expansion factor", scale=2
)
alpha_slider = gr.Slider(
minimum=0.0, maximum=1.0, value=0.3, step=0.1,
label="Transparency", scale=1
)
# 4. Analyze button (last)
analyze_btn = gr.Button("πŸ” Detect Text", variant="primary", size="lg")
# RIGHT COLUMN - Results
with gr.Column(scale=1):
with gr.Group(elem_classes=["results-panel"]):
gr.HTML("<h3>🎯 Detection Results</h3>")
output_img = gr.Image(
label="Detected Text Regions",
type="numpy",
height=450,
interactive=False
)
detection_info = gr.Textbox(
label="Detection Summary",
value="",
interactive=False,
lines=2,
placeholder="Results will appear here..."
)
show_scores_checkbox = gr.Checkbox(
value=True,
label="Show Confidence Scores",
info="Toggle scores without reprocessing",
interactive=True
)
# Event Handlers
analyze_btn.click(
fn=process_image,
inputs=[input_img, model_dropdown, thresh_slider, box_thresh_slider,
unclip_slider, alpha_slider, show_scores_checkbox],
outputs=[output_img, detection_info]
)
show_scores_checkbox.change(
fn=toggle_labels_visualization,
inputs=[show_scores_checkbox, alpha_slider],
outputs=[output_img, detection_info]
)
alpha_slider.change(
fn=toggle_labels_visualization,
inputs=[show_scores_checkbox, alpha_slider],
outputs=[output_img, detection_info]
)
demo.launch(
server_name="0.0.0.0",
server_port=7860,
debug=True,
share=False,
show_error=True,
inbrowser=True,
)