Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from transformers import AutoModel | |
| class PhoBERTMultiHeadGRU(nn.Module): | |
| def __init__(self, phobert_path: str, gru_hidden_dim: int, num_labels: int, num_classes: int): | |
| super().__init__() | |
| self.phobert = AutoModel.from_pretrained(phobert_path) | |
| phobert_hidden_size = self.phobert.config.hidden_size | |
| self.gru = nn.GRU( | |
| input_size=phobert_hidden_size, | |
| hidden_size=gru_hidden_dim, | |
| num_layers=1, | |
| batch_first=True, | |
| bidirectional=True, | |
| ) | |
| self.heads = nn.ModuleList( | |
| [nn.Linear(gru_hidden_dim * 2, num_classes) for _ in range(num_labels)] | |
| ) | |
| def masked_mean_pool(sequence_output: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: | |
| mask = attention_mask.unsqueeze(-1).type_as(sequence_output) | |
| masked_sum = (sequence_output * mask).sum(dim=1) | |
| token_count = mask.sum(dim=1).clamp(min=1.0) | |
| return masked_sum / token_count | |
| def forward(self, input_ids, attention_mask): | |
| phobert_outputs = self.phobert(input_ids=input_ids, attention_mask=attention_mask) | |
| sequence_output = phobert_outputs.last_hidden_state | |
| lengths = attention_mask.sum(dim=1).to(dtype=torch.long).cpu() | |
| packed = nn.utils.rnn.pack_padded_sequence( | |
| sequence_output, | |
| lengths, | |
| batch_first=True, | |
| enforce_sorted=False, | |
| ) | |
| packed_output, _ = self.gru(packed) | |
| gru_output, _ = nn.utils.rnn.pad_packed_sequence( | |
| packed_output, | |
| batch_first=True, | |
| total_length=sequence_output.size(1), | |
| ) | |
| pooled_output = self.masked_mean_pool(gru_output, attention_mask) | |
| # return list of logits, one for each head | |
| return [head(pooled_output) for head in self.heads] | |