| |
| """RNN Language Model for HuggingFace Transformers - PyTorch implementation.""" |
|
|
| import torch |
| import torch.nn as nn |
| try: |
| from transformers import PreTrainedModel |
| from transformers.modeling_outputs import CausalLMOutputWithPast |
| from transformers.generation import GenerationMixin, LogitsProcessor, LogitsProcessorList |
| except ImportError: |
| from transformers.modeling_utils import PreTrainedModel |
| from transformers.modeling_outputs import CausalLMOutputWithPast |
| try: |
| from transformers.generation import GenerationMixin, LogitsProcessor, LogitsProcessorList |
| except ImportError: |
| try: |
| from transformers.generation_utils import GenerationMixin, LogitsProcessor, LogitsProcessorList |
| except ImportError: |
| from transformers.generation_utils import LogitsProcessor, LogitsProcessorList |
| GenerationMixin = None |
|
|
| from .configuration_rnnlm import RNNLMConfig |
|
|
|
|
| class PreventUnkLogitsProcessor(LogitsProcessor): |
| """ |
| Redistribute probability from pad (0) and unk (1) to other tokens before sampling. |
| Matches the original Keras model's prevent_unk behavior. |
| """ |
|
|
| def __init__(self, pad_token_id: int = 0, unk_token_id: int = 1): |
| self.pad_token_id = pad_token_id |
| self.unk_token_id = unk_token_id |
|
|
| def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: |
| |
| scores = scores.clone() |
| scores[:, self.pad_token_id] = -1e8 |
| scores[:, self.unk_token_id] = -1e8 |
| return scores |
|
|
|
|
| class GRUKerasCompat(nn.Module): |
| """ |
| GRU matching Keras reset_after=False (GRU v1). |
| Keras: h_new = tanh(W_h·x + W_hn·(r⊙h)) |
| PyTorch default: h_new = tanh(W_h·x + r⊙(W_hn·h)) |
| We implement the Keras formulation for correct conversion. |
| Uses same weight layout as nn.GRU: [r, z, n] gate order. |
| """ |
|
|
| def __init__(self, input_size: int, hidden_size: int, batch_first: bool = True): |
| super().__init__() |
| self.input_size = input_size |
| self.hidden_size = hidden_size |
| self.batch_first = batch_first |
| self.weight_ih = nn.Parameter(torch.empty(3 * hidden_size, input_size)) |
| self.weight_hh = nn.Parameter(torch.empty(3 * hidden_size, hidden_size)) |
| self.bias_ih = nn.Parameter(torch.empty(3 * hidden_size)) |
| self.bias_hh = nn.Parameter(torch.empty(3 * hidden_size)) |
| self.reset_parameters() |
|
|
| def reset_parameters(self): |
| nn.init.xavier_uniform_(self.weight_ih) |
| nn.init.xavier_uniform_(self.weight_hh) |
| nn.init.zeros_(self.bias_ih) |
| nn.init.zeros_(self.bias_hh) |
|
|
| def forward(self, x: torch.Tensor, h_0: torch.Tensor = None): |
| if self.batch_first: |
| x = x |
| else: |
| x = x.transpose(0, 1) |
| batch, seq_len, _ = x.shape |
| if h_0 is None: |
| h = x.new_zeros(batch, self.hidden_size) |
| else: |
| h = h_0.squeeze(0) |
|
|
| outputs = [] |
| for t in range(seq_len): |
| x_t = x[:, t, :] |
| |
| r_ih = x_t @ self.weight_ih[:self.hidden_size].t() + self.bias_ih[:self.hidden_size] |
| z_ih = x_t @ self.weight_ih[self.hidden_size:2*self.hidden_size].t() + self.bias_ih[self.hidden_size:2*self.hidden_size] |
| n_ih = x_t @ self.weight_ih[2*self.hidden_size:].t() + self.bias_ih[2*self.hidden_size:] |
|
|
| r_hh = h @ self.weight_hh[:self.hidden_size].t() + self.bias_hh[:self.hidden_size] |
| z_hh = h @ self.weight_hh[self.hidden_size:2*self.hidden_size].t() + self.bias_hh[self.hidden_size:2*self.hidden_size] |
| n_hh = (h * torch.sigmoid(r_ih + r_hh)) @ self.weight_hh[2*self.hidden_size:].t() + self.bias_hh[2*self.hidden_size:] |
|
|
| r = torch.sigmoid(r_ih + r_hh) |
| z = torch.sigmoid(z_ih + z_hh) |
| n = torch.tanh(n_ih + n_hh) |
| h = (1 - z) * n + z * h |
| outputs.append(h) |
|
|
| output = torch.stack(outputs, dim=1) |
| if not self.batch_first: |
| output = output.transpose(0, 1) |
| return output, h.unsqueeze(0) |
|
|
|
|
| class RNNLMForCausalLM(PreTrainedModel): |
| """ |
| RNN-based Causal Language Model for text generation. |
| Compatible with HuggingFace TextGenerationPipeline. |
| Supports base model (no POS, no features). POS and features require |
| additional preprocessing at generation time. |
| """ |
|
|
| config_class = RNNLMConfig |
| base_model_prefix = "rnnlm" |
| supports_gradient_checkpointing = False |
| _no_split_modules = [] |
|
|
| def __init__(self, config: RNNLMConfig, **kwargs): |
| super().__init__(config) |
| self.config = config |
| |
| self.all_tied_weights_keys = {} |
| self.vocab_size = config.vocab_size |
| self.embedding_dim = config.embedding_dim |
| self.hidden_size = config.hidden_size |
| self.num_hidden_layers = config.num_hidden_layers |
| self.use_pos = getattr(config, "use_pos", False) |
| self.use_features = getattr(config, "use_features", False) |
|
|
| |
| self.embedding = nn.Embedding( |
| config.vocab_size + 1, |
| config.embedding_dim, |
| padding_idx=0, |
| ) |
|
|
| |
| self.gru_layers = nn.ModuleList() |
| for i in range(config.num_hidden_layers): |
| input_size = config.embedding_dim if i == 0 else config.hidden_size |
| self.gru_layers.append( |
| GRUKerasCompat( |
| input_size=input_size, |
| hidden_size=config.hidden_size, |
| batch_first=True, |
| ) |
| ) |
|
|
| |
| lm_input_size = config.hidden_size |
|
|
| |
| if self.use_pos: |
| self.pos_embedding = nn.Embedding( |
| config.n_pos_tags + 1, |
| config.n_pos_embedding_nodes, |
| padding_idx=0, |
| ) |
| self.pos_gru = nn.GRU( |
| input_size=config.n_pos_embedding_nodes, |
| hidden_size=config.n_pos_nodes, |
| num_layers=1, |
| batch_first=True, |
| ) |
| lm_input_size = lm_input_size + config.n_pos_nodes |
| else: |
| self.pos_embedding = None |
| self.pos_gru = None |
|
|
| |
| if self.use_features: |
| self.feature_dense = nn.Sequential( |
| nn.Linear(config.vocab_size + 1, config.n_feature_nodes), |
| nn.Sigmoid(), |
| ) |
| lm_input_size = lm_input_size + config.n_feature_nodes |
| else: |
| self.feature_dense = None |
|
|
| |
| self.lm_head = nn.Linear(lm_input_size, config.vocab_size + 1) |
|
|
| |
| self.apply(self._init_weights) |
|
|
| def _init_weights(self, module): |
| if isinstance(module, nn.Linear): |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| if module.bias is not None: |
| torch.nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.Embedding): |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| if module.padding_idx is not None: |
| module.weight.data[module.padding_idx].zero_() |
|
|
| def get_input_embeddings(self): |
| return self.embedding |
|
|
| def set_input_embeddings(self, value): |
| self.embedding = value |
|
|
| def get_output_embeddings(self): |
| return self.lm_head |
|
|
| def set_output_embeddings(self, new_embeddings): |
| self.lm_head = new_embeddings |
|
|
| def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): |
| """ |
| For RNN: past_key_values stores the hidden state tuple (h_n for each GRU layer). |
| During generation we only need the last token and the cached hidden state. |
| """ |
| if past_key_values is not None: |
| input_ids = input_ids[:, -1:] |
| return {"input_ids": input_ids, "past_key_values": past_key_values} |
|
|
| def forward( |
| self, |
| input_ids=None, |
| attention_mask=None, |
| past_key_values=None, |
| position_ids=None, |
| pos_ids=None, |
| feature_vecs=None, |
| labels=None, |
| use_cache=None, |
| output_attentions=None, |
| output_hidden_states=None, |
| return_dict=None, |
| ): |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
|
| |
| inputs_embeds = self.embedding(input_ids) |
|
|
| |
| hidden_states = inputs_embeds |
| new_past_key_values = () if use_cache else None |
|
|
| for i, gru_layer in enumerate(self.gru_layers): |
| if past_key_values is not None and len(past_key_values) > i: |
| h_0 = past_key_values[i] |
| hidden_states, h_n = gru_layer(hidden_states, h_0) |
| else: |
| hidden_states, h_n = gru_layer(hidden_states) |
|
|
| if use_cache: |
| new_past_key_values = new_past_key_values + (h_n,) |
|
|
| |
| if self.use_pos and pos_ids is not None: |
| pos_embeds = self.pos_embedding(pos_ids) |
| _, pos_h_n = self.pos_gru(pos_embeds) |
| pos_hidden = pos_h_n.squeeze(0).unsqueeze( |
| 1).expand(-1, hidden_states.size(1), -1) |
| hidden_states = torch.cat([hidden_states, pos_hidden], dim=-1) |
|
|
| |
| if self.use_features and feature_vecs is not None: |
| features = self.feature_dense(feature_vecs) |
| features = features.unsqueeze( |
| 1).expand(-1, hidden_states.size(1), -1) |
| hidden_states = torch.cat([hidden_states, features], dim=-1) |
|
|
| |
| logits = self.lm_head(hidden_states) |
|
|
| loss = None |
| if labels is not None: |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| loss_fct = nn.CrossEntropyLoss() |
| loss = loss_fct( |
| shift_logits.view(-1, shift_logits.size(-1)), |
| shift_labels.view(-1), |
| ) |
|
|
| if not return_dict: |
| output = (logits,) + (new_past_key_values, |
| ) if use_cache else (logits,) |
| return ((loss,) + output) if loss is not None else output |
|
|
| return CausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=new_past_key_values, |
| hidden_states=None, |
| attentions=None, |
| ) |
|
|
| @staticmethod |
| def _reorder_cache(past_key_values, beam_idx): |
| """Reorder past_key_values for beam search.""" |
| return tuple(layer_past.index_select(0, beam_idx) for layer_past in past_key_values) |
|
|
| def generate(self, inputs=None, **kwargs): |
| """Override to add prevent_unk (pad/unk suppression) during generation.""" |
| pad_id = getattr(self.config, "pad_token_id", 0) |
| unk_id = getattr(self.config, "unk_token_id", 1) |
| processor = PreventUnkLogitsProcessor(pad_token_id=pad_id, unk_token_id=unk_id) |
| logits_processor = kwargs.pop("logits_processor", None) |
| if logits_processor is None: |
| logits_processor = LogitsProcessorList() |
| elif not isinstance(logits_processor, LogitsProcessorList): |
| logits_processor = LogitsProcessorList(logits_processor) |
| logits_processor.insert(0, processor) |
| kwargs["logits_processor"] = logits_processor |
| |
| kwargs.setdefault("use_cache", False) |
| |
| if GenerationMixin is not None: |
| return GenerationMixin.generate(self, inputs, **kwargs) |
| return super().generate(inputs, **kwargs) |
|
|