|
|
|
|
|
|
|
|
|
|
| import math
|
| from typing import Optional, Union, Tuple, List
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| from transformers.modeling_outputs import MaskedLMOutput
|
| from transformers.modeling_utils import PreTrainedModel
|
| from transformers.configuration_utils import PretrainedConfig
|
|
|
| from RougeBERT import RougeBERT as RougeCore
|
|
|
|
|
|
|
|
|
| class RougeBERTConfig(PretrainedConfig):
|
| model_type = "rougebert"
|
|
|
| def __init__(
|
| self,
|
| vocab_size=1237,
|
| max_seq=512,
|
| num_layers=8,
|
| hidden_size=320,
|
| intermediate_size=1280,
|
| num_heads=8,
|
| kv_groups=2,
|
| rotary_max_seq=1024,
|
| window=16,
|
| dropout=0.1,
|
| ff_dropout=0.1,
|
| **kwargs,
|
| ):
|
| super().__init__(**kwargs)
|
| self.vocab_size = vocab_size
|
| self.max_seq = max_seq
|
| self.num_layers = num_layers
|
| self.hidden_size = hidden_size
|
| self.intermediate_size = intermediate_size
|
| self.num_heads = num_heads
|
| self.kv_groups = kv_groups
|
| self.rotary_max_seq = rotary_max_seq
|
| self.window = window
|
| self.dropout = dropout
|
| self.ff_dropout = ff_dropout
|
|
|
|
|
|
|
|
|
|
|
| class RougeBERTForMaskedLM(PreTrainedModel):
|
| config_class = RougeBERTConfig
|
|
|
| def __init__(self, config: RougeBERTConfig):
|
| super().__init__(config)
|
| self.model = RougeCore(
|
| vocab_size=config.vocab_size,
|
| max_seq=config.max_seq,
|
| num_layers=config.num_layers,
|
| hidden_size=config.hidden_size,
|
| intermediate_size=config.intermediate_size,
|
| num_heads=config.num_heads,
|
| kv_groups=config.kv_groups,
|
| rotary_max_seq=config.rotary_max_seq,
|
| window=config.window,
|
| dropout=config.dropout,
|
| )
|
| self.post_init()
|
|
|
| def _init_weights(self, module):
|
| """LLaMA-style initialization"""
|
| if isinstance(module, nn.Linear):
|
| std = 1.0 / math.sqrt(module.in_features)
|
| if getattr(module, "_is_residual", False):
|
| std = std / math.sqrt(2 * self.config.num_layers)
|
| nn.init.normal_(module.weight, mean=0.0, std=std)
|
| if module.bias is not None:
|
| nn.init.zeros_(module.bias)
|
| elif isinstance(module, nn.Embedding):
|
| nn.init.normal_(module.weight, mean=0.0, std=1.0 / math.sqrt(self.config.hidden_size))
|
|
|
| @property
|
| def _tied_weights_keys(self):
|
| return ["model.lm_head.weight"]
|
|
|
| def tie_weights(self):
|
| """Tie lm_head to tok_embeddings"""
|
| self.model.lm_head.weight = self.model.tok_embeddings.weight
|
|
|
| def forward(
|
| self,
|
| input_ids: Optional[torch.Tensor] = None,
|
| attention_mask: Optional[torch.Tensor] = None,
|
| global_positions: Optional[torch.Tensor] = None,
|
| labels: Optional[torch.Tensor] = None,
|
| output_attentions: Optional[bool] = None,
|
| output_hidden_states: Optional[bool] = None,
|
| return_dict: Optional[bool] = None,
|
| **kwargs,
|
| ) -> Union[MaskedLMOutput, Tuple[torch.Tensor, ...]]:
|
| """
|
| Forward pass for RougeBERT masked language modeling.
|
|
|
| Args:
|
| input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| Indices of input sequence tokens.
|
| attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| Mask to avoid attention on padding token indices.
|
| global_positions (`torch.LongTensor` of shape `(batch_size, num_globals)`, *optional*):
|
| Indices of global tokens. Use `-1` for padding.
|
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| Labels for computing the masked language modeling loss.
|
| output_attentions (`bool`, *optional*):
|
| Whether to return attentions weights.
|
| output_hidden_states (`bool`, *optional*):
|
| Whether to return hidden states (not yet implemented).
|
| return_dict (`bool`, *optional*):
|
| Whether to return a `MaskedLMOutput` instead of a plain tuple.
|
|
|
| Returns:
|
| `MaskedLMOutput` or `tuple(torch.FloatTensor)`
|
| """
|
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
| outputs = self.model(
|
| input_ids=input_ids,
|
| attention_mask=attention_mask,
|
| global_positions=global_positions,
|
| labels=labels,
|
| output_attentions=output_attentions,
|
| )
|
|
|
|
|
| if isinstance(outputs, tuple) and len(outputs) == 2:
|
| if output_attentions and isinstance(outputs[1], list):
|
|
|
| core_output, attentions = outputs
|
| if isinstance(core_output, tuple):
|
| loss, logits = core_output
|
| else:
|
| loss, logits = None, core_output
|
| else:
|
|
|
| loss, logits = outputs
|
| attentions = None
|
| else:
|
| loss, logits, attentions = None, outputs, None
|
|
|
| if not return_dict:
|
| output = (logits,) + ((attentions,) if attentions is not None else ())
|
| return ((loss,) + output) if loss is not None else output
|
|
|
| return MaskedLMOutput(
|
| loss=loss,
|
| logits=logits,
|
| hidden_states=None,
|
| attentions=attentions,
|
| )
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|