Spaces:
Sleeping
Sleeping
| # app/infer.py | |
| import os | |
| import torch | |
| import torch.nn as nn | |
| from PIL import Image | |
| import torchvision.transforms as T | |
| from app.utils import CHARS, idx2char, BLANK_CHAR | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| NUM_CLASSES = len(CHARS) | |
| # -------------------- | |
| # CRNN MODEL (SAME AS TRAINING) | |
| # -------------------- | |
| class CRNN(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.cnn = nn.Sequential( | |
| nn.Conv2d(1, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), | |
| nn.MaxPool2d(2, 2), | |
| nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(), | |
| nn.MaxPool2d(2, 2), | |
| nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(), | |
| nn.MaxPool2d((2, 1)), | |
| nn.Conv2d(256, 256, 3, padding=1), nn.ReLU() | |
| ) | |
| self.rnn = nn.LSTM( | |
| input_size=256 * 7, | |
| hidden_size=256, | |
| num_layers=2, | |
| bidirectional=True, | |
| batch_first=True | |
| ) | |
| self.fc = nn.Linear(512, NUM_CLASSES) | |
| def forward(self, x): | |
| x = self.cnn(x) | |
| b, c, h, w = x.shape | |
| x = x.permute(0, 3, 1, 2).reshape(b, w, c * h) | |
| x, _ = self.rnn(x) | |
| return self.fc(x) | |
| # -------------------- | |
| # LOAD MODEL | |
| # -------------------- | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| WEIGHTS = os.path.join(BASE_DIR, "weights", "ocr_model.pth") | |
| model = CRNN().to(DEVICE) | |
| model.load_state_dict(torch.load(WEIGHTS, map_location=DEVICE)) | |
| model.eval() | |
| # -------------------- | |
| # TRANSFORM | |
| # -------------------- | |
| transform = T.Compose([ | |
| T.Grayscale(), | |
| T.Resize((60, 160)), | |
| T.ToTensor() | |
| ]) | |
| # -------------------- | |
| # CTC DECODER | |
| # -------------------- | |
| def ctc_decode(logits): | |
| probs = logits.softmax(2)[0] | |
| best = probs.argmax(1) | |
| prev = None | |
| text = "" | |
| for idx in best: | |
| idx = idx.item() | |
| if idx != prev and CHARS[idx] != BLANK_CHAR: | |
| text += CHARS[idx] | |
| prev = idx | |
| return text | |
| # -------------------- | |
| # PUBLIC API | |
| # -------------------- | |
| def predict(pil_img: Image.Image): | |
| img = transform(pil_img).unsqueeze(0).to(DEVICE) | |
| with torch.no_grad(): | |
| logits = model(img) | |
| text = ctc_decode(logits) | |
| confidence = round(float(logits.softmax(2).max()), 3) | |
| return text, confidence | |