Spaces:
Running
Running
| import types | |
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from torch import nn | |
| from transformers import T5ForConditionalGeneration, T5EncoderModel, AutoModel, LogitsProcessor, LogitsProcessorList, PreTrainedModel | |
| from functools import partial | |
| from undecorate import unwrap | |
| from types import MethodType | |
| from utils import * | |
| from ling_disc import DebertaReplacedTokenizer | |
| from const import * | |
| from lingconv_t5 import LingConvT5ForConditionalGeneration | |
| from dataclasses import dataclass | |
| from transformers.modeling_outputs import Seq2SeqLMOutput | |
| from typing import Optional, Dict, Any | |
| def vae_sample(mu, logvar): | |
| std = torch.exp(0.5 * logvar) | |
| eps = torch.randn_like(std) | |
| return eps * std + mu | |
| class VAE(nn.Module): | |
| def __init__(self, args): | |
| super().__init__() | |
| self.encoder = nn.Sequential( | |
| nn.Linear(args.input_dim, args.hidden_dim), | |
| nn.ReLU(), | |
| nn.Linear(args.hidden_dim, args.hidden_dim), | |
| nn.ReLU(), | |
| ) | |
| self.decoder = nn.Sequential( | |
| nn.Linear(args.latent_dim, args.hidden_dim), | |
| nn.ReLU(), | |
| nn.Linear(args.hidden_dim, args.hidden_dim), | |
| nn.ReLU(), | |
| nn.Linear(args.hidden_dim, args.input_dim), | |
| ) | |
| self.fc_mu = nn.Linear(args.hidden_dim, args.latent_dim) | |
| self.fc_var = nn.Linear(args.hidden_dim, args.latent_dim) | |
| def forward(self, x): | |
| h = self.encoder(x) | |
| mu = self.fc_mu(h) | |
| logvar = self.fc_var(h) | |
| x = vae_sample(mu, logvar) | |
| o = self.decoder(x) | |
| return o, (mu, logvar) | |
| class LingGenerator(nn.Module): | |
| def __init__(self, args, hidden_dim=1000): | |
| super().__init__() | |
| self.gen = T5EncoderModel.from_pretrained('google/flan-t5-small') | |
| self.hidden_size = self.gen.config.d_model | |
| self.ling_embed = nn.Linear(args.lng_dim, self.hidden_size) | |
| # self.gen = nn.Sequential( | |
| # nn.Linear(args.lng_dim, 2*hidden_dim), | |
| # nn.ReLU(), | |
| # nn.BatchNorm1d(2*hidden_dim), | |
| # nn.Linear(2*hidden_dim, 2*hidden_dim), | |
| # nn.ReLU(), | |
| # nn.BatchNorm1d(2*hidden_dim), | |
| # nn.Linear(2*hidden_dim, hidden_dim), | |
| # nn.ReLU(), | |
| # ) | |
| self.gen_type = args.linggen_type | |
| self.gen_input = args.linggen_input | |
| if self.gen_type == 'vae': | |
| self.gen_mu = nn.Linear(hidden_dim, args.lng_dim) | |
| self.gen_logvar = nn.Linear(hidden_dim, args.lng_dim) | |
| elif self.gen_type == 'det': | |
| self.projection = nn.Linear(self.hidden_size, args.lng_dim) | |
| def forward(self, batch): | |
| inputs_embeds = self.gen.shared(batch['sentence1_input_ids']) | |
| inputs_att_mask = batch['sentence1_attention_mask'] | |
| bs = inputs_embeds.shape[0] | |
| if self.gen_input == 's+l': | |
| sentence1_ling = self.ling_embed(batch['sentence1_ling']) | |
| sentence1_ling = sentence1_ling.view(bs, 1, -1) | |
| inputs_embeds = inputs_embeds + sentence1_ling | |
| gen = self.gen(inputs_embeds=inputs_embeds, | |
| attention_mask=inputs_att_mask).last_hidden_state.mean(1) | |
| # gen = self.gen(batch['sentence1_ling']) | |
| cache = {} | |
| if self.gen_type == 'vae': | |
| mu = self.gen_mu(gen) | |
| logvar = self.gen_logvar(gen) | |
| output = vae_sample(mu, logvar) | |
| cache['linggen_mu'] = mu | |
| cache['linggen_logvar'] = logvar | |
| elif self.gen_type == 'det': | |
| output = self.projection(gen) | |
| return output, cache | |
| class LingDisc(nn.Module): | |
| def __init__(self, | |
| model_name, | |
| disc_type, | |
| disc_ckpt, | |
| lng_dim=40, | |
| quant_nbins=1, | |
| disc_lng_dim=None, | |
| lng_ids=None, | |
| **kwargs): | |
| super().__init__() | |
| if disc_type == 't5': | |
| self.encoder = T5EncoderModel.from_pretrained(model_name) | |
| hidden_dim = self.encoder.config.d_model | |
| self.dropout = nn.Dropout(0.2) | |
| self.lng_dim = disc_lng_dim if disc_lng_dim else lng_dim | |
| self.quant = quant_nbins > 1 | |
| self.quant = False | |
| if self.quant: | |
| self.ling_classifier = nn.Linear(hidden_dim, self.lng_dim * quant_nbins) | |
| else: | |
| self.ling_classifier = nn.Linear(hidden_dim, self.lng_dim) | |
| lng_ids = torch.tensor(lng_ids) if lng_ids is not None else None | |
| # from const import used_indices | |
| # lng_ids = torch.tensor(used_indices) | |
| self.register_buffer('lng_ids', lng_ids) | |
| elif disc_type == 'deberta': | |
| self.encoder= DebertaReplacedTokenizer.from_pretrained( | |
| pretrained_model_name_or_path=disc_ckpt, | |
| tok_model_name = model_name, | |
| problem_type='regression', num_labels=40) | |
| self.quant = False | |
| self.disc_type = disc_type | |
| def forward(self, **batch): | |
| if not 'attention_mask' in batch: | |
| if 'input_ids' in batch: | |
| att_mask = torch.ones_like(batch['input_ids']) | |
| else: | |
| att_mask = torch.ones_like(batch['logits'])[:,:,0] | |
| else: | |
| att_mask = batch['attention_mask'] | |
| if 'input_ids' in batch: | |
| enc_output = self.encoder(input_ids=batch['input_ids'], | |
| attention_mask=att_mask) | |
| elif 'logits' in batch: | |
| logits = batch['logits'] | |
| scores = F.softmax(logits, dim = -1) | |
| onehot = F.one_hot(logits.argmax(-1), num_classes=logits.shape[2]).float().to(logits.device) | |
| onehot_ = scores - scores.detach() + onehot | |
| embed_layer = self.encoder.get_input_embeddings() | |
| if isinstance(embed_layer, nn.Sequential): | |
| for i, module in enumerate(embed_layer): | |
| if i == 0: | |
| embeds = torch.matmul(onehot_, module.weight) | |
| else: | |
| embeds = module(embeds) | |
| else: | |
| embeds = onehot_ @ embed_layer.weight | |
| embeds = torch.matmul(onehot_, embed_layer.weight) | |
| enc_output = self.encoder(inputs_embeds=embeds, | |
| attention_mask=att_mask) | |
| if self.disc_type == 't5': | |
| sent_emb = self.dropout(enc_output.last_hidden_state.mean(1)) | |
| bs = sent_emb.shape[0] | |
| output = self.ling_classifier(sent_emb) | |
| if self.quant: | |
| output = output.reshape(bs, -1, self.lng_dim) | |
| if self.lng_ids is not None: | |
| output = torch.index_select(output, 1, self.lng_ids) | |
| elif self.disc_type == 'deberta': | |
| output = enc_output.logits | |
| return output | |
| class SemEmb(T5EncoderModel): | |
| def __init__(self, config, sep_token_id): | |
| super().__init__(config) | |
| self.sep_token_id = sep_token_id | |
| hidden_dim = self.config.d_model | |
| self.projection = nn.Sequential(nn.ReLU(), | |
| nn.Dropout(0.2), | |
| nn.Linear(hidden_dim, 1)) | |
| def compare_sem(self, **batch): | |
| bs = batch['attention_mask'].shape[0] | |
| ones = torch.ones((bs, 1), device=batch['attention_mask'].device) | |
| sep = torch.ones((bs, 1), dtype=torch.long, | |
| device=batch['attention_mask'].device) * self.sep_token_id | |
| att_mask = torch.cat([batch['attention_mask'], ones, batch['sentence2_attention_mask']], dim=1) | |
| if 'logits' in batch: | |
| input_ids = torch.cat([batch['input_ids'], sep], dim=1) | |
| embeds1 = self.shared(input_ids) | |
| logits = batch['logits'] | |
| scores = F.softmax(logits, dim = -1) | |
| onehot = F.one_hot(logits.argmax(-1), num_classes=logits.shape[2]).float().to(logits.device) | |
| onehot_ = scores - scores.detach() + onehot | |
| embeds2 = onehot_ @ self.shared.weight | |
| embeds1_2 = torch.cat([embeds1, embeds2], dim=1) | |
| hidden_units = super().forward(inputs_embeds=embeds1_2, | |
| attention_mask=att_mask).last_hidden_state.mean(1) | |
| elif 'sentence2_input_ids' in batch: | |
| input_ids = torch.cat([batch['input_ids'], sep, batch['sentence2_input_ids']], dim=1) | |
| hidden_units = super().forward(input_ids=input_ids, | |
| attention_mask=att_mask).last_hidden_state.mean(1) | |
| probs = self.projection(hidden_units) | |
| return probs | |
| def prepare_inputs_for_generation( | |
| combine_method, | |
| ling2_only, | |
| self, | |
| input_ids, | |
| past_key_values=None, | |
| attention_mask=None, | |
| head_mask=None, | |
| decoder_head_mask=None, | |
| cross_attn_head_mask=None, | |
| use_cache=None, | |
| encoder_outputs=None, | |
| sentence1_ling=None, | |
| sentence2_ling=None, | |
| **kwargs | |
| ): | |
| # cut decoder_input_ids if past is used | |
| if past_key_values is not None: | |
| input_ids = input_ids[:, -1:] | |
| cached = use_cache and len(past_key_values) > 0 | |
| input_ids = input_ids.clone() | |
| decoder_inputs_embeds = self.shared(input_ids) | |
| if combine_method == 'layer_injection': | |
| # For layer injection, we'll pass the ling embeddings separately | |
| ling_embed = sentence2_ling if ling2_only else (sentence1_ling + sentence2_ling) | |
| elif combine_method == 'decoder_add_first' and not cached: | |
| sentence2_ling = torch.cat([sentence2_ling, | |
| torch.repeat_interleave(torch.zeros_like(sentence2_ling), input_ids.shape[1] - 1, dim=1)], dim = 1) | |
| elif combine_method == 'decoder_concat': | |
| if ling2_only: | |
| decoder_inputs_embeds = torch.cat([sentence2_ling, decoder_inputs_embeds], dim=1) | |
| else: | |
| decoder_inputs_embeds = torch.cat([sentence1_ling, sentence2_ling, decoder_inputs_embeds], dim=1) | |
| if combine_method == 'decoder_add' or (not cached and combine_method == 'decoder_add_first'): | |
| if ling2_only: | |
| decoder_inputs_embeds = decoder_inputs_embeds + sentence2_ling | |
| else: | |
| decoder_inputs_embeds = decoder_inputs_embeds + sentence1_ling + sentence2_ling | |
| return { | |
| "decoder_inputs_embeds": decoder_inputs_embeds, | |
| "past_key_values": past_key_values, | |
| "encoder_outputs": encoder_outputs, | |
| "attention_mask": attention_mask, | |
| "head_mask": head_mask, | |
| "decoder_head_mask": decoder_head_mask, | |
| "cross_attn_head_mask": cross_attn_head_mask, | |
| "use_cache": use_cache, | |
| "ling_embed": ling_embed if combine_method == 'layer_injection' else None, | |
| } | |
| class LogitsAdd(LogitsProcessor): | |
| def __init__(self, sentence2_ling): | |
| super().__init__() | |
| self.sentence2_ling = sentence2_ling | |
| def __call__(self, input_ids, scores): | |
| return scores + self.sentence2_ling | |
| class EncoderDecoderVAE(LingConvT5ForConditionalGeneration): | |
| def __init__(self, config, args, pad_token_id, sepeos_token_id, vocab_size = 32128): | |
| if args.combine_method == 'layer_injection': | |
| if args.injection_layer < 0 or args.injection_layer >= config.num_decoder_layers: | |
| raise ValueError(f"Invalid injection layer: {args.injection_layer}. Must be between 0 and {config.num_decoder_layers - 1}.") | |
| config.ling_injection_layer = args.injection_layer | |
| config.ling_injection_type = args.injection_type # 'first' or 'all' | |
| super().__init__(config) | |
| self.prepare_inputs_for_generation = types.MethodType( | |
| partial(prepare_inputs_for_generation, args.combine_method, args.ling2_only), | |
| self) | |
| self.args = args | |
| self.pad_token_id = pad_token_id | |
| self.eos_token_id = sepeos_token_id | |
| hidden_dim = self.config.d_model if not 'logits' in args.combine_method else vocab_size | |
| if args.combine_method == 'fusion1': | |
| self.fusion = nn.Sequential( | |
| nn.Linear(hidden_dim + 2 * args.lng_dim, hidden_dim), | |
| ) | |
| elif args.combine_method == 'fusion2': | |
| self.fusion = nn.Sequential( | |
| nn.Linear(hidden_dim + 2 * args.lng_dim, hidden_dim), | |
| nn.ReLU(), | |
| nn.Linear(hidden_dim, hidden_dim), | |
| ) | |
| elif 'concat' in args.combine_method or 'add' in args.combine_method or 'layer_injection' in args.combine_method: | |
| if args.ling_embed_type == 'two-layer': | |
| self.ling_embed = nn.Sequential( | |
| nn.Linear(args.lng_dim, args.lng_dim), | |
| nn.ReLU(), | |
| nn.Linear(args.lng_dim, hidden_dim), | |
| ) | |
| else: | |
| self.ling_embed = nn.Linear(args.lng_dim, hidden_dim) | |
| self.ling_dropout = nn.Dropout(args.ling_dropout) | |
| self.ling_embed.apply(self._init_weights) | |
| if args.ling_vae: | |
| self.ling_mu = nn.Linear(hidden_dim, hidden_dim) | |
| self.ling_logvar = nn.Linear(hidden_dim, hidden_dim) | |
| nn.init.xavier_uniform_(self.ling_embed.weight) | |
| nn.init.xavier_uniform_(self.ling_mu.weight) | |
| nn.init.xavier_uniform_(self.ling_logvar.weight) | |
| generate_with_grad = unwrap(super().generate) | |
| self.generate_with_grad = MethodType(generate_with_grad, self) | |
| self.generate_original = super().generate | |
| def _init_weights(self, module): | |
| std = self.args.initializer_range | |
| if isinstance(module, nn.Linear): | |
| module.weight.data.normal_(mean=0.0, std=std) | |
| if module.bias is not None: | |
| module.bias.data.zero_() | |
| elif isinstance(module, nn.Embedding): | |
| module.weight.data.normal_(mean=0.0, std=std) | |
| if module.padding_idx is not None: | |
| module.weight.data[module.padding_idx].zero_() | |
| def get_fusion_layer(self): | |
| if 'fusion' in self.args.combine_method: | |
| return self.fusion | |
| elif 'concat' in self.args.combine_method or 'add' in self.args.combine_method: | |
| return self.ling_embed | |
| else: | |
| return None | |
| def sample(self, mu, logvar): | |
| std = torch.exp(0.5 * logvar) | |
| return mu + std * torch.randn_like(std) | |
| def _process_ling_embeddings(self, sentence1_ling, sentence2_ling, | |
| sentence1_ling_embed, sentence2_ling_embed, bs): | |
| """Helper method to process linguistic embeddings""" | |
| cache = {} | |
| # Process sentence1 embedding | |
| if sentence1_ling_embed is not None: | |
| sentence1_ling = sentence1_ling_embed | |
| elif sentence1_ling is not None: | |
| sentence1_ling = self.ling_embed(self.ling_dropout(sentence1_ling)) | |
| else: | |
| sentence1_ling = None | |
| # Process sentence2 embedding | |
| if sentence2_ling_embed is not None: | |
| sentence2_ling = sentence2_ling_embed | |
| elif sentence2_ling is not None: | |
| sentence2_ling = self.ling_embed(self.ling_dropout(sentence2_ling)) | |
| else: | |
| sentence2_ling = None | |
| # Apply VAE if configured | |
| if self.args.ling_vae and sentence1_ling is not None and sentence2_ling is not None: | |
| sentence1_ling = F.leaky_relu(sentence1_ling) | |
| sent1_mu, sent1_logvar = self.ling_mu(sentence1_ling), self.ling_logvar(sentence1_ling) | |
| sentence1_ling = self.sample(sent1_mu, sent1_logvar) | |
| sentence2_ling = F.leaky_relu(sentence2_ling) | |
| sent2_mu, sent2_logvar = self.ling_mu(sentence2_ling), self.ling_logvar(sentence2_ling) | |
| sentence2_ling = self.sample(sent2_mu, sent2_logvar) | |
| cache.update({ | |
| 'sent1_mu': sent1_mu, 'sent1_logvar': sent1_logvar, | |
| 'sent2_mu': sent2_mu, 'sent2_logvar': sent2_logvar, | |
| 'sentence1_ling': sentence1_ling, 'sentence2_ling': sentence2_ling | |
| }) | |
| else: | |
| if sentence2_ling is not None: | |
| cache['sentence2_ling'] = sentence2_ling | |
| if sentence1_ling is not None: | |
| cache['sentence1_ling'] = sentence1_ling | |
| # Reshape embeddings | |
| if sentence1_ling is not None: | |
| sentence1_ling = sentence1_ling.view(bs, 1, -1) | |
| if sentence2_ling is not None: | |
| sentence2_ling = sentence2_ling.view(bs, 1, -1) | |
| return sentence1_ling, sentence2_ling, cache | |
| def encode(self, | |
| input_ids=None, | |
| attention_mask=None, | |
| sentence1_ling=None, | |
| sentence2_ling=None, | |
| sentence1_ling_embed=None, | |
| sentence2_ling_embed=None, | |
| inputs_embeds=None, | |
| ): | |
| if inputs_embeds is None: | |
| inputs_embeds = self.shared(input_ids) | |
| inputs_att_mask = attention_mask if attention_mask is not None else torch.ones_like(input_ids) | |
| bs = inputs_embeds.shape[0] | |
| if self.args.combine_method in ('input_concat', 'input_add'): | |
| sentence1_ling, sentence2_ling, cache = self._process_ling_embeddings( | |
| sentence1_ling, sentence2_ling, | |
| sentence1_ling_embed, sentence2_ling_embed, bs | |
| ) | |
| if self.args.combine_method == 'input_concat': | |
| if self.args.ling2_only: | |
| inputs_embeds = torch.cat([inputs_embeds, sentence2_ling], dim=1) | |
| inputs_att_mask = torch.cat([inputs_att_mask, | |
| torch.ones((bs, 1)).to(inputs_embeds.device)], dim=1) | |
| else: | |
| inputs_embeds = torch.cat([inputs_embeds, sentence1_ling, sentence2_ling], dim=1) | |
| inputs_att_mask = torch.cat([inputs_att_mask, | |
| torch.ones((bs, 2)).to(inputs_embeds.device)], dim=1) | |
| elif self.args.combine_method == 'input_add': | |
| if self.args.ling2_only: | |
| inputs_embeds = inputs_embeds + sentence2_ling | |
| else: | |
| inputs_embeds = inputs_embeds + sentence1_ling + sentence2_ling | |
| else: | |
| cache = {} | |
| return self.encoder(inputs_embeds=inputs_embeds, | |
| attention_mask=inputs_att_mask), inputs_att_mask, cache | |
| def decode(self, | |
| sentence2_input_ids=None, | |
| sentence1_ling=None, | |
| sentence2_ling=None, | |
| encoder_outputs=None, | |
| encoder_attention_mask=None, | |
| decoder_inputs_embeds=None, | |
| decoder_attention_mask=None, | |
| generate=False, | |
| sentence1_ling_embed=None, | |
| sentence2_ling_embed=None, | |
| ling_embed=None, | |
| generate_with_grad=False, | |
| **kwargs | |
| ): | |
| bs = encoder_outputs[0].shape[0] | |
| cache = {} | |
| if decoder_inputs_embeds is None: | |
| if self.args.combine_method in ('embed_concat', 'decoder_concat', 'decoder_add', | |
| 'logits_add', 'decoder_add_first', 'layer_injection'): | |
| sentence1_ling, sentence2_ling, cache = self._process_ling_embeddings( | |
| sentence1_ling, sentence2_ling, | |
| sentence1_ling_embed, sentence2_ling_embed, bs | |
| ) | |
| if (self.args.combine_method == 'decoder_add_first' or | |
| (self.args.combine_method == 'layer_injection' and | |
| self.args.injection_type == 'first')) and not generate: | |
| sentence2_ling = torch.cat([sentence2_ling, | |
| torch.repeat_interleave(torch.zeros_like(sentence2_ling), | |
| sentence2_input_ids.shape[1] - 1, dim=1)], dim = 1) | |
| else: | |
| sentence1_ling, sentence2_ling = None, None | |
| if generate: | |
| if self.args.combine_method == 'logits_add': | |
| logits_processor = LogitsProcessorList([LogitsAdd(sentence2_ling.view(bs, -1))]) | |
| else: | |
| logits_processor = LogitsProcessorList() | |
| generate_fn = self.generate_with_grad if generate_with_grad else self.generate_original | |
| dec_output = generate_fn( | |
| attention_mask=encoder_attention_mask, | |
| encoder_outputs=encoder_outputs, | |
| sentence1_ling=sentence1_ling, | |
| sentence2_ling=sentence2_ling, | |
| logits_processor = logits_processor, | |
| # renormalize_logits=True, | |
| # do_sample=True, | |
| # top_p=0.8, | |
| eos_token_id=self.eos_token_id, | |
| # min_new_tokens=3, | |
| # repetition_penalty=1.2, | |
| max_length=self.args.max_length, | |
| use_cache=True, | |
| **kwargs | |
| ) | |
| return dec_output, cache | |
| if sentence2_input_ids is not None: | |
| labels = sentence2_input_ids.clone() | |
| labels[labels == self.pad_token_id] = -100 | |
| else: | |
| labels = None | |
| if decoder_inputs_embeds is None: | |
| decoder_input_ids = self._shift_right(sentence2_input_ids) | |
| decoder_inputs_embeds = self.shared(decoder_input_ids) | |
| if self.args.combine_method == 'decoder_concat': | |
| if self.args.ling2_only: | |
| decoder_inputs_embeds = torch.cat([sentence2_ling, decoder_inputs_embeds], dim=1) | |
| decoder_attention_mask = torch.cat([torch.ones((bs, 1)).to(decoder_inputs_embeds.device), decoder_attention_mask], dim=1) | |
| labels = torch.cat([torch.ones((bs, 1), dtype=torch.int64).to(decoder_inputs_embeds.device) * self.pad_token_id, | |
| labels], dim=1) | |
| else: | |
| decoder_inputs_embeds = torch.cat([sentence1_ling, sentence2_ling, decoder_inputs_embeds], dim=1) | |
| decoder_attention_mask = torch.cat([torch.ones((bs, 2)).to(decoder_inputs_embeds.device), decoder_attention_mask], dim=1) | |
| labels = torch.cat([torch.ones((bs, 2), dtype=torch.int64).to(decoder_inputs_embeds.device) * self.pad_token_id, | |
| labels], dim=1) | |
| elif self.args.combine_method == 'decoder_add' or self.args.combine_method == 'decoder_add_first' : | |
| if self.args.ling2_only: | |
| decoder_inputs_embeds = decoder_inputs_embeds + self.args.combine_weight * sentence2_ling | |
| else: | |
| decoder_inputs_embeds = decoder_inputs_embeds + sentence1_ling + sentence2_ling | |
| if ling_embed is None: | |
| ling_embed = sentence2_ling | |
| dec_output = super().forward( | |
| decoder_inputs_embeds=decoder_inputs_embeds, | |
| decoder_attention_mask=decoder_attention_mask, | |
| encoder_outputs=encoder_outputs, | |
| attention_mask=encoder_attention_mask, | |
| labels=labels, | |
| ling_embed=ling_embed, | |
| **kwargs | |
| ) | |
| if self.args.combine_method == 'logits_add': | |
| dec_output.logits = dec_output.logits + self.args.combine_weight * sentence2_ling | |
| vocab_size = dec_output.logits.size(-1) | |
| dec_output.loss = F.cross_entropy(dec_output.logits.view(-1, vocab_size), labels.view(-1)) | |
| return dec_output, cache | |
| def generate(self, *args, **kwargs): | |
| return self.forward(*args, **kwargs, generate=True) | |
| def forward(self, | |
| input_ids=None, | |
| attention_mask=None, | |
| labels=None, | |
| decoder_attention_mask=None, | |
| decoder_inputs_embeds=None, | |
| sentence1_ling=None, | |
| sentence2_ling=None, | |
| sentence1_ling_embed=None, | |
| sentence2_ling_embed=None, | |
| inputs_embeds=None, | |
| generate=False, | |
| encoder_outputs=None, | |
| encoder_attention_mask=None, | |
| ling_embed=None, | |
| generate_with_grad=False, | |
| **kwargs): | |
| cache = {} | |
| if encoder_outputs is None: | |
| encoder_outputs, encoder_attention_mask, cache = self.encode( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| sentence1_ling=sentence1_ling, | |
| sentence2_ling=sentence2_ling, | |
| sentence1_ling_embed=sentence1_ling_embed, | |
| sentence2_ling_embed=sentence2_ling_embed, | |
| inputs_embeds=inputs_embeds | |
| ) | |
| dec_output, cache2 = self.decode( | |
| sentence2_input_ids=labels, | |
| sentence1_ling=sentence1_ling, | |
| sentence2_ling=sentence2_ling, | |
| decoder_inputs_embeds=decoder_inputs_embeds, | |
| decoder_attention_mask=decoder_attention_mask, | |
| encoder_outputs=encoder_outputs, | |
| encoder_attention_mask=encoder_attention_mask, | |
| generate=generate, | |
| sentence1_ling_embed=sentence1_ling_embed, | |
| sentence2_ling_embed=sentence2_ling_embed, | |
| ling_embed=ling_embed, | |
| generate_with_grad=generate_with_grad, | |
| **kwargs | |
| ) | |
| cache.update(cache2) | |
| if generate: | |
| return dec_output | |
| else: | |
| return MySeq2SeqLMOutput( | |
| loss=dec_output.loss, | |
| logits=dec_output.logits, | |
| past_key_values=dec_output.past_key_values, | |
| decoder_hidden_states=dec_output.decoder_hidden_states, | |
| decoder_attentions=dec_output.decoder_attentions, | |
| cross_attentions=dec_output.cross_attentions, | |
| encoder_last_hidden_state=encoder_outputs[0], | |
| encoder_hidden_states=getattr(encoder_outputs, 'hidden_states', None), | |
| encoder_attentions=getattr(encoder_outputs, 'attentions', None), | |
| cache=cache | |
| ) | |
| def infer_with_cache(self, batch): | |
| dec_output, _, cache = self(batch, generate = True) | |
| return dec_output, cache | |
| def infer(self, batch): | |
| dec_output, _ = self.infer_with_cache(batch) | |
| return dec_output | |
| def infer_with_feedback_BP(self, ling_disc, sem_emb, batch, tokenizer, progress=None): | |
| from torch.autograd import grad | |
| interpolations = [] | |
| def line_search(): | |
| eta = 1e3 | |
| sem_prob = 1 | |
| patience = 4 | |
| while patience > 0: | |
| param_ = param - eta * grads | |
| with torch.no_grad(): | |
| new_loss, pred = get_loss(param_) | |
| max_len = pred.shape[1] | |
| lens = torch.where(pred == self.eos_token_id, 1, 0).argmax(-1) + 1 | |
| sem_batch = {**batch, | |
| 'sentence2_input_ids': pred, | |
| 'sentence2_attention_mask': sequence_mask(lens, max_len = max_len) | |
| } | |
| sem_prob = torch.sigmoid(sem_emb.compare_sem(**sem_batch)).item() | |
| if new_loss < loss and sem_prob >= 0.90 and lens.item() > 1: | |
| return param_ | |
| eta *= 2.25 | |
| patience -= 1 | |
| return False | |
| def get_loss(param): | |
| if self.args.feedback_param == 'l': | |
| batch.update({'sentence2_ling_embed': param}) | |
| elif self.args.feedback_param == 's': | |
| batch.update({'inputs_embeds': param}) | |
| if self.args.feedback_param == 'logits': | |
| logits = param | |
| pred = param.argmax(-1) | |
| else: | |
| outputs = self.generate(**batch, output_scores=True, return_dict_in_generate=True, generate_with_grad=True) | |
| pred = outputs.sequences | |
| logits = torch.stack(outputs.scores, dim=1) | |
| out = ling_disc(logits = logits) | |
| probs = F.softmax(out, 1) | |
| if ling_disc.quant: | |
| loss = F.cross_entropy(out, batch['sentence2_discr']) | |
| else: | |
| loss = F.mse_loss(out, batch['sentence2_ling']) | |
| return loss, pred | |
| if self.args.feedback_param == 'l': | |
| ling2_embed = self.ling_embed(batch['sentence2_ling']) | |
| param = torch.nn.Parameter(ling2_embed, requires_grad = True) | |
| elif self.args.feedback_param == 's': | |
| inputs_embeds = self.shared(batch['input_ids']) | |
| param = torch.nn.Parameter(inputs_embeds, requires_grad = True) | |
| elif self.args.feedback_param == 'logits': | |
| logits = self.infer_with_cache(batch)[1]['scores'] | |
| param = torch.nn.Parameter(logits, requires_grad = True) | |
| num_iter = 0 | |
| while num_iter < 3: | |
| loss, pred = get_loss(param) | |
| pred_text = tokenizer.batch_decode(pred.cpu().numpy(), | |
| skip_special_tokens=True)[0] | |
| interpolations.append(pred_text) | |
| if loss < 1: | |
| break | |
| self.zero_grad() | |
| grads = grad(loss, param)[0] | |
| param = line_search() | |
| if param is False: | |
| break | |
| num_iter += 1 | |
| if progress is not None: | |
| progress((num_iter, None), unit='intermediate paraphrase generated.') | |
| return pred, [pred_text, interpolations] | |
| def set_grad(module, state): | |
| if module is not None: | |
| for p in module.parameters(): | |
| p.requires_grad = state | |
| def set_grad_except(model, name, state): | |
| for n, p in model.named_parameters(): | |
| if not name in n: | |
| p.requires_grad = state | |
| class SemEmbPipeline(): | |
| def __init__(self, | |
| ckpt = "/data/mohamed/checkpoints/ling_conversion_sem_emb_best.pt"): | |
| self.tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base") | |
| self.model = SemEmb(T5EncoderModel.from_pretrained('google/flan-t5-base'), self.tokenizer.get_vocab()['</s>']) | |
| state = torch.load(ckpt) | |
| self.model.load_state_dict(state['model'], strict=False) | |
| self.model.eval() | |
| self.model.cuda() | |
| def __call__(self, sentence1, sentence2): | |
| sentence1 = self.tokenizer(sentence1, return_attention_mask = True, return_tensors = 'pt') | |
| sentence2 = self.tokenizer(sentence2, return_attention_mask = True, return_tensors = 'pt') | |
| sem_logit = self.model( | |
| sentence1_input_ids = sentence1.input_ids.cuda(), | |
| sentence1_attention_mask = sentence1.attention_mask.cuda(), | |
| sentence2_input_ids = sentence2.input_ids.cuda(), | |
| sentence2_attention_mask = sentence2.attention_mask.cuda(), | |
| ) | |
| sem_prob = torch.sigmoid(sem_logit).item() | |
| return sem_prob | |
| class LingDiscPipeline(): | |
| def __init__(self, | |
| model_name="google/flan-t5-base", | |
| disc_type='deberta', | |
| disc_ckpt='mohdelgaar/lingconv-discriminator', | |
| # disc_type='t5', | |
| # disc_ckpt='/data/mohamed/checkpoints/ling_conversion_ling_disc.pt', | |
| ): | |
| self.tokenizer = T5Tokenizer.from_pretrained(model_name) | |
| self.model = LingDisc(model_name, disc_type, disc_ckpt) | |
| self.model.eval() | |
| self.model.cuda() | |
| def __call__(self, sentence): | |
| inputs = self.tokenizer(sentence, return_tensors = 'pt') | |
| with torch.no_grad(): | |
| ling_pred = self.model(input_ids=inputs.input_ids.cuda()) | |
| return ling_pred | |
| def get_model(args, tokenizer, device): | |
| if args.pretrain_disc or args.disc_loss or args.disc_ckpt: | |
| ling_disc = LingDisc(args.model_name, args.disc_type, args.disc_model_path).to(device) | |
| else: | |
| ling_disc = None | |
| if args.model_path: | |
| model = EncoderDecoderVAE.from_pretrained(args.model_path, args, tokenizer.pad_token_id, tokenizer.eos_token_id).to(device) | |
| else: | |
| model = EncoderDecoderVAE.from_pretrained(args.model_name, args, tokenizer.pad_token_id, tokenizer.eos_token_id).to(device) | |
| if args.sem_loss or args.model_path: | |
| if args.sem_loss_type == 'shared': | |
| sem_emb = model.encoder | |
| elif args.sem_loss_type == 'dedicated': | |
| sem_emb = SemEmb.from_pretrained(args.sem_model_path, tokenizer.eos_token_id).to(device) | |
| else: | |
| raise NotImplementedError('Semantic loss type') | |
| else: | |
| sem_emb = None | |
| return model, ling_disc, sem_emb | |
| class MySeq2SeqLMOutput(Seq2SeqLMOutput): | |
| """ | |
| Extends Seq2SeqLMOutput to include a cache dictionary for additional model outputs. | |
| Args: | |
| cache (`Dict[str, Any]`): | |
| Dictionary containing additional model outputs like linguistic features, | |
| VAE parameters, scores, etc. | |
| """ | |
| cache: Optional[Dict[str, Any]] = None | |