KLPR_v2 / app.py
nice22090's picture
Rebuild app for HF Spaces
20a77f4
"""
Korean License Plate OCR - KLPR v2 (Model v5)
Hugging Face Gradio App
"""
from __future__ import annotations
import gradio as gr
import gradio_client.utils as client_utils
import torch
import torch.nn as nn
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
# Work around gradio_client not handling boolean JSON schema nodes.
if not getattr(client_utils, "_patched_bool_schema", False):
_orig_json_schema_to_python_type = client_utils._json_schema_to_python_type
def _safe_json_schema_to_python_type(schema, defs=None):
if isinstance(schema, bool):
return "Any"
return _orig_json_schema_to_python_type(schema, defs)
client_utils._json_schema_to_python_type = _safe_json_schema_to_python_type
client_utils._patched_bool_schema = True
class CRNN(nn.Module):
def __init__(self, img_height, num_chars, rnn_hidden=256):
super().__init__()
self.cnn = nn.Sequential(
nn.Conv2d(1, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d((2, 2)),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d((2, 2)),
nn.Conv2d(128, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.MaxPool2d((2, 1)),
nn.Conv2d(256, 512, kernel_size=3, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.Conv2d(512, 512, kernel_size=3, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.MaxPool2d((2, 1)),
nn.Conv2d(512, 512, kernel_size=3, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.MaxPool2d((2, 1)),
)
self.rnn = nn.LSTM(512, rnn_hidden, bidirectional=True, num_layers=2, batch_first=True)
self.fc = nn.Linear(rnn_hidden * 2, num_chars)
def forward(self, x):
conv = self.cnn(x)
conv = conv.squeeze(2).permute(0, 2, 1)
rnn_out, _ = self.rnn(conv)
return self.fc(rnn_out)
def decode_predictions(outputs, itos, blank_idx=0):
preds = outputs.argmax(2).detach().cpu().numpy()
decoded = []
for pred in preds:
char_list = []
prev_idx = blank_idx
for idx in pred:
if idx != blank_idx and idx != prev_idx:
char_list.append(itos[int(idx)])
prev_idx = idx
decoded.append("".join(char_list))
return decoded
def preprocess_image(image, img_height=32, max_width=200):
if not isinstance(image, Image.Image):
if isinstance(image, np.ndarray):
image = Image.fromarray(image.astype("uint8"))
else:
image = Image.open(image)
image = image.convert("L")
w, h = image.size
new_w = min(int(img_height * w / h), max_width)
image = image.resize((new_w, img_height), Image.LANCZOS)
new_img = Image.new("L", (max_width, img_height), 255)
new_img.paste(image, (0, 0))
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)
return transform(new_img).unsqueeze(0)
print("๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘...")
checkpoint_path = "best_ocr_one_line.pth"
checkpoint = torch.load(checkpoint_path, map_location="cpu")
img_h = checkpoint.get("img_h", 32)
max_w = checkpoint.get("max_w", 200)
itos = checkpoint["itos"]
num_chars = len(itos)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CRNN(img_h, num_chars, rnn_hidden=256).to(device)
model.load_state_dict(checkpoint["model_state"])
model.eval()
print(f"โœ“ ๋ชจ๋ธ ๋กœ๋“œ ์™„๋ฃŒ (Device: {device})")
print(f" - Epoch: {checkpoint.get('epoch', '?')}")
print(f" - Val Acc: {checkpoint.get('val_acc', '?'):.2%}")
def predict_license_plate(image):
if image is None:
return "์ด๋ฏธ์ง€๋ฅผ ์—…๋กœ๋“œํ•ด ์ฃผ์„ธ์š”."
try:
image_tensor = preprocess_image(image, img_h, max_w).to(device)
with torch.no_grad():
outputs = model(image_tensor).log_softmax(2)
predictions = decode_predictions(outputs, itos)
result = predictions[0]
return result if result else "(์ธ์‹ ๊ฒฐ๊ณผ ์—†์Œ)"
except Exception as exc:
return f"์˜ค๋ฅ˜ ๋ฐœ์ƒ: {exc}"
demo = gr.Interface(
fn=predict_license_plate,
inputs=gr.Image(type="pil", label="๋ฒˆํ˜ธํŒ ์ด๋ฏธ์ง€"),
outputs=gr.Textbox(label="์ธ์‹ ๊ฒฐ๊ณผ"),
title="๐Ÿš˜ ํ•œ๊ตญ ๋ฒˆํ˜ธํŒ OCR - KLPR v2",
description=(
"๋ฒˆํ˜ธํŒ ์ด๋ฏธ์ง€์—์„œ ๋ฌธ์ž๋ฅผ ์ธ์‹ํ•ฉ๋‹ˆ๋‹ค.\n\n"
"**๋ชจ๋ธ ์ •๋ณด:** CRNN (CNN + BiLSTM + CTC)\n"
"**์ž…๋ ฅ:** ๋ฒˆํ˜ธํŒ ์ด๋ฏธ์ง€ 1์žฅ"
),
api_name="predict",
cache_examples=False,
)
if __name__ == "__main__":
demo.launch()