| | |
| | import torch |
| | import torch.nn as nn |
| | from mmcv.runner import BaseModule |
| |
|
| | from mmocr.models.builder import FUSERS |
| |
|
| |
|
| | @FUSERS.register_module() |
| | class ABIFuser(BaseModule): |
| | """Mix and align visual feature and linguistic feature Implementation of |
| | language model of `ABINet <https://arxiv.org/abs/1910.04396>`_. |
| | |
| | Args: |
| | d_model (int): Hidden size of input. |
| | max_seq_len (int): Maximum text sequence length :math:`T`. |
| | num_chars (int): Number of text characters :math:`C`. |
| | init_cfg (dict): Specifies the initialization method for model layers. |
| | """ |
| |
|
| | def __init__(self, |
| | d_model=512, |
| | max_seq_len=40, |
| | num_chars=90, |
| | init_cfg=None, |
| | **kwargs): |
| | super().__init__(init_cfg=init_cfg) |
| |
|
| | self.max_seq_len = max_seq_len + 1 |
| | self.w_att = nn.Linear(2 * d_model, d_model) |
| | self.cls = nn.Linear(d_model, num_chars) |
| |
|
| | def forward(self, l_feature, v_feature): |
| | """ |
| | Args: |
| | l_feature: (N, T, E) where T is length, N is batch size and |
| | d is dim of model. |
| | v_feature: (N, T, E) shape the same as l_feature. |
| | |
| | Returns: |
| | A dict with key ``logits`` |
| | The logits of shape (N, T, C) where N is batch size, T is length |
| | and C is the number of characters. |
| | """ |
| | f = torch.cat((l_feature, v_feature), dim=2) |
| | f_att = torch.sigmoid(self.w_att(f)) |
| | output = f_att * v_feature + (1 - f_att) * l_feature |
| |
|
| | logits = self.cls(output) |
| |
|
| | return {'logits': logits} |
| |
|