nice22090 commited on
Commit
cb24c40
ยท
1 Parent(s): 0c56ad2

Rebuild app for HF Spaces

Browse files
Files changed (1) hide show
  1. app.py +44 -87
app.py CHANGED
@@ -1,8 +1,10 @@
1
  """
2
- ํ•œ๊ตญ ๋ฒˆํ˜ธํŒ OCR - KLPR_v1 (Model v4)
3
  Hugging Face Gradio App
4
  """
5
 
 
 
6
  import gradio as gr
7
  import gradio_client.utils as client_utils
8
  import torch
@@ -11,77 +13,60 @@ from PIL import Image
11
  import torchvision.transforms as transforms
12
  import numpy as np
13
 
14
- # Work around gradio_client schema parsing when additionalProperties is boolean.
15
  if not getattr(client_utils, "_patched_bool_schema", False):
16
- _orig_json_schema_to_python_type = client_utils.json_schema_to_python_type
17
 
18
- def _safe_json_schema_to_python_type(schema):
19
  if isinstance(schema, bool):
20
  return "Any"
21
- return _orig_json_schema_to_python_type(schema)
22
 
23
- client_utils.json_schema_to_python_type = _safe_json_schema_to_python_type
24
  client_utils._patched_bool_schema = True
25
 
26
- # ============================================================================
27
- # ๋ชจ๋ธ ์ •์˜
28
- # ============================================================================
29
  class CRNN(nn.Module):
30
  def __init__(self, img_height, num_chars, rnn_hidden=256):
31
- super(CRNN, self).__init__()
32
-
33
- # CNN - 32x200 -> 1x50
34
  self.cnn = nn.Sequential(
35
  nn.Conv2d(1, 64, kernel_size=3, padding=1),
36
  nn.ReLU(inplace=True),
37
  nn.MaxPool2d((2, 2)),
38
-
39
  nn.Conv2d(64, 128, kernel_size=3, padding=1),
40
  nn.ReLU(inplace=True),
41
  nn.MaxPool2d((2, 2)),
42
-
43
  nn.Conv2d(128, 256, kernel_size=3, padding=1),
44
  nn.BatchNorm2d(256),
45
  nn.ReLU(inplace=True),
46
-
47
  nn.Conv2d(256, 256, kernel_size=3, padding=1),
48
  nn.BatchNorm2d(256),
49
  nn.ReLU(inplace=True),
50
  nn.MaxPool2d((2, 1)),
51
-
52
  nn.Conv2d(256, 512, kernel_size=3, padding=1),
53
  nn.BatchNorm2d(512),
54
  nn.ReLU(inplace=True),
55
-
56
  nn.Conv2d(512, 512, kernel_size=3, padding=1),
57
  nn.BatchNorm2d(512),
58
  nn.ReLU(inplace=True),
59
  nn.MaxPool2d((2, 1)),
60
-
61
  nn.Conv2d(512, 512, kernel_size=3, padding=1),
62
  nn.BatchNorm2d(512),
63
  nn.ReLU(inplace=True),
64
- nn.MaxPool2d((2, 1))
65
  )
66
-
67
  self.rnn = nn.LSTM(512, rnn_hidden, bidirectional=True, num_layers=2, batch_first=True)
68
  self.fc = nn.Linear(rnn_hidden * 2, num_chars)
69
 
70
  def forward(self, x):
71
  conv = self.cnn(x)
72
- b, c, h, w = conv.size()
73
  conv = conv.squeeze(2).permute(0, 2, 1)
74
  rnn_out, _ = self.rnn(conv)
75
- output = self.fc(rnn_out)
76
- return output
77
 
78
- # ============================================================================
79
- # CTC ๋””์ฝ”๋”ฉ
80
- # ============================================================================
81
- def decode_predictions(outputs, itos, blank_idx=0):
82
- """CTC ๋””์ฝ”๋”ฉ"""
83
- preds = outputs.argmax(2).detach().cpu().numpy() # (B, T)
84
 
 
 
85
  decoded = []
86
  for pred in preds:
87
  char_list = []
@@ -90,104 +75,76 @@ def decode_predictions(outputs, itos, blank_idx=0):
90
  if idx != blank_idx and idx != prev_idx:
91
  char_list.append(itos[int(idx)])
92
  prev_idx = idx
93
- decoded.append(''.join(char_list))
94
  return decoded
95
 
96
- # ============================================================================
97
- # ์ด๋ฏธ์ง€ ์ „์ฒ˜๋ฆฌ
98
- # ============================================================================
99
  def preprocess_image(image, img_height=32, max_width=200):
100
- """๋ฒˆํ˜ธํŒ ์ด๋ฏธ์ง€ ์ „์ฒ˜๋ฆฌ"""
101
- # PIL Image๋กœ ๋ณ€ํ™˜ (Gradio 4.x์—์„œ type="pil"๋กœ ์ด๋ฏธ PIL Image)
102
  if not isinstance(image, Image.Image):
103
  if isinstance(image, np.ndarray):
104
- image = Image.fromarray(image.astype('uint8'))
105
-
106
- image = image.convert('L')
107
 
108
- # ๋ฆฌ์‚ฌ์ด์ฆˆ (aspect ratio ์œ ์ง€)
109
  w, h = image.size
110
  new_w = min(int(img_height * w / h), max_width)
111
  image = image.resize((new_w, img_height), Image.LANCZOS)
112
 
113
- # ํŒจ๋”ฉ
114
- new_img = Image.new('L', (max_width, img_height), 255)
115
  new_img.paste(image, (0, 0))
116
 
117
- # Transform
118
- transform = transforms.Compose([
119
- transforms.ToTensor(),
120
- transforms.Normalize((0.5,), (0.5,))
121
- ])
122
 
123
- return transform(new_img).unsqueeze(0) # (1, 1, H, W)
124
 
125
- # ============================================================================
126
- # ๋ชจ๋ธ ๋กœ๋“œ
127
- # ============================================================================
128
  print("๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘...")
129
- checkpoint_path = 'best_ocr_one_line.pth'
130
- checkpoint = torch.load(checkpoint_path, map_location='cpu')
131
 
132
- img_h = checkpoint.get('img_h', 32)
133
- max_w = checkpoint.get('max_w', 200)
134
- itos = checkpoint['itos']
135
  num_chars = len(itos)
136
 
137
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
138
  model = CRNN(img_h, num_chars, rnn_hidden=256).to(device)
139
- model.load_state_dict(checkpoint['model_state'])
140
  model.eval()
141
 
142
  print(f"โœ“ ๋ชจ๋ธ ๋กœ๋“œ ์™„๋ฃŒ (Device: {device})")
143
  print(f" - Epoch: {checkpoint.get('epoch', '?')}")
144
  print(f" - Val Acc: {checkpoint.get('val_acc', '?'):.2%}")
145
 
146
- # ============================================================================
147
- # ์ถ”๋ก  ํ•จ์ˆ˜
148
- # ============================================================================
149
  def predict_license_plate(image):
150
- """๋ฒˆํ˜ธํŒ ์ด๋ฏธ์ง€์—์„œ ํ…์ŠคํŠธ ์˜ˆ์ธก"""
151
  if image is None:
152
- return "์ด๋ฏธ์ง€๋ฅผ ์—…๋กœ๋“œํ•ด์ฃผ์„ธ์š”."
153
-
154
  try:
155
- # ์ „์ฒ˜๋ฆฌ
156
  image_tensor = preprocess_image(image, img_h, max_w).to(device)
157
-
158
- # ์ถ”๋ก 
159
  with torch.no_grad():
160
  outputs = model(image_tensor).log_softmax(2)
161
  predictions = decode_predictions(outputs, itos)
162
-
163
  result = predictions[0]
164
  return result if result else "(์ธ์‹ ๊ฒฐ๊ณผ ์—†์Œ)"
 
 
165
 
166
- except Exception as e:
167
- return f"์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}"
168
 
169
- # ============================================================================
170
- # Gradio ์ธํ„ฐํŽ˜์ด์Šค
171
- # ============================================================================
172
  demo = gr.Interface(
173
  fn=predict_license_plate,
174
  inputs=gr.Image(type="pil", label="๋ฒˆํ˜ธํŒ ์ด๋ฏธ์ง€"),
175
  outputs=gr.Textbox(label="์ธ์‹ ๊ฒฐ๊ณผ"),
176
- title="๐Ÿš— ํ•œ๊ตญ ๋ฒˆํ˜ธํŒ OCR - KLPR v1",
177
- description="""
178
- ํ•œ๊ตญ ์ž๋™์ฐจ ๋ฒˆํ˜ธํŒ์„ ์ธ์‹ํ•˜๋Š” OCR ๋ชจ๋ธ์ž…๋‹ˆ๋‹ค.
179
-
180
- **๋ชจ๋ธ ์ •๋ณด:**
181
- - Model: CRNN (CNN + Bidirectional LSTM + CTC)
182
- - Validation Accuracy: 92.38%
183
- - Epoch: 2
184
- - ์ง€์› ๋ฌธ์ž: 72๊ฐœ (ํ•œ๊ธ€ + ์ˆซ์ž)
185
-
186
- **์‚ฌ์šฉ ๋ฐฉ๋ฒ•:**
187
- 1. ๋ฒˆํ˜ธํŒ ์ด๋ฏธ์ง€๋ฅผ ์—…๋กœ๋“œํ•˜์„ธ์š”
188
- 2. ์ž๋™์œผ๋กœ ๋ฒˆํ˜ธํŒ ๋ฒˆํ˜ธ๊ฐ€ ์ธ์‹๋ฉ๋‹ˆ๋‹ค
189
- """,
190
- api_name="predict"
191
  )
192
 
193
  if __name__ == "__main__":
 
1
  """
2
+ Korean License Plate OCR - KLPR v1 (Model v4)
3
  Hugging Face Gradio App
4
  """
5
 
6
+ from __future__ import annotations
7
+
8
  import gradio as gr
9
  import gradio_client.utils as client_utils
10
  import torch
 
13
  import torchvision.transforms as transforms
14
  import numpy as np
15
 
16
+ # Work around gradio_client not handling boolean JSON schema nodes.
17
  if not getattr(client_utils, "_patched_bool_schema", False):
18
+ _orig_json_schema_to_python_type = client_utils._json_schema_to_python_type
19
 
20
+ def _safe_json_schema_to_python_type(schema, defs=None):
21
  if isinstance(schema, bool):
22
  return "Any"
23
+ return _orig_json_schema_to_python_type(schema, defs)
24
 
25
+ client_utils._json_schema_to_python_type = _safe_json_schema_to_python_type
26
  client_utils._patched_bool_schema = True
27
 
28
+
 
 
29
  class CRNN(nn.Module):
30
  def __init__(self, img_height, num_chars, rnn_hidden=256):
31
+ super().__init__()
 
 
32
  self.cnn = nn.Sequential(
33
  nn.Conv2d(1, 64, kernel_size=3, padding=1),
34
  nn.ReLU(inplace=True),
35
  nn.MaxPool2d((2, 2)),
 
36
  nn.Conv2d(64, 128, kernel_size=3, padding=1),
37
  nn.ReLU(inplace=True),
38
  nn.MaxPool2d((2, 2)),
 
39
  nn.Conv2d(128, 256, kernel_size=3, padding=1),
40
  nn.BatchNorm2d(256),
41
  nn.ReLU(inplace=True),
 
42
  nn.Conv2d(256, 256, kernel_size=3, padding=1),
43
  nn.BatchNorm2d(256),
44
  nn.ReLU(inplace=True),
45
  nn.MaxPool2d((2, 1)),
 
46
  nn.Conv2d(256, 512, kernel_size=3, padding=1),
47
  nn.BatchNorm2d(512),
48
  nn.ReLU(inplace=True),
 
49
  nn.Conv2d(512, 512, kernel_size=3, padding=1),
50
  nn.BatchNorm2d(512),
51
  nn.ReLU(inplace=True),
52
  nn.MaxPool2d((2, 1)),
 
53
  nn.Conv2d(512, 512, kernel_size=3, padding=1),
54
  nn.BatchNorm2d(512),
55
  nn.ReLU(inplace=True),
56
+ nn.MaxPool2d((2, 1)),
57
  )
 
58
  self.rnn = nn.LSTM(512, rnn_hidden, bidirectional=True, num_layers=2, batch_first=True)
59
  self.fc = nn.Linear(rnn_hidden * 2, num_chars)
60
 
61
  def forward(self, x):
62
  conv = self.cnn(x)
 
63
  conv = conv.squeeze(2).permute(0, 2, 1)
64
  rnn_out, _ = self.rnn(conv)
65
+ return self.fc(rnn_out)
 
66
 
 
 
 
 
 
 
67
 
68
+ def decode_predictions(outputs, itos, blank_idx=0):
69
+ preds = outputs.argmax(2).detach().cpu().numpy()
70
  decoded = []
71
  for pred in preds:
72
  char_list = []
 
75
  if idx != blank_idx and idx != prev_idx:
76
  char_list.append(itos[int(idx)])
77
  prev_idx = idx
78
+ decoded.append("".join(char_list))
79
  return decoded
80
 
81
+
 
 
82
  def preprocess_image(image, img_height=32, max_width=200):
 
 
83
  if not isinstance(image, Image.Image):
84
  if isinstance(image, np.ndarray):
85
+ image = Image.fromarray(image.astype("uint8"))
86
+ else:
87
+ image = Image.open(image)
88
 
89
+ image = image.convert("L")
90
  w, h = image.size
91
  new_w = min(int(img_height * w / h), max_width)
92
  image = image.resize((new_w, img_height), Image.LANCZOS)
93
 
94
+ new_img = Image.new("L", (max_width, img_height), 255)
 
95
  new_img.paste(image, (0, 0))
96
 
97
+ transform = transforms.Compose(
98
+ [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
99
+ )
100
+ return transform(new_img).unsqueeze(0)
 
101
 
 
102
 
 
 
 
103
  print("๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘...")
104
+ checkpoint_path = "best_ocr_one_line.pth"
105
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
106
 
107
+ img_h = checkpoint.get("img_h", 32)
108
+ max_w = checkpoint.get("max_w", 200)
109
+ itos = checkpoint["itos"]
110
  num_chars = len(itos)
111
 
112
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
113
  model = CRNN(img_h, num_chars, rnn_hidden=256).to(device)
114
+ model.load_state_dict(checkpoint["model_state"])
115
  model.eval()
116
 
117
  print(f"โœ“ ๋ชจ๋ธ ๋กœ๋“œ ์™„๋ฃŒ (Device: {device})")
118
  print(f" - Epoch: {checkpoint.get('epoch', '?')}")
119
  print(f" - Val Acc: {checkpoint.get('val_acc', '?'):.2%}")
120
 
121
+
 
 
122
  def predict_license_plate(image):
 
123
  if image is None:
124
+ return "์ด๋ฏธ์ง€๋ฅผ ์—…๋กœ๋“œํ•ด ์ฃผ์„ธ์š”."
 
125
  try:
 
126
  image_tensor = preprocess_image(image, img_h, max_w).to(device)
 
 
127
  with torch.no_grad():
128
  outputs = model(image_tensor).log_softmax(2)
129
  predictions = decode_predictions(outputs, itos)
 
130
  result = predictions[0]
131
  return result if result else "(์ธ์‹ ๊ฒฐ๊ณผ ์—†์Œ)"
132
+ except Exception as exc:
133
+ return f"์˜ค๋ฅ˜ ๋ฐœ์ƒ: {exc}"
134
 
 
 
135
 
 
 
 
136
  demo = gr.Interface(
137
  fn=predict_license_plate,
138
  inputs=gr.Image(type="pil", label="๋ฒˆํ˜ธํŒ ์ด๋ฏธ์ง€"),
139
  outputs=gr.Textbox(label="์ธ์‹ ๊ฒฐ๊ณผ"),
140
+ title="๐Ÿš˜ ํ•œ๊ตญ ๋ฒˆํ˜ธํŒ OCR - KLPR v1",
141
+ description=(
142
+ "๋ฒˆํ˜ธํŒ ์ด๋ฏธ์ง€์—์„œ ๋ฌธ์ž๋ฅผ ์ธ์‹ํ•ฉ๋‹ˆ๋‹ค.\n\n"
143
+ "**๋ชจ๋ธ ์ •๋ณด:** CRNN (CNN + BiLSTM + CTC)\n"
144
+ "**์ž…๋ ฅ:** ๋ฒˆํ˜ธํŒ ์ด๋ฏธ์ง€ 1์žฅ"
145
+ ),
146
+ api_name="predict",
147
+ cache_examples=False,
 
 
 
 
 
 
 
148
  )
149
 
150
  if __name__ == "__main__":