Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| import torch | |
| import torch.nn.functional as F | |
| import torchvision | |
| import matplotlib.pyplot as plt | |
| import zipfile | |
| import os | |
| import gradio as gr | |
| from PIL import Image | |
| CHARS = "~=" + " abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789,.'-!?:;\"" | |
| BLANK = 0 | |
| PAD = 1 | |
| CHARS_DICT = {c: i for i, c in enumerate(CHARS)} | |
| TEXTLEN = 30 | |
| tokens_list = list(CHARS_DICT.keys()) | |
| silence_token = '|' | |
| if silence_token not in tokens_list: | |
| tokens_list.append(silence_token) | |
| def fit_picture(img): | |
| target_height = 32 | |
| target_width = 400 | |
| # Calculate resize dimensions | |
| aspect_ratio = img.width / img.height | |
| if aspect_ratio > (target_width / target_height): | |
| resize_width = target_width | |
| resize_height = int(target_width / aspect_ratio) | |
| else: | |
| resize_height = target_height | |
| resize_width = int(target_height * aspect_ratio) | |
| # Resize transformation | |
| resize_transform = torchvision.transforms.Resize((resize_height, resize_width)) | |
| # Pad transformation | |
| padding_height = (target_height - resize_height) if target_height > resize_height else 0 | |
| padding_width = (target_width - resize_width) if target_width > resize_width else 0 | |
| pad_transform = torchvision.transforms.Pad((0, 0, padding_width, padding_height), fill=0, padding_mode='constant') | |
| transformss = torchvision.transforms.Compose([ | |
| torchvision.transforms.Grayscale(num_output_channels = 1), | |
| torchvision.transforms.ToTensor(), | |
| torchvision.transforms.Normalize(0.5,0.5), | |
| resize_transform, | |
| pad_transform | |
| ]) | |
| fin_img = transformss(img) | |
| return fin_img | |
| def load_model(filename): | |
| data = torch.load(filename, map_location=torch.device('cpu'), weights_only=True) | |
| recognizer.load_state_dict(data["recognizer"]) | |
| optimizer.load_state_dict(data["optimizer"]) | |
| def ctc_decode_sequence(seq): | |
| """Removes blanks and repetitions from the sequence.""" | |
| ret = [] | |
| prev = BLANK | |
| for x in seq: | |
| if prev != BLANK and prev != x: | |
| ret.append(prev) | |
| prev = x | |
| if seq[-1] == 66: | |
| ret.append(66) | |
| return ret | |
| def ctc_decode(codes): | |
| """Decode a batch of sequences.""" | |
| ret = [] | |
| for cs in codes.T: | |
| ret.append(ctc_decode_sequence(cs)) | |
| return ret | |
| def decode_text(codes): | |
| chars = [CHARS[c] for c in codes] | |
| return ''.join(chars) | |
| class Residual(torch.nn.Module): | |
| def __init__(self, in_channels, out_channels, stride, pdrop = 0.2): | |
| super().__init__() | |
| self.conv1 = torch.nn.Conv2d(in_channels, out_channels, 3, stride, 1) | |
| self.bn1 = torch.nn.BatchNorm2d(out_channels) | |
| self.conv2 = torch.nn.Conv2d(out_channels, out_channels, 3, 1, 1) | |
| self.bn2 = torch.nn.BatchNorm2d(out_channels) | |
| if in_channels != out_channels or stride != 1: | |
| self.skip = torch.nn.Conv2d(in_channels, out_channels, 1, stride, 0) | |
| else: | |
| self.skip = torch.nn.Identity() | |
| self.dropout = torch.nn.Dropout2d(pdrop) | |
| def forward(self, x): | |
| y = torch.nn.functional.relu(self.bn1(self.conv1(x))) | |
| y = torch.nn.functional.relu(self.bn2(self.conv2(y)) + self.skip(x)) | |
| y = self.dropout(y) | |
| return y | |
| class TextRecognizer(torch.nn.Module): | |
| def __init__(self, labels): | |
| super().__init__() | |
| self.feature_extractor = torch.nn.Sequential( | |
| Residual(1, 32, 1), | |
| Residual(32, 32, 2), | |
| Residual(32, 32, 1), | |
| Residual(32, 64, 2), | |
| Residual(64, 64, 1), | |
| Residual(64, 128, (2,1)), | |
| Residual(128, 128, 1), | |
| Residual(128, 128, (2,1)), | |
| Residual(128, 128, (2,1)), | |
| ) | |
| self.recurrent = torch.nn.LSTM(128, 128, 1 ,bidirectional = True) | |
| self.output = torch.nn.Linear(256, labels) | |
| def forward(self, x): | |
| x = self.feature_extractor(x) | |
| x = x.squeeze(2) | |
| x = x.permute(2,0,1) | |
| x,_ = self.recurrent(x) | |
| x = self.output(x) | |
| return x | |
| recognizer = TextRecognizer(len(CHARS)) | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| print("Device:", DEVICE) | |
| LR = 1e-3 | |
| recognizer.to(DEVICE) | |
| optimizer = torch.optim.Adam(recognizer.parameters(), lr=LR) | |
| load_model('model.pt') | |
| recognizer.eval() | |
| def ctc_read(image): | |
| imagefin = fit_picture(image) | |
| image_tensor = imagefin.unsqueeze(0).to(DEVICE) | |
| print(image_tensor.size()) | |
| with torch.no_grad(): | |
| scores = recognizer(image_tensor) | |
| predictions = scores.argmax(2).cpu().numpy() | |
| decoded_sequences = ctc_decode(predictions) | |
| # Convert decoded sequences to text | |
| for i in decoded_sequences: | |
| decoded_text = decode_text(i) | |
| return decoded_text | |
| # Gradio Interface | |
| iface = gr.Interface( | |
| fn=ctc_read, | |
| inputs=gr.Image(type="pil"), # PIL Image input | |
| outputs="text", # Text output | |
| title="Handwritten Text Recognition", | |
| description="Upload an image, and the custome AI will extract the text." | |
| ) | |
| iface.launch(share=True) |