import torch import torch.nn as nn from transformers import AutoModel class PhoBERTMultiHeadGRU(nn.Module): def __init__(self, phobert_path: str, gru_hidden_dim: int, num_labels: int, num_classes: int): super().__init__() self.phobert = AutoModel.from_pretrained(phobert_path) phobert_hidden_size = self.phobert.config.hidden_size self.gru = nn.GRU( input_size=phobert_hidden_size, hidden_size=gru_hidden_dim, num_layers=1, batch_first=True, bidirectional=True, ) self.heads = nn.ModuleList( [nn.Linear(gru_hidden_dim * 2, num_classes) for _ in range(num_labels)] ) @staticmethod def masked_mean_pool(sequence_output: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: mask = attention_mask.unsqueeze(-1).type_as(sequence_output) masked_sum = (sequence_output * mask).sum(dim=1) token_count = mask.sum(dim=1).clamp(min=1.0) return masked_sum / token_count def forward(self, input_ids, attention_mask): phobert_outputs = self.phobert(input_ids=input_ids, attention_mask=attention_mask) sequence_output = phobert_outputs.last_hidden_state lengths = attention_mask.sum(dim=1).to(dtype=torch.long).cpu() packed = nn.utils.rnn.pack_padded_sequence( sequence_output, lengths, batch_first=True, enforce_sorted=False, ) packed_output, _ = self.gru(packed) gru_output, _ = nn.utils.rnn.pad_packed_sequence( packed_output, batch_first=True, total_length=sequence_output.size(1), ) pooled_output = self.masked_mean_pool(gru_output, attention_mask) # return list of logits, one for each head return [head(pooled_output) for head in self.heads]