KanaWrite / utils.py
mooncake030's picture
add hira kata options
196a03f
from pathlib import Path
import cv2
import numpy as np
from openvino import Core
from pydantic import BaseModel
class KanaData(BaseModel):
class Category(BaseModel):
seion: list[str]
dakuon: list[str]
handakuon: list[str]
youon: list[str]
category: Category
hiragana: dict[str, list[str]]
katakana: dict[str, list[str]]
spell: dict[str, list[str]]
@classmethod
def load(cls, path):
return cls.model_validate_json(Path(path).read_text())
class Recognizer:
class Result(BaseModel):
char: str
prob: float
def __init__(self, model_path, char_list_path, device="CPU", blank="[blank]"):
core = Core()
self.model = core.read_model(model_path)
self.compiled_model = core.compile_model(self.model, device)
self.infer_request = self.compiled_model.create_infer_request()
# (batch_size, channel, width, height)
_, _, self.input_height, self.input_width = self.model.inputs[0].shape
self.input_tensor_name = self.model.inputs[0].get_any_name()
self.output_tensor_name = self.model.outputs[0].get_any_name()
with open(char_list_path, "rt", encoding="UTF-8") as fp:
self.chars = [blank] + fp.read().split("\n")
def recognize(self, image, top_k=10):
image = self.preprocess(image, self.input_height, self.input_width)[None, :, :, :]
for _ in range(2):
self.infer_request.infer(inputs={self.input_tensor_name: image})
preds = self.infer_request.get_tensor(self.output_tensor_name).data[:]
return self.ctc_decode(preds, top_k)
def preprocess(self, image, height, width, invert=False):
src: np.ndarray = cv2.cvtColor(image, cv2.COLOR_RGBA2GRAY)
src = (255 - src) if invert else src
ratio = float(src.shape[1]) / float(src.shape[0])
dsize = (int(height * ratio), height)
rsz = cv2.resize(src, dsize, interpolation=cv2.INTER_AREA).astype(np.float32)
img = rsz[None, :, :] # [h,w] -> [c,h,w]
_, h, w = img.shape
# right edge padding
return np.pad(img, ((0, 0), (0, height - h), (0, width - w)), mode="edge")
def ctc_decode(self, preds, top_k) -> tuple[list, list[list[Result]]]:
index, texts, nbest = 0, list(), list()
preds_index: np.ndarray = np.argmax(preds, 2)
preds_index = preds_index.transpose(1, 0)
preds_index_reshape = preds_index.reshape(-1)
preds_sizes = np.array([preds_index.shape[1]] * preds_index.shape[0])
for step in preds_sizes:
t = preds_index_reshape[index : index + step]
if t.shape[0] == 0:
continue
char_list = []
for i in range(step):
if t[i] == 0:
continue
# removing repeated characters and blank.
if i > 0 and t[i - 1] == t[i]:
continue
char_list.append(self.chars[t[i]])
# process n-best
probs = self.softmax(preds[i][0])
k_indices = np.argsort(-probs)[:top_k]
k_probs = probs[k_indices]
k_results = [
Recognizer.Result(char=self.chars[j], prob=prob)
for j, prob in zip(k_indices, k_probs)
]
nbest.append(k_results)
text = "".join(char_list)
texts.append(text)
index += step
return texts, nbest
def softmax(self, x):
exp_x = np.exp(x - np.max(x))
return exp_x / np.sum(exp_x, axis=0)