|
|
from typing import Any |
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
from transformers import ( |
|
|
PreTrainedModel, |
|
|
XLMRobertaConfig, |
|
|
XLMRobertaModel, |
|
|
) |
|
|
from .configuration_comet import CometModelConfig |
|
|
|
|
|
|
|
|
class Encoder(nn.Module): |
|
|
"""Encoder module based on XLMRoberta.""" |
|
|
|
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
self.model = XLMRobertaModel( |
|
|
config=XLMRobertaConfig.from_pretrained("microsoft/infoxlm-large"), |
|
|
add_pooling_layer=False, |
|
|
) |
|
|
self.model.encoder.output_hidden_states = True |
|
|
|
|
|
def forward( |
|
|
self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs: Any |
|
|
) -> dict[str, Any]: |
|
|
return self.model( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
output_hidden_states=True, |
|
|
return_dict=False, |
|
|
)[-1] |
|
|
|
|
|
@property |
|
|
def num_layers(self) -> int: |
|
|
"""Number of model layers available.""" |
|
|
return self.model.config.num_hidden_layers + 1 |
|
|
|
|
|
@property |
|
|
def output_units(self) -> int: |
|
|
"""Max number of tokens the encoder handles.""" |
|
|
return self.model.config.hidden_size |
|
|
|
|
|
|
|
|
class LayerwiseAttention(nn.Module): |
|
|
"""Module that applies attention across model layers.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
num_layers: int, |
|
|
layer_weights: list[float] | None = None, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
layer_weights = layer_weights or [0.0] * num_layers |
|
|
self.scalar_parameters = nn.ParameterList( |
|
|
[ |
|
|
nn.Parameter(torch.HalfTensor([layer_weights[i]]), requires_grad=True) |
|
|
for i in range(num_layers) |
|
|
] |
|
|
) |
|
|
self.weight = nn.Parameter(torch.HalfTensor([1.0]), requires_grad=True) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
tensors: list[torch.Tensor], |
|
|
mask: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
weights = torch.cat([parameter for parameter in self.scalar_parameters]) |
|
|
normed_weights = torch.softmax(weights, dim=0) |
|
|
normed_weights = torch.split(normed_weights, split_size_or_sections=1) |
|
|
return self.weight * sum( |
|
|
weight * tensor for weight, tensor in zip(normed_weights, tensors) |
|
|
) |
|
|
|
|
|
|
|
|
class Estimator(nn.Module): |
|
|
"""Feed-forward estimator module.""" |
|
|
|
|
|
def _get_activation(self, activation: str) -> nn.Module: |
|
|
"""Get activation function by name.""" |
|
|
if hasattr(nn, activation.title()): |
|
|
return getattr(nn, activation.title())() |
|
|
raise ValueError(f"{activation} is not a valid activation function!") |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
in_dim: int, |
|
|
out_dim: int = 1, |
|
|
hidden_sizes: list[int] = [3072, 1024], |
|
|
activations: str = "Tanh", |
|
|
dropout: float = 0.1, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
modules: list[nn.Module] = [] |
|
|
|
|
|
|
|
|
modules.append(nn.Linear(in_dim, hidden_sizes[0])) |
|
|
modules.append(self._get_activation(activations)) |
|
|
modules.append(nn.Dropout(dropout)) |
|
|
|
|
|
|
|
|
for i in range(1, len(hidden_sizes)): |
|
|
modules.append(nn.Linear(hidden_sizes[i - 1], hidden_sizes[i])) |
|
|
modules.append(self._get_activation(activations)) |
|
|
modules.append(nn.Dropout(dropout)) |
|
|
|
|
|
|
|
|
modules.append(nn.Linear(hidden_sizes[-1], int(out_dim))) |
|
|
|
|
|
self.ff = nn.Sequential(*modules) |
|
|
|
|
|
def forward(self, in_features: torch.Tensor) -> torch.Tensor: |
|
|
return self.ff(in_features) |
|
|
|
|
|
|
|
|
class CometModel(PreTrainedModel): |
|
|
config_class = CometModelConfig |
|
|
_no_split_modules = ["Encoder", "LayerwiseAttention", "Estimator"] |
|
|
|
|
|
def __init__(self, config: CometModelConfig) -> None: |
|
|
super().__init__(config) |
|
|
|
|
|
self.encoder = Encoder() |
|
|
self.layerwise_attention = LayerwiseAttention( |
|
|
num_layers=self.encoder.num_layers |
|
|
) |
|
|
self.estimator = Estimator( |
|
|
in_dim=self.encoder.output_units, |
|
|
hidden_sizes=config.hidden_sizes, |
|
|
activations=config.activations, |
|
|
dropout=config.dropout, |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
attention_mask: torch.Tensor, |
|
|
token_type_ids: torch.Tensor | None = None, |
|
|
**kwargs: Any, |
|
|
) -> torch.Tensor: |
|
|
encoder_out = self.encoder( |
|
|
input_ids, |
|
|
attention_mask, |
|
|
token_type_ids=token_type_ids, |
|
|
) |
|
|
embeddings = self.layerwise_attention( |
|
|
encoder_out, |
|
|
attention_mask, |
|
|
) |
|
|
|
|
|
embedding = embeddings[:, 0, :] |
|
|
return self.estimator(embedding).view(-1) |
|
|
|