cpdaily / app.py
Boos4721's picture
dedupe API: only /recognize + /predict_json
46520ac verified
Raw
History Blame Contribute Delete
6.46 kB
import json
import numpy as np
import onnxruntime as ort
from PIL import Image
# --- Fix gradio_client bug: json_schema_to_python_type crashes on bool schemas
# (TypeError: argument of type 'bool' is not iterable) -> breaks /info & API ---
import gradio_client.utils as _gcu
_orig_j2p = _gcu._json_schema_to_python_type
def _safe_j2p(schema, defs=None):
if isinstance(schema, bool):
return "Any"
return _orig_j2p(schema, defs)
_gcu._json_schema_to_python_type = _safe_j2p
_orig_get_type = _gcu.get_type
def _safe_get_type(schema):
if not isinstance(schema, dict):
return "Any"
return _orig_get_type(schema)
_gcu.get_type = _safe_get_type
import gradio as gr
from huggingface_hub import hf_hub_download
MODEL_REPO = "Boos4721/cpdaily-ocr"
# 从 HF Hub 拉取模型与字符表
model_path = hf_hub_download(MODEL_REPO, "cpdaily_captcha_ocr_fp16.onnx")
charset_path = hf_hub_download(MODEL_REPO, "charset.json")
CHARS = json.load(open(charset_path)) # ["<blank>", "A", "B", ...]
IMG_H, IMG_W = 32, 160
_opts = ort.SessionOptions()
_opts.log_severity_level = 3
_sess = ort.InferenceSession(model_path, sess_options=_opts, providers=["CPUExecutionProvider"])
_inp = _sess.get_inputs()[0].name
PLACEHOLDER = (
"<div class='result-card empty'>"
"<div class='result-text'>—</div>"
"<div class='result-hint'>上传或拖入一张验证码图片开始识别</div>"
"</div>"
)
def _result_html(text, avg):
pct = avg * 100
color = "#16a34a" if pct >= 95 else ("#f59e0b" if pct >= 80 else "#dc2626")
chars = "".join(f"<span class='ch'>{c}</span>" for c in text)
return (
"<div class='result-card'>"
f"<div class='result-text'>{chars}</div>"
"<div class='conf-row'>"
"<span class='conf-label'>平均置信度 / Confidence</span>"
f"<span class='conf-val' style='color:{color}'>{pct:.1f}%</span>"
"</div>"
"<div class='conf-bar'>"
f"<div class='conf-fill' style='width:{pct:.1f}%;background:{color}'></div>"
"</div>"
"</div>"
)
def _infer(image):
"""核心推理:返回 (text, avg_confidence, per_char_confidences)。"""
img = Image.fromarray(image).convert("L").resize((IMG_W, IMG_H), Image.BILINEAR)
x = (np.asarray(img, dtype=np.float32) / 255.0)[None, None, :, :]
logits = _sess.run(None, {_inp: x})[0][0] # [T, 63] log-softmax
prob = np.exp(logits)
idx = logits.argmax(-1)
conf = prob.max(-1)
out, confs, prev = [], [], -1
for t, p in enumerate(idx):
if p != prev and p != 0: # CTC greedy: 去重 + 去 blank
out.append(CHARS[p]); confs.append(float(conf[t]))
prev = p
text = "".join(out)
avg = float(np.mean(confs)) if confs else 0.0
return text, avg, confs
def recognize(image):
if image is None:
return PLACEHOLDER
text, avg, _ = _infer(image)
return _result_html(text, avg)
def recognize_json(image):
"""JSON API 端点:返回 {text, confidence, chars}。"""
if image is None:
return {"text": "", "confidence": 0.0, "chars": []}
text, avg, confs = _infer(image)
return {
"text": text,
"confidence": round(avg, 4),
"chars": [{"char": c, "confidence": round(cf, 4)} for c, cf in zip(text, confs)],
}
CSS = """
.gradio-container {max-width: 880px !important; margin: 0 auto !important;}
#title-block h1 {font-size: 1.7rem; margin-bottom: .2rem;}
#title-block p, #title-block li {color: var(--body-text-color-subdued); font-size: .9rem; line-height: 1.5;}
.result-card {min-height: 180px; display: flex; flex-direction: column;
align-items: center; justify-content: center; gap: 14px;
border: 1px solid var(--border-color-primary); border-radius: 14px;
padding: 24px; background: var(--block-background-fill);}
.result-card.empty {opacity: .7;}
.result-text {font-size: 2.6rem; font-weight: 700; letter-spacing: .35rem;
font-family: ui-monospace, "SF Mono", Menlo, monospace; line-height: 1;}
.result-text .ch {display: inline-block; padding: 0 2px;}
.result-hint {font-size: .9rem; color: var(--body-text-color-subdued); letter-spacing: normal; font-weight: 400;}
.conf-row {display: flex; align-items: baseline; gap: 10px;}
.conf-label {font-size: .85rem; color: var(--body-text-color-subdued);}
.conf-val {font-size: 1.25rem; font-weight: 700;}
.conf-bar {width: 70%; height: 8px; border-radius: 999px;
background: var(--neutral-200); overflow: hidden;}
.conf-fill {height: 100%; border-radius: 999px; transition: width .3s ease;}
footer {display: none !important;}
"""
with gr.Blocks(title="cpdaily-ocr · 今日校园验证码识别", theme=gr.themes.Soft(
primary_hue="orange", secondary_hue="blue"), css=CSS) as demo:
with gr.Column(elem_id="title-block"):
gr.Markdown(
"# 🔠 cpdaily-ocr · 今日校园验证码识别\n"
"上传一张今日校园 / Cpdaily 的 5 位验证码图片,自动识别其中字符。\n\n"
"轻量 **CRNN + CTC** 模型 · 纯 ONNX · **1.07 MB** · **99.37%** 准确率"
)
with gr.Row(equal_height=True):
with gr.Column(scale=1):
inp = gr.Image(label="验证码图片 / Captcha", type="numpy",
height=200, sources=["upload", "clipboard"])
btn = gr.Button("识别 / Recognize", variant="primary", size="lg")
with gr.Column(scale=1):
out_html = gr.HTML(value=PLACEHOLDER)
gr.Markdown("### 示例 / Examples")
gr.Examples(
examples=[["examples/sample1.png"], ["examples/sample2.png"], ["examples/sample3.png"]],
inputs=inp,
)
gr.Markdown(
"字符集 A-Z / a-z / 0-9(62 类)· 输入灰度 32×160 · CTC 贪心解码 · "
"模型 [`Boos4721/cpdaily-ocr`](https://huggingface.co/Boos4721/cpdaily-ocr)"
)
# UI 交互(不暴露为可编程 API,避免重复端点)
inp.change(recognize, inputs=inp, outputs=out_html, api_name=False)
btn.click(recognize, inputs=inp, outputs=out_html, api_name="recognize")
# ---- 纯 JSON API 端点(不渲染 UI):api_name="/predict_json" ----
api_in = gr.Image(type="numpy", visible=False)
api_out = gr.JSON(visible=False)
api_in.change(recognize_json, inputs=api_in, outputs=api_out,
api_name="predict_json", show_progress="hidden")
if __name__ == "__main__":
demo.launch()