import json import numpy as np import onnxruntime as ort from PIL import Image # --- Fix gradio_client bug: json_schema_to_python_type crashes on bool schemas # (TypeError: argument of type 'bool' is not iterable) -> breaks /info & API --- import gradio_client.utils as _gcu _orig_j2p = _gcu._json_schema_to_python_type def _safe_j2p(schema, defs=None): if isinstance(schema, bool): return "Any" return _orig_j2p(schema, defs) _gcu._json_schema_to_python_type = _safe_j2p _orig_get_type = _gcu.get_type def _safe_get_type(schema): if not isinstance(schema, dict): return "Any" return _orig_get_type(schema) _gcu.get_type = _safe_get_type import gradio as gr from huggingface_hub import hf_hub_download MODEL_REPO = "Boos4721/cpdaily-ocr" # 从 HF Hub 拉取模型与字符表 model_path = hf_hub_download(MODEL_REPO, "cpdaily_captcha_ocr_fp16.onnx") charset_path = hf_hub_download(MODEL_REPO, "charset.json") CHARS = json.load(open(charset_path)) # ["", "A", "B", ...] IMG_H, IMG_W = 32, 160 _opts = ort.SessionOptions() _opts.log_severity_level = 3 _sess = ort.InferenceSession(model_path, sess_options=_opts, providers=["CPUExecutionProvider"]) _inp = _sess.get_inputs()[0].name PLACEHOLDER = ( "
" "
" "
上传或拖入一张验证码图片开始识别
" "
" ) def _result_html(text, avg): pct = avg * 100 color = "#16a34a" if pct >= 95 else ("#f59e0b" if pct >= 80 else "#dc2626") chars = "".join(f"{c}" for c in text) return ( "
" f"
{chars}
" "
" "平均置信度 / Confidence" f"{pct:.1f}%" "
" "
" f"
" "
" "
" ) def _infer(image): """核心推理:返回 (text, avg_confidence, per_char_confidences)。""" img = Image.fromarray(image).convert("L").resize((IMG_W, IMG_H), Image.BILINEAR) x = (np.asarray(img, dtype=np.float32) / 255.0)[None, None, :, :] logits = _sess.run(None, {_inp: x})[0][0] # [T, 63] log-softmax prob = np.exp(logits) idx = logits.argmax(-1) conf = prob.max(-1) out, confs, prev = [], [], -1 for t, p in enumerate(idx): if p != prev and p != 0: # CTC greedy: 去重 + 去 blank out.append(CHARS[p]); confs.append(float(conf[t])) prev = p text = "".join(out) avg = float(np.mean(confs)) if confs else 0.0 return text, avg, confs def recognize(image): if image is None: return PLACEHOLDER text, avg, _ = _infer(image) return _result_html(text, avg) def recognize_json(image): """JSON API 端点:返回 {text, confidence, chars}。""" if image is None: return {"text": "", "confidence": 0.0, "chars": []} text, avg, confs = _infer(image) return { "text": text, "confidence": round(avg, 4), "chars": [{"char": c, "confidence": round(cf, 4)} for c, cf in zip(text, confs)], } CSS = """ .gradio-container {max-width: 880px !important; margin: 0 auto !important;} #title-block h1 {font-size: 1.7rem; margin-bottom: .2rem;} #title-block p, #title-block li {color: var(--body-text-color-subdued); font-size: .9rem; line-height: 1.5;} .result-card {min-height: 180px; display: flex; flex-direction: column; align-items: center; justify-content: center; gap: 14px; border: 1px solid var(--border-color-primary); border-radius: 14px; padding: 24px; background: var(--block-background-fill);} .result-card.empty {opacity: .7;} .result-text {font-size: 2.6rem; font-weight: 700; letter-spacing: .35rem; font-family: ui-monospace, "SF Mono", Menlo, monospace; line-height: 1;} .result-text .ch {display: inline-block; padding: 0 2px;} .result-hint {font-size: .9rem; color: var(--body-text-color-subdued); letter-spacing: normal; font-weight: 400;} .conf-row {display: flex; align-items: baseline; gap: 10px;} .conf-label {font-size: .85rem; color: var(--body-text-color-subdued);} .conf-val {font-size: 1.25rem; font-weight: 700;} .conf-bar {width: 70%; height: 8px; border-radius: 999px; background: var(--neutral-200); overflow: hidden;} .conf-fill {height: 100%; border-radius: 999px; transition: width .3s ease;} footer {display: none !important;} """ with gr.Blocks(title="cpdaily-ocr · 今日校园验证码识别", theme=gr.themes.Soft( primary_hue="orange", secondary_hue="blue"), css=CSS) as demo: with gr.Column(elem_id="title-block"): gr.Markdown( "# 🔠 cpdaily-ocr · 今日校园验证码识别\n" "上传一张今日校园 / Cpdaily 的 5 位验证码图片,自动识别其中字符。\n\n" "轻量 **CRNN + CTC** 模型 · 纯 ONNX · **1.07 MB** · **99.37%** 准确率" ) with gr.Row(equal_height=True): with gr.Column(scale=1): inp = gr.Image(label="验证码图片 / Captcha", type="numpy", height=200, sources=["upload", "clipboard"]) btn = gr.Button("识别 / Recognize", variant="primary", size="lg") with gr.Column(scale=1): out_html = gr.HTML(value=PLACEHOLDER) gr.Markdown("### 示例 / Examples") gr.Examples( examples=[["examples/sample1.png"], ["examples/sample2.png"], ["examples/sample3.png"]], inputs=inp, ) gr.Markdown( "字符集 A-Z / a-z / 0-9(62 类)· 输入灰度 32×160 · CTC 贪心解码 · " "模型 [`Boos4721/cpdaily-ocr`](https://huggingface.co/Boos4721/cpdaily-ocr)" ) # UI 交互(不暴露为可编程 API,避免重复端点) inp.change(recognize, inputs=inp, outputs=out_html, api_name=False) btn.click(recognize, inputs=inp, outputs=out_html, api_name="recognize") # ---- 纯 JSON API 端点(不渲染 UI):api_name="/predict_json" ---- api_in = gr.Image(type="numpy", visible=False) api_out = gr.JSON(visible=False) api_in.change(recognize_json, inputs=api_in, outputs=api_out, api_name="predict_json", show_progress="hidden") if __name__ == "__main__": demo.launch()