| | from .configuration_hypernet import ZettHypernetConfig |
| | from transformers import PreTrainedModel, RobertaConfig, RobertaModel |
| | from functools import partial |
| |
|
| | from torch import nn as nn |
| | import torch |
| | from torch.nn import functional as F |
| |
|
| | class Rescaler(nn.Module): |
| | def __init__(self, dim: int): |
| | super().__init__() |
| |
|
| | self.dim = dim |
| |
|
| | self.w = nn.Parameter(torch.ones((1, self.dim)), requires_grad=False) |
| | self.b = nn.Parameter(torch.ones((1, self.dim)), requires_grad=False) |
| |
|
| | def __call__(self, x): |
| | return self.w * x + self.b |
| |
|
| |
|
| | class ProjectorBlock(nn.Module): |
| | def __init__(self, input_dim: int, dim: int, intermediate_dim: int): |
| | super().__init__() |
| |
|
| | self.input_dim = input_dim |
| | self.dim = dim |
| | self.intermediate_dim = intermediate_dim |
| |
|
| | self.dense1 = nn.Linear(self.input_dim, self.intermediate_dim) |
| | self.dense2 = nn.Linear(self.intermediate_dim, self.dim) |
| |
|
| | self.ln = nn.LayerNorm(self.dim, eps=1e-6) |
| |
|
| | def __call__(self, x): |
| | h = F.gelu( |
| | self.dense2(F.gelu(self.dense1(x), approximate="tanh")), |
| | approximate="tanh", |
| | ) |
| | return self.ln(h + x) |
| |
|
| |
|
| | class ZettHypernet(PreTrainedModel): |
| | config_class = ZettHypernetConfig |
| |
|
| | def __init__(self, config: ZettHypernetConfig): |
| | super().__init__(config) |
| |
|
| | self.config = config |
| | self.has_separate_out_embeddings = getattr( |
| | self.config, "separate_out_embeddings", False |
| | ) |
| | self.lang_embeddings = nn.Embedding( |
| | self.config.n_langs, self.config.hn_hidden_size |
| | ) |
| |
|
| | if self.has_separate_out_embeddings: |
| | n_in_embd = self.config.n_embd * 2 |
| | n_out_embd = self.config.n_embd |
| | else: |
| | n_in_embd = self.config.n_embd |
| | n_out_embd = self.config.n_embd |
| |
|
| | if self.config.hn_model_type == "roberta": |
| | config = RobertaConfig.from_pretrained( |
| | self.config.hn_model_name_or_path |
| | ) |
| | config.num_hidden_layers = self.config.hn_n_layers |
| | config.hidden_size = self.config.hn_hidden_size |
| | config.intermediate_size = self.config.hn_intermediate_size |
| | if getattr(self.config, "hn_num_attention_heads", None) is None: |
| | self.config.hn_num_attention_heads = self.config.hn_hidden_size // 64 |
| | config.num_attention_heads = self.config.hn_num_attention_heads |
| | self.embed_init_range = config.initializer_range |
| | module_class = partial(RobertaModel, add_pooling_layer=False) |
| | elif self.config.hn_model_type == "t5": |
| | raise NotImplementedError() |
| |
|
| | if self.config.hn_embed_using_source_embeddings: |
| | |
| | config.vocab_size = self.config.pad_token_id + 1 |
| |
|
| | if ( |
| | self.config.hn_add_inter_token_attention |
| | or self.config.hn_embed_target_priors |
| | ): |
| | raise NotImplementedError() |
| |
|
| | self.pad_token_id = self.config.pad_token_id |
| | assert self.pad_token_id is not None |
| | self.model = module_class(config) |
| |
|
| | |
| | self.fallback_embeddings = nn.Embedding( |
| | max(self.config.hn_n_extra_tokens, 1), n_in_embd |
| | ) |
| |
|
| | if self.config.hn_embed_using_source_embeddings: |
| | self.input_projection = nn.Sequential( |
| | *[ |
| | nn.Linear(n_in_embd, self.config.hn_hidden_size), |
| | ProjectorBlock( |
| | self.config.hn_hidden_size, |
| | self.config.hn_hidden_size, |
| | self.config.hn_intermediate_size, |
| | ), |
| | ] |
| | ) |
| |
|
| | if self.config.hn_single_head: |
| | self.output_projection = nn.Sequential( |
| | *[ |
| | ProjectorBlock( |
| | self.config.hn_hidden_size, |
| | self.config.hn_hidden_size, |
| | self.config.hn_intermediate_size, |
| | ), |
| | nn.Linear(self.config.hn_hidden_size, n_in_embd), |
| | ] |
| | ) |
| | else: |
| | self.output_projection = nn.Sequential( |
| | *[ |
| | ProjectorBlock( |
| | self.config.hn_hidden_size, |
| | self.config.hn_hidden_size, |
| | self.config.hn_intermediate_size, |
| | ), |
| | nn.Linear(self.config.hn_hidden_size, n_out_embd), |
| | ] |
| | ) |
| | if self.has_separate_out_embeddings: |
| | self.output_projection_out = nn.Sequential( |
| | *[ |
| | ProjectorBlock( |
| | self.config.hn_hidden_size, |
| | self.config.hn_hidden_size, |
| | self.config.hn_intermediate_size, |
| | ), |
| | nn.Linear(self.config.hn_hidden_size, self.config.n_embd), |
| | ] |
| | ) |
| |
|
| | if self.config.hn_rescale_embeddings: |
| | self.in_scaler = Rescaler(n_in_embd) |
| | self.scaler = Rescaler(n_out_embd) |
| |
|
| | if self.has_separate_out_embeddings: |
| | self.out_scaler = Rescaler(self.config.n_embd) |
| |
|
| | if getattr(self.config, "hn_predict_bias", False): |
| | self.bias_projection = nn.Linear(self.config.hn_hidden_size, 1) |
| |
|
| | def __call__( |
| | self, |
| | target_surface_forms, |
| | target_priors=None, |
| | source_embeddings=None, |
| | lang_index=None, |
| | deterministic: bool = True, |
| | ): |
| | if target_priors is not None: |
| | raise NotImplementedError() |
| |
|
| | if not self.config.hn_embed_using_source_embeddings: |
| | raise NotImplementedError() |
| |
|
| | use_fallback = target_surface_forms >= self.config.original_vocab_size |
| |
|
| | main_ids = torch.minimum( |
| | target_surface_forms, torch.tensor(self.config.original_vocab_size - 1, device=self.device) |
| | ) |
| | fallback_ids = torch.maximum( |
| | target_surface_forms - self.config.original_vocab_size, torch.tensor(0, device=self.device) |
| | ) |
| |
|
| | source_embeds = F.embedding(main_ids, weight=source_embeddings) |
| |
|
| | if self.config.hn_rescale_embeddings: |
| | source_embeds = self.in_scaler(source_embeds) |
| |
|
| | inputs_embeds = torch.where( |
| | use_fallback[..., None], |
| | self.fallback_embeddings(fallback_ids), |
| | source_embeds, |
| | ) |
| | inputs_embeds = self.input_projection(inputs_embeds) |
| | attention_mask = target_surface_forms != self.pad_token_id |
| |
|
| | if self.config.hn_embed_lang_id: |
| | lang_embedding = self.lang_embeddings(lang_index).squeeze() |
| | |
| | lang_embedding -= self.model.embeddings.token_type_embeddings( |
| | torch.tensor(0, device=self.device) |
| | ) + self.model.embeddings.position_embeddings( |
| | torch.tensor(attention_mask.shape[1], device=self.device) |
| | ) |
| |
|
| | lang_embedding = lang_embedding[None, None, :].expand( |
| | inputs_embeds.shape[0], -1, -1 |
| | ) |
| |
|
| | inputs_embeds = torch.cat( |
| | [ |
| | inputs_embeds, |
| | lang_embedding, |
| | ], |
| | axis=1, |
| | ) |
| | attention_mask = torch.cat( |
| | [ |
| | attention_mask, |
| | torch.ones(lang_embedding.shape[:-1], dtype=torch.bool, device=self.device), |
| | ], |
| | axis=1, |
| | ) |
| |
|
| | position_ids = torch.broadcast_to( |
| | torch.arange(torch.atleast_2d(attention_mask).shape[-1], device=self.device), |
| | attention_mask.shape, |
| | ) |
| |
|
| | hidden_states = self.model( |
| | inputs_embeds=inputs_embeds, |
| | attention_mask=attention_mask, |
| | position_ids=position_ids, |
| | ).last_hidden_state |
| |
|
| | if self.config.hn_concat_last_hidden_state: |
| | hidden_states = hidden_states.reshape(target_surface_forms.shape[0], -1) |
| | else: |
| | hidden_states = hidden_states[:, 0] |
| |
|
| | predicted_embeddings = self.output_projection(hidden_states) |
| |
|
| | if self.config.hn_single_head: |
| | predicted_embeddings_in = predicted_embeddings[..., : self.config.n_embd] |
| |
|
| | if self.has_separate_out_embeddings: |
| | predicted_embeddings_out = predicted_embeddings[ |
| | ..., self.config.n_embd : |
| | ] |
| | else: |
| | predicted_embeddings_out = None |
| | else: |
| | predicted_embeddings_in = predicted_embeddings |
| | if self.has_separate_out_embeddings: |
| | predicted_embeddings_out = self.output_projection_out(hidden_states) |
| | else: |
| | predicted_embeddings_out = None |
| |
|
| | if self.config.hn_rescale_embeddings: |
| | predicted_embeddings_in = self.scaler(predicted_embeddings_in) |
| |
|
| | if predicted_embeddings_out is not None: |
| | predicted_embeddings_out = self.out_scaler(predicted_embeddings_out) |
| |
|
| | if getattr(self.config, "hn_predict_bias", False): |
| | predicted_bias = self.bias_projection(hidden_states)[..., 0] |
| | else: |
| | predicted_bias = torch.zeros_like( |
| | target_surface_forms[..., 0], dtype=self.dtype |
| | ) |
| |
|
| | return predicted_embeddings_in, predicted_embeddings_out, predicted_bias |
| |
|