File size: 4,999 Bytes
9335bef
20a77f4
9335bef
 
 
20a77f4
 
9335bef
046067f
9335bef
 
 
 
642943a
9335bef
20a77f4
046067f
20a77f4
046067f
20a77f4
046067f
 
20a77f4
046067f
20a77f4
046067f
 
20a77f4
9335bef
 
20a77f4
9335bef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20a77f4
9335bef
 
 
 
 
 
 
 
20a77f4
9335bef
 
20a77f4
 
9335bef
 
 
 
 
 
 
 
20a77f4
9335bef
 
20a77f4
9335bef
f89b961
 
20a77f4
 
 
9335bef
20a77f4
9335bef
 
 
 
20a77f4
9335bef
 
20a77f4
 
 
 
9335bef
 
 
20a77f4
 
9335bef
20a77f4
 
 
9335bef
 
20a77f4
9335bef
20a77f4
9335bef
 
 
 
 
 
20a77f4
9335bef
 
20a77f4
9335bef
 
 
 
 
 
 
20a77f4
 
9335bef
 
 
 
f89b961
9335bef
20a77f4
 
 
 
 
 
 
 
9335bef
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
"""
Korean License Plate OCR - KLPR v2 (Model v5)
Hugging Face Gradio App
"""

from __future__ import annotations

import gradio as gr
import gradio_client.utils as client_utils
import torch
import torch.nn as nn
from PIL import Image
import torchvision.transforms as transforms
import numpy as np

# Work around gradio_client not handling boolean JSON schema nodes.
if not getattr(client_utils, "_patched_bool_schema", False):
    _orig_json_schema_to_python_type = client_utils._json_schema_to_python_type

    def _safe_json_schema_to_python_type(schema, defs=None):
        if isinstance(schema, bool):
            return "Any"
        return _orig_json_schema_to_python_type(schema, defs)

    client_utils._json_schema_to_python_type = _safe_json_schema_to_python_type
    client_utils._patched_bool_schema = True


class CRNN(nn.Module):
    def __init__(self, img_height, num_chars, rnn_hidden=256):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d((2, 2)),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d((2, 2)),
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d((2, 1)),
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.MaxPool2d((2, 1)),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.MaxPool2d((2, 1)),
        )
        self.rnn = nn.LSTM(512, rnn_hidden, bidirectional=True, num_layers=2, batch_first=True)
        self.fc = nn.Linear(rnn_hidden * 2, num_chars)

    def forward(self, x):
        conv = self.cnn(x)
        conv = conv.squeeze(2).permute(0, 2, 1)
        rnn_out, _ = self.rnn(conv)
        return self.fc(rnn_out)


def decode_predictions(outputs, itos, blank_idx=0):
    preds = outputs.argmax(2).detach().cpu().numpy()
    decoded = []
    for pred in preds:
        char_list = []
        prev_idx = blank_idx
        for idx in pred:
            if idx != blank_idx and idx != prev_idx:
                char_list.append(itos[int(idx)])
            prev_idx = idx
        decoded.append("".join(char_list))
    return decoded


def preprocess_image(image, img_height=32, max_width=200):
    if not isinstance(image, Image.Image):
        if isinstance(image, np.ndarray):
            image = Image.fromarray(image.astype("uint8"))
        else:
            image = Image.open(image)

    image = image.convert("L")
    w, h = image.size
    new_w = min(int(img_height * w / h), max_width)
    image = image.resize((new_w, img_height), Image.LANCZOS)

    new_img = Image.new("L", (max_width, img_height), 255)
    new_img.paste(image, (0, 0))

    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
    )
    return transform(new_img).unsqueeze(0)


print("๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘...")
checkpoint_path = "best_ocr_one_line.pth"
checkpoint = torch.load(checkpoint_path, map_location="cpu")

img_h = checkpoint.get("img_h", 32)
max_w = checkpoint.get("max_w", 200)
itos = checkpoint["itos"]
num_chars = len(itos)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CRNN(img_h, num_chars, rnn_hidden=256).to(device)
model.load_state_dict(checkpoint["model_state"])
model.eval()

print(f"โœ“ ๋ชจ๋ธ ๋กœ๋“œ ์™„๋ฃŒ (Device: {device})")
print(f"  - Epoch: {checkpoint.get('epoch', '?')}")
print(f"  - Val Acc: {checkpoint.get('val_acc', '?'):.2%}")


def predict_license_plate(image):
    if image is None:
        return "์ด๋ฏธ์ง€๋ฅผ ์—…๋กœ๋“œํ•ด ์ฃผ์„ธ์š”."
    try:
        image_tensor = preprocess_image(image, img_h, max_w).to(device)
        with torch.no_grad():
            outputs = model(image_tensor).log_softmax(2)
            predictions = decode_predictions(outputs, itos)
        result = predictions[0]
        return result if result else "(์ธ์‹ ๊ฒฐ๊ณผ ์—†์Œ)"
    except Exception as exc:
        return f"์˜ค๋ฅ˜ ๋ฐœ์ƒ: {exc}"


demo = gr.Interface(
    fn=predict_license_plate,
    inputs=gr.Image(type="pil", label="๋ฒˆํ˜ธํŒ ์ด๋ฏธ์ง€"),
    outputs=gr.Textbox(label="์ธ์‹ ๊ฒฐ๊ณผ"),
    title="๐Ÿš˜ ํ•œ๊ตญ ๋ฒˆํ˜ธํŒ OCR - KLPR v2",
    description=(
        "๋ฒˆํ˜ธํŒ ์ด๋ฏธ์ง€์—์„œ ๋ฌธ์ž๋ฅผ ์ธ์‹ํ•ฉ๋‹ˆ๋‹ค.\n\n"
        "**๋ชจ๋ธ ์ •๋ณด:** CRNN (CNN + BiLSTM + CTC)\n"
        "**์ž…๋ ฅ:** ๋ฒˆํ˜ธํŒ ์ด๋ฏธ์ง€ 1์žฅ"
    ),
    api_name="predict",
    cache_examples=False,
)

if __name__ == "__main__":
    demo.launch()