File size: 4,999 Bytes
9335bef 20a77f4 9335bef 20a77f4 9335bef 046067f 9335bef 642943a 9335bef 20a77f4 046067f 20a77f4 046067f 20a77f4 046067f 20a77f4 046067f 20a77f4 046067f 20a77f4 9335bef 20a77f4 9335bef 20a77f4 9335bef 20a77f4 9335bef 20a77f4 9335bef 20a77f4 9335bef 20a77f4 9335bef f89b961 20a77f4 9335bef 20a77f4 9335bef 20a77f4 9335bef 20a77f4 9335bef 20a77f4 9335bef 20a77f4 9335bef 20a77f4 9335bef 20a77f4 9335bef 20a77f4 9335bef 20a77f4 9335bef 20a77f4 9335bef f89b961 9335bef 20a77f4 9335bef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
"""
Korean License Plate OCR - KLPR v2 (Model v5)
Hugging Face Gradio App
"""
from __future__ import annotations
import gradio as gr
import gradio_client.utils as client_utils
import torch
import torch.nn as nn
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
# Work around gradio_client not handling boolean JSON schema nodes.
if not getattr(client_utils, "_patched_bool_schema", False):
_orig_json_schema_to_python_type = client_utils._json_schema_to_python_type
def _safe_json_schema_to_python_type(schema, defs=None):
if isinstance(schema, bool):
return "Any"
return _orig_json_schema_to_python_type(schema, defs)
client_utils._json_schema_to_python_type = _safe_json_schema_to_python_type
client_utils._patched_bool_schema = True
class CRNN(nn.Module):
def __init__(self, img_height, num_chars, rnn_hidden=256):
super().__init__()
self.cnn = nn.Sequential(
nn.Conv2d(1, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d((2, 2)),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d((2, 2)),
nn.Conv2d(128, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.MaxPool2d((2, 1)),
nn.Conv2d(256, 512, kernel_size=3, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.Conv2d(512, 512, kernel_size=3, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.MaxPool2d((2, 1)),
nn.Conv2d(512, 512, kernel_size=3, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.MaxPool2d((2, 1)),
)
self.rnn = nn.LSTM(512, rnn_hidden, bidirectional=True, num_layers=2, batch_first=True)
self.fc = nn.Linear(rnn_hidden * 2, num_chars)
def forward(self, x):
conv = self.cnn(x)
conv = conv.squeeze(2).permute(0, 2, 1)
rnn_out, _ = self.rnn(conv)
return self.fc(rnn_out)
def decode_predictions(outputs, itos, blank_idx=0):
preds = outputs.argmax(2).detach().cpu().numpy()
decoded = []
for pred in preds:
char_list = []
prev_idx = blank_idx
for idx in pred:
if idx != blank_idx and idx != prev_idx:
char_list.append(itos[int(idx)])
prev_idx = idx
decoded.append("".join(char_list))
return decoded
def preprocess_image(image, img_height=32, max_width=200):
if not isinstance(image, Image.Image):
if isinstance(image, np.ndarray):
image = Image.fromarray(image.astype("uint8"))
else:
image = Image.open(image)
image = image.convert("L")
w, h = image.size
new_w = min(int(img_height * w / h), max_width)
image = image.resize((new_w, img_height), Image.LANCZOS)
new_img = Image.new("L", (max_width, img_height), 255)
new_img.paste(image, (0, 0))
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)
return transform(new_img).unsqueeze(0)
print("๋ชจ๋ธ ๋ก๋ฉ ์ค...")
checkpoint_path = "best_ocr_one_line.pth"
checkpoint = torch.load(checkpoint_path, map_location="cpu")
img_h = checkpoint.get("img_h", 32)
max_w = checkpoint.get("max_w", 200)
itos = checkpoint["itos"]
num_chars = len(itos)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CRNN(img_h, num_chars, rnn_hidden=256).to(device)
model.load_state_dict(checkpoint["model_state"])
model.eval()
print(f"โ ๋ชจ๋ธ ๋ก๋ ์๋ฃ (Device: {device})")
print(f" - Epoch: {checkpoint.get('epoch', '?')}")
print(f" - Val Acc: {checkpoint.get('val_acc', '?'):.2%}")
def predict_license_plate(image):
if image is None:
return "์ด๋ฏธ์ง๋ฅผ ์
๋ก๋ํด ์ฃผ์ธ์."
try:
image_tensor = preprocess_image(image, img_h, max_w).to(device)
with torch.no_grad():
outputs = model(image_tensor).log_softmax(2)
predictions = decode_predictions(outputs, itos)
result = predictions[0]
return result if result else "(์ธ์ ๊ฒฐ๊ณผ ์์)"
except Exception as exc:
return f"์ค๋ฅ ๋ฐ์: {exc}"
demo = gr.Interface(
fn=predict_license_plate,
inputs=gr.Image(type="pil", label="๋ฒํธํ ์ด๋ฏธ์ง"),
outputs=gr.Textbox(label="์ธ์ ๊ฒฐ๊ณผ"),
title="๐ ํ๊ตญ ๋ฒํธํ OCR - KLPR v2",
description=(
"๋ฒํธํ ์ด๋ฏธ์ง์์ ๋ฌธ์๋ฅผ ์ธ์ํฉ๋๋ค.\n\n"
"**๋ชจ๋ธ ์ ๋ณด:** CRNN (CNN + BiLSTM + CTC)\n"
"**์
๋ ฅ:** ๋ฒํธํ ์ด๋ฏธ์ง 1์ฅ"
),
api_name="predict",
cache_examples=False,
)
if __name__ == "__main__":
demo.launch()
|