RougeBERT / RougeBERTHF.py
gbyuvd's picture
Upload base codes
48d0053 verified
# RougeBERTHF.py
# Minimal Implementation of HuggingFace's Wrapper for RougeBERT
# Beware: not yet pushed into Automodel's directory
# by @gbyuvd
import math
from typing import Optional, Union, Tuple, List
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.modeling_outputs import MaskedLMOutput
from transformers.modeling_utils import PreTrainedModel
from transformers.configuration_utils import PretrainedConfig
from RougeBERT import RougeBERT as RougeCore
# -------------------------
# Config
# -------------------------
class RougeBERTConfig(PretrainedConfig):
model_type = "rougebert"
def __init__(
self,
vocab_size=1237,
max_seq=512,
num_layers=8,
hidden_size=320,
intermediate_size=1280,
num_heads=8,
kv_groups=2,
rotary_max_seq=1024,
window=16,
dropout=0.1,
ff_dropout=0.1,
**kwargs,
):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.max_seq = max_seq
self.num_layers = num_layers
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_heads = num_heads
self.kv_groups = kv_groups
self.rotary_max_seq = rotary_max_seq
self.window = window
self.dropout = dropout
self.ff_dropout = ff_dropout
# -------------------------
# Hugging Face Model Wrapper
# -------------------------
class RougeBERTForMaskedLM(PreTrainedModel):
config_class = RougeBERTConfig
def __init__(self, config: RougeBERTConfig):
super().__init__(config)
self.model = RougeCore(
vocab_size=config.vocab_size,
max_seq=config.max_seq,
num_layers=config.num_layers,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
num_heads=config.num_heads,
kv_groups=config.kv_groups,
rotary_max_seq=config.rotary_max_seq,
window=config.window,
dropout=config.dropout,
)
self.post_init()
def _init_weights(self, module):
"""LLaMA-style initialization"""
if isinstance(module, nn.Linear):
std = 1.0 / math.sqrt(module.in_features)
if getattr(module, "_is_residual", False):
std = std / math.sqrt(2 * self.config.num_layers)
nn.init.normal_(module.weight, mean=0.0, std=std)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=1.0 / math.sqrt(self.config.hidden_size))
@property
def _tied_weights_keys(self):
return ["model.lm_head.weight"]
def tie_weights(self):
"""Tie lm_head to tok_embeddings"""
self.model.lm_head.weight = self.model.tok_embeddings.weight
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
global_positions: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[MaskedLMOutput, Tuple[torch.Tensor, ...]]:
"""
Forward pass for RougeBERT masked language modeling.
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens.
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid attention on padding token indices.
global_positions (`torch.LongTensor` of shape `(batch_size, num_globals)`, *optional*):
Indices of global tokens. Use `-1` for padding.
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss.
output_attentions (`bool`, *optional*):
Whether to return attentions weights.
output_hidden_states (`bool`, *optional*):
Whether to return hidden states (not yet implemented).
return_dict (`bool`, *optional*):
Whether to return a `MaskedLMOutput` instead of a plain tuple.
Returns:
`MaskedLMOutput` or `tuple(torch.FloatTensor)`
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
global_positions=global_positions,
labels=labels,
output_attentions=output_attentions,
)
# Parse core output
if isinstance(outputs, tuple) and len(outputs) == 2:
if output_attentions and isinstance(outputs[1], list):
# (loss_logits_tuple, attentions_list)
core_output, attentions = outputs
if isinstance(core_output, tuple):
loss, logits = core_output
else:
loss, logits = None, core_output
else:
# (loss, logits) — attentions not returned
loss, logits = outputs
attentions = None
else:
loss, logits, attentions = None, outputs, None
if not return_dict:
output = (logits,) + ((attentions,) if attentions is not None else ())
return ((loss,) + output) if loss is not None else output
return MaskedLMOutput(
loss=loss,
logits=logits,
hidden_states=None, # Extend later if needed
attentions=attentions,
)
# -------------------------
# Register with AutoClasses (not yet as this is a mere prototype)
# -------------------------
# from transformers import AutoConfig, AutoModelForMaskedLM
# AutoConfig.register("rougebert", RougeBERTConfig)
# AutoModelForMaskedLM.register(RougeBERTConfig, RougeBERTForMaskedLM)