creative-help / rnnlm_model /modeling_rnnlm.py
roemmele's picture
Upload folder using huggingface_hub
59a876a verified
# coding: utf-8
"""RNN Language Model for HuggingFace Transformers - PyTorch implementation."""
import torch
import torch.nn as nn
try:
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.generation import GenerationMixin, LogitsProcessor, LogitsProcessorList
except ImportError:
from transformers.modeling_utils import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
try:
from transformers.generation import GenerationMixin, LogitsProcessor, LogitsProcessorList
except ImportError:
try:
from transformers.generation_utils import GenerationMixin, LogitsProcessor, LogitsProcessorList
except ImportError:
from transformers.generation_utils import LogitsProcessor, LogitsProcessorList
GenerationMixin = None
from .configuration_rnnlm import RNNLMConfig
class PreventUnkLogitsProcessor(LogitsProcessor):
"""
Redistribute probability from pad (0) and unk (1) to other tokens before sampling.
Matches the original Keras model's prevent_unk behavior.
"""
def __init__(self, pad_token_id: int = 0, unk_token_id: int = 1):
self.pad_token_id = pad_token_id
self.unk_token_id = unk_token_id
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# Set pad and unk logits to very small value so they're never sampled
scores = scores.clone()
scores[:, self.pad_token_id] = -1e8
scores[:, self.unk_token_id] = -1e8
return scores
class GRUKerasCompat(nn.Module):
"""
GRU matching Keras reset_after=False (GRU v1).
Keras: h_new = tanh(W_h·x + W_hn·(r⊙h))
PyTorch default: h_new = tanh(W_h·x + r⊙(W_hn·h))
We implement the Keras formulation for correct conversion.
Uses same weight layout as nn.GRU: [r, z, n] gate order.
"""
def __init__(self, input_size: int, hidden_size: int, batch_first: bool = True):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.batch_first = batch_first
self.weight_ih = nn.Parameter(torch.empty(3 * hidden_size, input_size))
self.weight_hh = nn.Parameter(torch.empty(3 * hidden_size, hidden_size))
self.bias_ih = nn.Parameter(torch.empty(3 * hidden_size))
self.bias_hh = nn.Parameter(torch.empty(3 * hidden_size))
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform_(self.weight_ih)
nn.init.xavier_uniform_(self.weight_hh)
nn.init.zeros_(self.bias_ih)
nn.init.zeros_(self.bias_hh)
def forward(self, x: torch.Tensor, h_0: torch.Tensor = None):
if self.batch_first:
x = x # (batch, seq, input)
else:
x = x.transpose(0, 1)
batch, seq_len, _ = x.shape
if h_0 is None:
h = x.new_zeros(batch, self.hidden_size)
else:
h = h_0.squeeze(0) # (batch, hidden)
outputs = []
for t in range(seq_len):
x_t = x[:, t, :] # (batch, input)
# Gates: weight layout [r, z, n], each (hidden, input) or (hidden, hidden)
r_ih = x_t @ self.weight_ih[:self.hidden_size].t() + self.bias_ih[:self.hidden_size]
z_ih = x_t @ self.weight_ih[self.hidden_size:2*self.hidden_size].t() + self.bias_ih[self.hidden_size:2*self.hidden_size]
n_ih = x_t @ self.weight_ih[2*self.hidden_size:].t() + self.bias_ih[2*self.hidden_size:]
r_hh = h @ self.weight_hh[:self.hidden_size].t() + self.bias_hh[:self.hidden_size]
z_hh = h @ self.weight_hh[self.hidden_size:2*self.hidden_size].t() + self.bias_hh[self.hidden_size:2*self.hidden_size]
n_hh = (h * torch.sigmoid(r_ih + r_hh)) @ self.weight_hh[2*self.hidden_size:].t() + self.bias_hh[2*self.hidden_size:]
r = torch.sigmoid(r_ih + r_hh)
z = torch.sigmoid(z_ih + z_hh)
n = torch.tanh(n_ih + n_hh)
h = (1 - z) * n + z * h
outputs.append(h)
output = torch.stack(outputs, dim=1) # (batch, seq, hidden)
if not self.batch_first:
output = output.transpose(0, 1)
return output, h.unsqueeze(0)
class RNNLMForCausalLM(PreTrainedModel):
"""
RNN-based Causal Language Model for text generation.
Compatible with HuggingFace TextGenerationPipeline.
Supports base model (no POS, no features). POS and features require
additional preprocessing at generation time.
"""
config_class = RNNLMConfig
base_model_prefix = "rnnlm"
supports_gradient_checkpointing = False
_no_split_modules = []
def __init__(self, config: RNNLMConfig, **kwargs):
super().__init__(config)
self.config = config
# RNNLM has no tied weights; transformers expects this attribute (dict) for .update()
self.all_tied_weights_keys = {}
self.vocab_size = config.vocab_size
self.embedding_dim = config.embedding_dim
self.hidden_size = config.hidden_size
self.num_hidden_layers = config.num_hidden_layers
self.use_pos = getattr(config, "use_pos", False)
self.use_features = getattr(config, "use_features", False)
# Embedding layer (vocab_size + 1 for padding at index 0)
self.embedding = nn.Embedding(
config.vocab_size + 1,
config.embedding_dim,
padding_idx=0,
)
# GRU layers (Keras reset_after=False compatible)
self.gru_layers = nn.ModuleList()
for i in range(config.num_hidden_layers):
input_size = config.embedding_dim if i == 0 else config.hidden_size
self.gru_layers.append(
GRUKerasCompat(
input_size=input_size,
hidden_size=config.hidden_size,
batch_first=True,
)
)
# Output size after GRU
lm_input_size = config.hidden_size
# Optional POS branch (for loading converted models - generation needs external POS)
if self.use_pos:
self.pos_embedding = nn.Embedding(
config.n_pos_tags + 1,
config.n_pos_embedding_nodes,
padding_idx=0,
)
self.pos_gru = nn.GRU(
input_size=config.n_pos_embedding_nodes,
hidden_size=config.n_pos_nodes,
num_layers=1,
batch_first=True,
)
lm_input_size = lm_input_size + config.n_pos_nodes
else:
self.pos_embedding = None
self.pos_gru = None
# Optional feature branch
if self.use_features:
self.feature_dense = nn.Sequential(
nn.Linear(config.vocab_size + 1, config.n_feature_nodes),
nn.Sigmoid(),
)
lm_input_size = lm_input_size + config.n_feature_nodes
else:
self.feature_dense = None
# Output projection
self.lm_head = nn.Linear(lm_input_size, config.vocab_size + 1)
# Initialize weights
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def get_input_embeddings(self):
return self.embedding
def set_input_embeddings(self, value):
self.embedding = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
"""
For RNN: past_key_values stores the hidden state tuple (h_n for each GRU layer).
During generation we only need the last token and the cached hidden state.
"""
if past_key_values is not None:
input_ids = input_ids[:, -1:]
return {"input_ids": input_ids, "past_key_values": past_key_values}
def forward(
self,
input_ids=None,
attention_mask=None,
past_key_values=None,
position_ids=None,
pos_ids=None,
feature_vecs=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
use_cache = use_cache if use_cache is not None else self.config.use_cache
# Get embeddings
inputs_embeds = self.embedding(input_ids)
# Run through GRU layers
hidden_states = inputs_embeds
new_past_key_values = () if use_cache else None
for i, gru_layer in enumerate(self.gru_layers):
if past_key_values is not None and len(past_key_values) > i:
h_0 = past_key_values[i]
hidden_states, h_n = gru_layer(hidden_states, h_0)
else:
hidden_states, h_n = gru_layer(hidden_states)
if use_cache:
new_past_key_values = new_past_key_values + (h_n,)
# Optional: concatenate POS hidden states (requires pos_ids at each step)
if self.use_pos and pos_ids is not None:
pos_embeds = self.pos_embedding(pos_ids)
_, pos_h_n = self.pos_gru(pos_embeds)
pos_hidden = pos_h_n.squeeze(0).unsqueeze(
1).expand(-1, hidden_states.size(1), -1)
hidden_states = torch.cat([hidden_states, pos_hidden], dim=-1)
# Optional: concatenate feature vectors
if self.use_features and feature_vecs is not None:
features = self.feature_dense(feature_vecs)
features = features.unsqueeze(
1).expand(-1, hidden_states.size(1), -1)
hidden_states = torch.cat([hidden_states, features], dim=-1)
# Project to vocabulary
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
)
if not return_dict:
output = (logits,) + (new_past_key_values,
) if use_cache else (logits,)
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=new_past_key_values,
hidden_states=None,
attentions=None,
)
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
"""Reorder past_key_values for beam search."""
return tuple(layer_past.index_select(0, beam_idx) for layer_past in past_key_values)
def generate(self, inputs=None, **kwargs):
"""Override to add prevent_unk (pad/unk suppression) during generation."""
pad_id = getattr(self.config, "pad_token_id", 0)
unk_id = getattr(self.config, "unk_token_id", 1)
processor = PreventUnkLogitsProcessor(pad_token_id=pad_id, unk_token_id=unk_id)
logits_processor = kwargs.pop("logits_processor", None)
if logits_processor is None:
logits_processor = LogitsProcessorList()
elif not isinstance(logits_processor, LogitsProcessorList):
logits_processor = LogitsProcessorList(logits_processor)
logits_processor.insert(0, processor)
kwargs["logits_processor"] = logits_processor
# RNNLM uses tuple cache (hidden states), not DynamicCache; avoid cache to prevent "not subscriptable" error
kwargs.setdefault("use_cache", False)
# Call GenerationMixin.generate explicitly (super() can fail in some loading contexts)
if GenerationMixin is not None:
return GenerationMixin.generate(self, inputs, **kwargs)
return super().generate(inputs, **kwargs)