| |
| |
| |
| import torch.nn as nn |
| import torch |
| import torch.nn.functional as F |
| from datetime import datetime |
| from concurrent.futures import ThreadPoolExecutor |
| from dataclasses import dataclass |
| from typing import Optional |
| from transformers import PreTrainedModel, PretrainedConfig, GPT2Config |
| from transformers.modeling_outputs import MaskedLMOutput |
| from gpst.gpt2_flash_attn import GPT2Model |
| from gpst.backend_loader import load_cpp_backend |
| cppbackend = load_cpp_backend() |
| from gpst.r2d2_insideoutside import * |
| import copy |
|
|
| def load_model(model, model_path, strict=True): |
| state_dict = torch.load(model_path, map_location=lambda a, b: a) |
| transfered_state_dict = {} |
| for k, v in state_dict.items(): |
| new_k = k.replace('module.', '') |
| transfered_state_dict[new_k] = v |
| model.load_state_dict(transfered_state_dict, strict=strict) |
|
|
| def index_sanity_hook(module, input, output): |
| def check(tensor): |
| if isinstance(tensor, torch.Tensor): |
| if tensor.dtype == torch.long or "int" in str(tensor.dtype): |
| if tensor.max() > 10000 or tensor.min() < -10000: |
| print(f"[!] Suspicious index in {module.__class__.__name__}: min={tensor.min().item()}, max={tensor.max().item()}, shape={tensor.shape}") |
|
|
| |
| for item in input: |
| if isinstance(item, (tuple, list)): |
| for sub in item: |
| check(sub) |
| else: |
| check(item) |
|
|
| @dataclass(kw_only=True) |
| class R2D2GenOutput(): |
| struct_loss: Optional[torch.FloatTensor] = None, |
| non_struct_loss: Optional[torch.FloatTensor] = None, |
| non_struct_loss_fullscale: Optional[torch.FloatTensor] = None, |
| action_logits: Optional[torch.FloatTensor] = None, |
| hidden_states: Optional[torch.FloatTensor] = None, |
| cls_hidden_states: Optional[torch.FloatTensor] = None, |
| tgt_ids: Optional[torch.LongTensor] = None, |
| pred: Optional[torch.FloatTensor] = None, |
| splits: Optional[torch.LongTensor] = None, |
| gpt_loss: Optional[torch.FloatTensor] = None, |
| action_loss: Optional[torch.FloatTensor] = None, |
| inside_outside_loss: Optional[torch.FloatTensor] = None, |
| parser_loss: Optional[torch.FloatTensor] = None, |
| glue_finetune_loss: Optional[torch.FloatTensor] = None, |
| past_kv: Optional[torch.FloatTensor] = None |
| logits: Optional[torch.FloatTensor] = None, |
| loss: Optional[torch.FloatTensor] = None, |
|
|
|
|
| class GPSTConfig(PretrainedConfig): |
| model_type = "gpst" |
| |
| def __init__(self, r2d2=None, gpt=None, **kwargs): |
|
|
| self.gptconfig = gpt |
| self.r2d2config = r2d2 |
| super().__init__(**kwargs) |
|
|
| class GPST(PreTrainedModel): |
| config_class = GPSTConfig |
|
|
| def __init__(self, config, gradient_checkpoint=False): |
| super().__init__(config) |
| self.config = config |
| self.r2d2_config = PretrainedConfig.from_dict(config.r2d2config) |
| self.gpt_config = GPT2Config.from_dict(config.gptconfig) |
|
|
| self.vocab_size = self.gpt_config.vocab_size |
|
|
| total_layer = self.gpt_config.n_layer |
| action_transformers = GPT2Model(copy.deepcopy(self.gpt_config), no_embedding=True, no_layer_norm=True, n_layers_manual=self.gpt_config.action_layer_num) |
| action_transformers.gradient_checkpointing = gradient_checkpoint |
| self.gpt_config.n_layer = total_layer - self.gpt_config.action_layer_num |
| self.gpt_config.num_hidden_layers = total_layer - self.gpt_config.action_layer_num |
| gpt_transformers = GPT2Model(self.gpt_config, no_embedding=True, no_extra_embedding=True) |
| gpt_transformers.gradient_checkpointing = gradient_checkpoint |
|
|
| r2d2 = InsideOutsideModule(self.r2d2_config) |
| self.model = FastGenerativeR2D2( |
| r2d2=r2d2, |
| action_layers=action_transformers, |
| generation_layers=gpt_transformers, |
| vocab_size=self.vocab_size, |
| r2d2_input_dim=r2d2.input_dim, |
| embedding_dim=self.gpt_config.n_embd, |
| ext_vocab_size=self.r2d2_config.ext_vocab_size, |
| dense_hidden_factor=self.gpt_config.dense_hidden_factor |
| ) |
|
|
| |
| |
| |
|
|
| def get_input_embeddings(self): |
| return self.model.embeddings |
| |
| def forward(self, **kwargs): |
| return self.model(**kwargs) |
|
|
| class FastGenerativeR2D2(nn.Module): |
| def __init__(self, r2d2, action_layers, generation_layers, vocab_size, |
| r2d2_input_dim, embedding_dim, dropout_rate=0.2, ext_vocab_size=0, |
| fix_embeddings=False, dense_hidden_factor=4): |
| |
| |
| super().__init__() |
| self.embedding_dim = embedding_dim |
| self.r2d2_input_dim = r2d2_input_dim |
| self.r2d2 = r2d2 |
|
|
| self.vocab_size = vocab_size |
|
|
| |
|
|
| self.enable_gpt = False |
| if action_layers is not None and generation_layers is not None: |
| self.dense_hidden_factor = dense_hidden_factor |
| self.action_layers = action_layers |
| self.generation_layers = generation_layers |
| self.bos_embedding = nn.Parameter(torch.rand(self.embedding_dim)) |
| self.up_scale = nn.Linear(self.r2d2_input_dim, self.embedding_dim) |
| self.dense = nn.Sequential(nn.Linear(self.embedding_dim, self.dense_hidden_factor * self.embedding_dim), |
| nn.GELU(), |
| nn.Dropout(dropout_rate), |
| nn.Linear(self.dense_hidden_factor * self.embedding_dim, self.embedding_dim)) |
| self.action_mlp = nn.Sequential(nn.LayerNorm(self.embedding_dim), |
| nn.Linear(self.embedding_dim, self.embedding_dim), |
| nn.GELU(), |
| nn.Dropout(dropout_rate), |
| nn.Linear(self.embedding_dim, 2)) |
| self.enable_gpt = True |
| |
| self.classifier = nn.Linear(self.embedding_dim, vocab_size, bias=False) |
| self.embeddings = nn.Embedding(vocab_size, self.embedding_dim) |
| self.embeddings.requires_grad = not fix_embeddings |
| self.down_scale = nn.Linear(self.embedding_dim, self.r2d2_input_dim) |
|
|
| self.insideoutside_dense = nn.Sequential( |
| nn.Linear(r2d2_input_dim, self.dense_hidden_factor * r2d2_input_dim), |
| nn.GELU(), |
| nn.Dropout(dropout_rate), |
| nn.Linear(self.dense_hidden_factor * r2d2_input_dim, self.embedding_dim) |
| ) |
|
|
| |
|
|
| self._init_weights() |
| self._tie_weights() |
|
|
| def _init_weights(self): |
| if self.enable_gpt: |
| self.bos_embedding.data.normal_(mean=0, std=0.02) |
| self.embeddings.weight.data.normal_(mean=0, std=0.02) |
|
|
| def _tie_weights(self): |
| self.classifier.weight = self.embeddings.weight |
|
|
| def get_parser(self): |
| return self.r2d2.parser |
| |
| def from_pretrain(self, model_path, strict=True): |
| load_model(self, model_path, strict=strict) |
| self._tie_weights() |
|
|
| def _append_eos_label(self, eos_labels, chunk_input_ids, chunk_masks, next_token_indices, max_input_len): |
| chunk_masks = (chunk_masks.sum(dim=1) > 0).to(int) |
| seq_lens = chunk_masks.sum(dim=1) |
| temp_ids = torch.zeros((chunk_input_ids.shape[0], chunk_input_ids.shape[1] + 1), dtype=chunk_input_ids.dtype, device=chunk_input_ids.device) |
| temp_ids.fill_(-100) |
| temp_ids[:, :-1] = chunk_input_ids |
| |
| temp_ids.scatter_(1, seq_lens.unsqueeze(1), torch.tensor(eos_labels, device=chunk_input_ids.device).unsqueeze(1)) |
| chunk_input_ids = temp_ids |
| next_token_indices = next_token_indices[:, :max_input_len + 1] |
| return next_token_indices, chunk_input_ids |
|
|
| def gpst_prep(self, input_ids, attention_mask, max_len=96): |
| """ |
| - `input_ids`: token ids, with padding id 0, such that 2048 tokens are reached (example shape: 66x81) for 66 sentences with max length 81 |
| - `chunk_input_ids`: The same tokens, but without padding |
| - `chunk_masks`: Shape 2x1024, and in each row values 1-something to indicate the sentence in `input_ids`. Starting with 1! |
| - `masks`: Same shape as the input ids, 1 for input tokens, 0 for padding |
| - `group_ids`: One-dim with length 66 (num sents), values 0 and 1 to indicate which sentence goes in which row of the `chunk_input_ids` and `chunk_masks` |
| - `atom_spans` is None, `span_ids` is an empty list, `external_vocab_ids` is None |
| """ |
| device = attention_mask.device |
| max_batch_len = attention_mask.sum(dim=1).max().item() |
| |
| |
| input_ids = input_ids[:,:max_batch_len] |
| attention_mask = attention_mask[:,:max_batch_len] |
| chunk_input_ids = input_ids |
| |
| group_ids = torch.arange(input_ids.shape[0], dtype=int, device=device) |
| |
| chunk_masks = attention_mask * (group_ids+1).unsqueeze(1) |
| gpst_batch = { |
| "input_ids": input_ids, |
| "masks": attention_mask, |
| "chunk_input_ids": chunk_input_ids, |
| "group_ids": group_ids.cpu().numpy(), |
| "chunk_masks": chunk_masks |
| } |
| return gpst_batch |
|
|
| def forward(self, chunk_input_ids= None, chunk_masks=None, input_ids=None, masks=None, eos_labels=None, group_ids=None, |
| atom_spans=None, span_ids=None, external_vocab_ids=None, attention_mask=None, |
| coeff=1.0, temperature=1.0, past_key_values=None): |
| |
| gpst_batch = self.gpst_prep(input_ids, attention_mask) |
| input_ids = gpst_batch["input_ids"] |
| masks = gpst_batch["masks"] |
| chunk_input_ids = gpst_batch["chunk_input_ids"] |
| group_ids = gpst_batch["group_ids"] |
| chunk_masks = gpst_batch["chunk_masks"] |
| |
| batch_size = max(group_ids) + 1 |
| r2d2_input_ids = torch.where(chunk_input_ids == -100, 0, chunk_input_ids) |
| input_embeddings = self.embeddings(r2d2_input_ids) |
| r2d2_embeddings = self.down_scale(input_embeddings) |
| |
| max_input_len = (chunk_masks != 0).sum(dim=1).max().to('cpu', non_blocking=True) |
| |
| ctx, outside_tgt, ldr_repr, position_ids, tgt_ids, token_indices, ext_ids, split_targets, l_height = \ |
| self.r2d2(r2d2_input_ids, chunk_masks, input_ids, masks, r2d2_embeddings, group_ids, |
| max_input_len, atom_spans=atom_spans, coeff=coeff, temperature=temperature, span_ids=span_ids, |
| eos_labels=eos_labels, external_vocab_ids=external_vocab_ids) |
|
|
|
|
| if self.training: |
| |
| parser_loss = self.r2d2.parser_loss(ctx) |
| outside_embeddings = self.r2d2.outside_embeddings(ctx) |
| io_dense = self.insideoutside_dense(outside_embeddings) |
| outside_logits = self.classifier(io_dense) |
| insideoutside_loss = F.cross_entropy(outside_logits, outside_tgt) |
| else: |
| parser_loss = insideoutside_loss = 0 |
| |
| logits = action_logits = None |
| gpt_loss = action_loss = 0 |
| past_kv = None |
| hidden_states = None |
|
|
| if self.enable_gpt: |
| if past_key_values is not None: |
| action_past_kv, gen_past_kv = past_key_values |
| else: |
| action_past_kv = gen_past_kv = None |
| gpt_input = self.up_scale(ldr_repr).clone() |
| gpt_input.scatter_(1, token_indices.unsqueeze(2).repeat(1, 1, input_embeddings.shape[-1]).clone(), |
| input_embeddings.to(gpt_input.dtype)) |
| |
| |
| |
| bos_emb = self.bos_embedding.unsqueeze(0).repeat(batch_size, 1) |
| |
| cat_input = torch.cat([bos_emb.unsqueeze(1), gpt_input], dim=1) |
| |
| |
| outputs = self.action_layers(inputs_embeds=cat_input, position_ids=position_ids, past_key_values=action_past_kv) |
| action_logits = self.action_mlp(outputs.last_hidden_state) |
| |
| |
| |
| action_tgt = torch.where(tgt_ids == self.r2d2.reduce_id, 1, 0) |
| action_tgt = torch.where(tgt_ids != -1, action_tgt, -1) |
| |
|
|
| next_token_indices = (tgt_ids != self.r2d2.reduce_id).int().argsort(dim=-1, descending=True, stable=True) |
| if eos_labels is None: |
| |
| truncated_len = max_input_len |
| next_token_indices = next_token_indices[:, :truncated_len] |
| else: |
| next_token_indices, chunk_input_ids = self._append_eos_label(eos_labels, chunk_input_ids, chunk_masks, next_token_indices, max_input_len) |
| |
| |
| |
| next_token_indices_reformat = next_token_indices.unsqueeze(2).repeat(1, 1, self.embedding_dim) |
| generation_inputs = outputs.last_hidden_state.gather(1, next_token_indices_reformat) |
| |
| |
| |
| token_outputs = self.generation_layers(inputs_embeds=generation_inputs, past_key_values=gen_past_kv) |
|
|
| hidden_states = token_outputs.last_hidden_state |
| logits = self.classifier(self.dense(hidden_states)) |
| |
| |
| |
| gpt_loss = F.cross_entropy(logits.permute(0, 2, 1), chunk_input_ids, ignore_index=-100) |
| action_loss = F.cross_entropy(action_logits.permute(0, 2, 1), action_tgt, ignore_index=-1) |
| past_kv = (outputs.past_key_values, token_outputs.past_key_values) |
|
|
| |
| |
| return R2D2GenOutput(struct_loss=insideoutside_loss + l_height, |
| non_struct_loss=0.5 * gpt_loss + action_loss + parser_loss, |
| non_struct_loss_fullscale=gpt_loss + action_loss + parser_loss, |
| logits=logits, |
| action_logits=action_logits, |
| hidden_states=hidden_states, |
| tgt_ids=chunk_input_ids, |
| gpt_loss=gpt_loss, |
| action_loss=action_loss, |
| inside_outside_loss=insideoutside_loss, |
| parser_loss=parser_loss, |
| past_kv=past_kv, |
| splits=split_targets, |
| loss=action_loss+gpt_loss+parser_loss+insideoutside_loss+l_height) |
| |
|
|
| class FastGenerativeR2D2_discriminant_glue(FastGenerativeR2D2): |
| |
| def _append_eos_label(self, eos_labels, chunk_input_ids, chunk_masks, next_token_indices, max_input_len): |
| chunk_masks = (chunk_masks.sum(dim=1) > 0).to(int) |
| seq_lens = chunk_masks.sum(dim=1) |
| temp_ids = torch.zeros((chunk_input_ids.shape[0], chunk_input_ids.shape[1] + 1), dtype=chunk_input_ids.dtype, device=chunk_input_ids.device) |
| temp_ids.fill_(-100) |
| temp_ids[:, :-1] = chunk_input_ids |
| |
| chunk_input_ids = temp_ids |
| next_token_indices = next_token_indices[:, :max_input_len + 1] |
| return next_token_indices, chunk_input_ids |