|
|
from .config import Config |
|
|
from .model import Text_recognization_model |
|
|
import os |
|
|
import torch |
|
|
|
|
|
from .utils import CTCLabelConverter,Averager |
|
|
|
|
|
from PIL import Image |
|
|
import math |
|
|
import numpy as np |
|
|
from .dataset import NormalizePAD |
|
|
import tempfile |
|
|
|
|
|
|
|
|
import os |
|
|
import math |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import torch |
|
|
|
|
|
class TextRecognition: |
|
|
def __init__(self,model_path='model/recognization_model.pth' , device='cpu' ): |
|
|
|
|
|
self.opt = Config() |
|
|
self.opt.device = device |
|
|
self.model_path = model_path |
|
|
|
|
|
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
file_path = os.path.join(current_dir, "words.txt") |
|
|
with open(file_path, "r", encoding="utf-8") as file: |
|
|
content = file.readlines() |
|
|
self.opt.character = ''.join([str(elem).strip('\n') for elem in content]) + " " |
|
|
|
|
|
|
|
|
if 'CTC' in self.opt.Prediction: |
|
|
self.converter = CTCLabelConverter(self.opt.character) |
|
|
else: |
|
|
self.converter = AttnLabelConverter(self.opt.character) |
|
|
|
|
|
|
|
|
self.opt.num_class = len(self.converter.character) |
|
|
|
|
|
|
|
|
model_path = os.path.join(current_dir, self.model_path) |
|
|
self.model = Text_recognization_model(self.opt) |
|
|
|
|
|
self.model.load_state_dict(torch.load(model_path, map_location=self.opt.device, weights_only=True)) |
|
|
self.model = self.model.to(self.opt.device) |
|
|
self.model.eval() |
|
|
|
|
|
def recognize_image(self, image): |
|
|
|
|
|
if isinstance(image, str): |
|
|
pil_image = Image.open(image).convert('L') |
|
|
elif isinstance(image, np.ndarray): |
|
|
if len(image.shape) == 3 and image.shape[2] == 3: |
|
|
|
|
|
gray_array = np.dot(image[..., :3], [0.2989, 0.5870, 0.1140]) |
|
|
pil_image = Image.fromarray(gray_array.astype('uint8')) |
|
|
elif len(image.shape) == 2: |
|
|
pil_image = Image.fromarray(image.astype('uint8')) |
|
|
else: |
|
|
raise ValueError("Unsupported image format!") |
|
|
else: |
|
|
raise TypeError("Input must be a file path (str) or a NumPy array.") |
|
|
|
|
|
|
|
|
pil_image = pil_image.transpose(Image.Transpose.FLIP_LEFT_RIGHT) |
|
|
w, h = pil_image.size |
|
|
ratio = w / float(h) |
|
|
|
|
|
if math.ceil(self.opt.imgH * ratio) > self.opt.imgW: |
|
|
resized_w = self.opt.imgW |
|
|
else: |
|
|
resized_w = math.ceil(self.opt.imgH * ratio) |
|
|
pil_image = pil_image.resize((resized_w, self.opt.imgH), Image.Resampling.BICUBIC) |
|
|
|
|
|
|
|
|
transform = NormalizePAD((1, self.opt.imgH, self.opt.imgW)) |
|
|
img = transform(pil_image) |
|
|
img = img.unsqueeze(0) |
|
|
img = img.to(self.opt.device) |
|
|
|
|
|
|
|
|
preds = self.model(img) |
|
|
preds_size = torch.IntTensor([preds.size(1)]) |
|
|
_, preds_index = preds.max(2) |
|
|
preds_str = self.converter.decode(preds_index.data, preds_size.data)[0] |
|
|
|
|
|
return preds_str |
|
|
|
|
|
|