|
|
""" |
|
|
Korean License Plate OCR - KLPR v2 (Model v5) |
|
|
Hugging Face Gradio App |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import gradio as gr |
|
|
import gradio_client.utils as client_utils |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from PIL import Image |
|
|
import torchvision.transforms as transforms |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
if not getattr(client_utils, "_patched_bool_schema", False): |
|
|
_orig_json_schema_to_python_type = client_utils._json_schema_to_python_type |
|
|
|
|
|
def _safe_json_schema_to_python_type(schema, defs=None): |
|
|
if isinstance(schema, bool): |
|
|
return "Any" |
|
|
return _orig_json_schema_to_python_type(schema, defs) |
|
|
|
|
|
client_utils._json_schema_to_python_type = _safe_json_schema_to_python_type |
|
|
client_utils._patched_bool_schema = True |
|
|
|
|
|
|
|
|
class CRNN(nn.Module): |
|
|
def __init__(self, img_height, num_chars, rnn_hidden=256): |
|
|
super().__init__() |
|
|
self.cnn = nn.Sequential( |
|
|
nn.Conv2d(1, 64, kernel_size=3, padding=1), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.MaxPool2d((2, 2)), |
|
|
nn.Conv2d(64, 128, kernel_size=3, padding=1), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.MaxPool2d((2, 2)), |
|
|
nn.Conv2d(128, 256, kernel_size=3, padding=1), |
|
|
nn.BatchNorm2d(256), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Conv2d(256, 256, kernel_size=3, padding=1), |
|
|
nn.BatchNorm2d(256), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.MaxPool2d((2, 1)), |
|
|
nn.Conv2d(256, 512, kernel_size=3, padding=1), |
|
|
nn.BatchNorm2d(512), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Conv2d(512, 512, kernel_size=3, padding=1), |
|
|
nn.BatchNorm2d(512), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.MaxPool2d((2, 1)), |
|
|
nn.Conv2d(512, 512, kernel_size=3, padding=1), |
|
|
nn.BatchNorm2d(512), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.MaxPool2d((2, 1)), |
|
|
) |
|
|
self.rnn = nn.LSTM(512, rnn_hidden, bidirectional=True, num_layers=2, batch_first=True) |
|
|
self.fc = nn.Linear(rnn_hidden * 2, num_chars) |
|
|
|
|
|
def forward(self, x): |
|
|
conv = self.cnn(x) |
|
|
conv = conv.squeeze(2).permute(0, 2, 1) |
|
|
rnn_out, _ = self.rnn(conv) |
|
|
return self.fc(rnn_out) |
|
|
|
|
|
|
|
|
def decode_predictions(outputs, itos, blank_idx=0): |
|
|
preds = outputs.argmax(2).detach().cpu().numpy() |
|
|
decoded = [] |
|
|
for pred in preds: |
|
|
char_list = [] |
|
|
prev_idx = blank_idx |
|
|
for idx in pred: |
|
|
if idx != blank_idx and idx != prev_idx: |
|
|
char_list.append(itos[int(idx)]) |
|
|
prev_idx = idx |
|
|
decoded.append("".join(char_list)) |
|
|
return decoded |
|
|
|
|
|
|
|
|
def preprocess_image(image, img_height=32, max_width=200): |
|
|
if not isinstance(image, Image.Image): |
|
|
if isinstance(image, np.ndarray): |
|
|
image = Image.fromarray(image.astype("uint8")) |
|
|
else: |
|
|
image = Image.open(image) |
|
|
|
|
|
image = image.convert("L") |
|
|
w, h = image.size |
|
|
new_w = min(int(img_height * w / h), max_width) |
|
|
image = image.resize((new_w, img_height), Image.LANCZOS) |
|
|
|
|
|
new_img = Image.new("L", (max_width, img_height), 255) |
|
|
new_img.paste(image, (0, 0)) |
|
|
|
|
|
transform = transforms.Compose( |
|
|
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] |
|
|
) |
|
|
return transform(new_img).unsqueeze(0) |
|
|
|
|
|
|
|
|
print("๋ชจ๋ธ ๋ก๋ฉ ์ค...") |
|
|
checkpoint_path = "best_ocr_one_line.pth" |
|
|
checkpoint = torch.load(checkpoint_path, map_location="cpu") |
|
|
|
|
|
img_h = checkpoint.get("img_h", 32) |
|
|
max_w = checkpoint.get("max_w", 200) |
|
|
itos = checkpoint["itos"] |
|
|
num_chars = len(itos) |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
model = CRNN(img_h, num_chars, rnn_hidden=256).to(device) |
|
|
model.load_state_dict(checkpoint["model_state"]) |
|
|
model.eval() |
|
|
|
|
|
print(f"โ ๋ชจ๋ธ ๋ก๋ ์๋ฃ (Device: {device})") |
|
|
print(f" - Epoch: {checkpoint.get('epoch', '?')}") |
|
|
print(f" - Val Acc: {checkpoint.get('val_acc', '?'):.2%}") |
|
|
|
|
|
|
|
|
def predict_license_plate(image): |
|
|
if image is None: |
|
|
return "์ด๋ฏธ์ง๋ฅผ ์
๋ก๋ํด ์ฃผ์ธ์." |
|
|
try: |
|
|
image_tensor = preprocess_image(image, img_h, max_w).to(device) |
|
|
with torch.no_grad(): |
|
|
outputs = model(image_tensor).log_softmax(2) |
|
|
predictions = decode_predictions(outputs, itos) |
|
|
result = predictions[0] |
|
|
return result if result else "(์ธ์ ๊ฒฐ๊ณผ ์์)" |
|
|
except Exception as exc: |
|
|
return f"์ค๋ฅ ๋ฐ์: {exc}" |
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=predict_license_plate, |
|
|
inputs=gr.Image(type="pil", label="๋ฒํธํ ์ด๋ฏธ์ง"), |
|
|
outputs=gr.Textbox(label="์ธ์ ๊ฒฐ๊ณผ"), |
|
|
title="๐ ํ๊ตญ ๋ฒํธํ OCR - KLPR v2", |
|
|
description=( |
|
|
"๋ฒํธํ ์ด๋ฏธ์ง์์ ๋ฌธ์๋ฅผ ์ธ์ํฉ๋๋ค.\n\n" |
|
|
"**๋ชจ๋ธ ์ ๋ณด:** CRNN (CNN + BiLSTM + CTC)\n" |
|
|
"**์
๋ ฅ:** ๋ฒํธํ ์ด๋ฏธ์ง 1์ฅ" |
|
|
), |
|
|
api_name="predict", |
|
|
cache_examples=False, |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|