wmt22-cometkiwi-da-v2 / modeling_comet.py
DeepTranslateAdmin's picture
Upload model
adac040 verified
from typing import Any
import torch
from torch import nn
from transformers import (
PreTrainedModel,
XLMRobertaConfig,
XLMRobertaModel,
)
from .configuration_comet import CometModelConfig
class Encoder(nn.Module):
"""Encoder module based on XLMRoberta."""
def __init__(self):
super().__init__()
self.model = XLMRobertaModel(
config=XLMRobertaConfig.from_pretrained("microsoft/infoxlm-large"),
add_pooling_layer=False,
)
self.model.encoder.output_hidden_states = True
def forward(
self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs: Any
) -> dict[str, Any]:
return self.model(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
return_dict=False,
)[-1]
@property
def num_layers(self) -> int:
"""Number of model layers available."""
return self.model.config.num_hidden_layers + 1
@property
def output_units(self) -> int:
"""Max number of tokens the encoder handles."""
return self.model.config.hidden_size
class LayerwiseAttention(nn.Module):
"""Module that applies attention across model layers."""
def __init__(
self,
num_layers: int,
layer_weights: list[float] | None = None,
) -> None:
super().__init__()
layer_weights = layer_weights or [0.0] * num_layers
self.scalar_parameters = nn.ParameterList(
[
nn.Parameter(torch.HalfTensor([layer_weights[i]]), requires_grad=True)
for i in range(num_layers)
]
)
self.weight = nn.Parameter(torch.HalfTensor([1.0]), requires_grad=True)
def forward(
self,
tensors: list[torch.Tensor],
mask: torch.Tensor,
) -> torch.Tensor:
weights = torch.cat([parameter for parameter in self.scalar_parameters])
normed_weights = torch.softmax(weights, dim=0)
normed_weights = torch.split(normed_weights, split_size_or_sections=1)
return self.weight * sum(
weight * tensor for weight, tensor in zip(normed_weights, tensors)
)
class Estimator(nn.Module):
"""Feed-forward estimator module."""
def _get_activation(self, activation: str) -> nn.Module:
"""Get activation function by name."""
if hasattr(nn, activation.title()):
return getattr(nn, activation.title())()
raise ValueError(f"{activation} is not a valid activation function!")
def __init__(
self,
in_dim: int,
out_dim: int = 1,
hidden_sizes: list[int] = [3072, 1024],
activations: str = "Tanh",
dropout: float = 0.1,
) -> None:
super().__init__()
modules: list[nn.Module] = []
# First layer
modules.append(nn.Linear(in_dim, hidden_sizes[0]))
modules.append(self._get_activation(activations))
modules.append(nn.Dropout(dropout))
# Hidden layers
for i in range(1, len(hidden_sizes)):
modules.append(nn.Linear(hidden_sizes[i - 1], hidden_sizes[i]))
modules.append(self._get_activation(activations))
modules.append(nn.Dropout(dropout))
# Output layer
modules.append(nn.Linear(hidden_sizes[-1], int(out_dim)))
self.ff = nn.Sequential(*modules)
def forward(self, in_features: torch.Tensor) -> torch.Tensor:
return self.ff(in_features)
class CometModel(PreTrainedModel):
config_class = CometModelConfig
_no_split_modules = ["Encoder", "LayerwiseAttention", "Estimator"]
def __init__(self, config: CometModelConfig) -> None:
super().__init__(config)
self.encoder = Encoder()
self.layerwise_attention = LayerwiseAttention(
num_layers=self.encoder.num_layers
)
self.estimator = Estimator(
in_dim=self.encoder.output_units,
hidden_sizes=config.hidden_sizes,
activations=config.activations,
dropout=config.dropout,
)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: torch.Tensor | None = None,
**kwargs: Any,
) -> torch.Tensor:
encoder_out = self.encoder(
input_ids,
attention_mask,
token_type_ids=token_type_ids,
)
embeddings = self.layerwise_attention(
encoder_out,
attention_mask,
)
# Use CLS token as sentence embedding
embedding = embeddings[:, 0, :]
return self.estimator(embedding).view(-1)