Spaces:
Sleeping
Sleeping
File size: 1,947 Bytes
8ec1e48 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 | import torch
import torch.nn as nn
from transformers import AutoModel
class PhoBERTMultiHeadGRU(nn.Module):
def __init__(self, phobert_path: str, gru_hidden_dim: int, num_labels: int, num_classes: int):
super().__init__()
self.phobert = AutoModel.from_pretrained(phobert_path)
phobert_hidden_size = self.phobert.config.hidden_size
self.gru = nn.GRU(
input_size=phobert_hidden_size,
hidden_size=gru_hidden_dim,
num_layers=1,
batch_first=True,
bidirectional=True,
)
self.heads = nn.ModuleList(
[nn.Linear(gru_hidden_dim * 2, num_classes) for _ in range(num_labels)]
)
@staticmethod
def masked_mean_pool(sequence_output: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
mask = attention_mask.unsqueeze(-1).type_as(sequence_output)
masked_sum = (sequence_output * mask).sum(dim=1)
token_count = mask.sum(dim=1).clamp(min=1.0)
return masked_sum / token_count
def forward(self, input_ids, attention_mask):
phobert_outputs = self.phobert(input_ids=input_ids, attention_mask=attention_mask)
sequence_output = phobert_outputs.last_hidden_state
lengths = attention_mask.sum(dim=1).to(dtype=torch.long).cpu()
packed = nn.utils.rnn.pack_padded_sequence(
sequence_output,
lengths,
batch_first=True,
enforce_sorted=False,
)
packed_output, _ = self.gru(packed)
gru_output, _ = nn.utils.rnn.pad_packed_sequence(
packed_output,
batch_first=True,
total_length=sequence_output.size(1),
)
pooled_output = self.masked_mean_pool(gru_output, attention_mask)
# return list of logits, one for each head
return [head(pooled_output) for head in self.heads]
|