Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from app.utils import CHARS | |
| NUM_CLASSES = len(CHARS) | |
| class CRNN(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.cnn = nn.Sequential( | |
| nn.Conv2d(1, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), | |
| nn.MaxPool2d(2, 2), | |
| nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(), | |
| nn.MaxPool2d(2, 2), | |
| nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(), | |
| nn.MaxPool2d((2, 1)), | |
| nn.Conv2d(256, 256, 3, padding=1), nn.ReLU() | |
| ) | |
| self.rnn = nn.LSTM( | |
| input_size=256 * 7, | |
| hidden_size=256, | |
| num_layers=2, | |
| bidirectional=True, | |
| batch_first=True | |
| ) | |
| self.fc = nn.Linear(512, NUM_CLASSES) | |
| def forward(self, x): | |
| x = self.cnn(x) | |
| b, c, h, w = x.shape | |
| x = x.permute(0, 3, 1, 2).reshape(b, w, c * h) | |
| x, _ = self.rnn(x) | |
| return self.fc(x) | |