import types import torch import torch.nn.functional as F import numpy as np from torch import nn from transformers import T5ForConditionalGeneration, T5EncoderModel, AutoModel, LogitsProcessor, LogitsProcessorList, PreTrainedModel from functools import partial from undecorate import unwrap from types import MethodType from utils import * from ling_disc import DebertaReplacedTokenizer from const import * from lingconv_t5 import LingConvT5ForConditionalGeneration from dataclasses import dataclass from transformers.modeling_outputs import Seq2SeqLMOutput from typing import Optional, Dict, Any def vae_sample(mu, logvar): std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return eps * std + mu class VAE(nn.Module): def __init__(self, args): super().__init__() self.encoder = nn.Sequential( nn.Linear(args.input_dim, args.hidden_dim), nn.ReLU(), nn.Linear(args.hidden_dim, args.hidden_dim), nn.ReLU(), ) self.decoder = nn.Sequential( nn.Linear(args.latent_dim, args.hidden_dim), nn.ReLU(), nn.Linear(args.hidden_dim, args.hidden_dim), nn.ReLU(), nn.Linear(args.hidden_dim, args.input_dim), ) self.fc_mu = nn.Linear(args.hidden_dim, args.latent_dim) self.fc_var = nn.Linear(args.hidden_dim, args.latent_dim) def forward(self, x): h = self.encoder(x) mu = self.fc_mu(h) logvar = self.fc_var(h) x = vae_sample(mu, logvar) o = self.decoder(x) return o, (mu, logvar) class LingGenerator(nn.Module): def __init__(self, args, hidden_dim=1000): super().__init__() self.gen = T5EncoderModel.from_pretrained('google/flan-t5-small') self.hidden_size = self.gen.config.d_model self.ling_embed = nn.Linear(args.lng_dim, self.hidden_size) # self.gen = nn.Sequential( # nn.Linear(args.lng_dim, 2*hidden_dim), # nn.ReLU(), # nn.BatchNorm1d(2*hidden_dim), # nn.Linear(2*hidden_dim, 2*hidden_dim), # nn.ReLU(), # nn.BatchNorm1d(2*hidden_dim), # nn.Linear(2*hidden_dim, hidden_dim), # nn.ReLU(), # ) self.gen_type = args.linggen_type self.gen_input = args.linggen_input if self.gen_type == 'vae': self.gen_mu = nn.Linear(hidden_dim, args.lng_dim) self.gen_logvar = nn.Linear(hidden_dim, args.lng_dim) elif self.gen_type == 'det': self.projection = nn.Linear(self.hidden_size, args.lng_dim) def forward(self, batch): inputs_embeds = self.gen.shared(batch['sentence1_input_ids']) inputs_att_mask = batch['sentence1_attention_mask'] bs = inputs_embeds.shape[0] if self.gen_input == 's+l': sentence1_ling = self.ling_embed(batch['sentence1_ling']) sentence1_ling = sentence1_ling.view(bs, 1, -1) inputs_embeds = inputs_embeds + sentence1_ling gen = self.gen(inputs_embeds=inputs_embeds, attention_mask=inputs_att_mask).last_hidden_state.mean(1) # gen = self.gen(batch['sentence1_ling']) cache = {} if self.gen_type == 'vae': mu = self.gen_mu(gen) logvar = self.gen_logvar(gen) output = vae_sample(mu, logvar) cache['linggen_mu'] = mu cache['linggen_logvar'] = logvar elif self.gen_type == 'det': output = self.projection(gen) return output, cache class LingDisc(nn.Module): def __init__(self, model_name, disc_type, disc_ckpt, lng_dim=40, quant_nbins=1, disc_lng_dim=None, lng_ids=None, **kwargs): super().__init__() if disc_type == 't5': self.encoder = T5EncoderModel.from_pretrained(model_name) hidden_dim = self.encoder.config.d_model self.dropout = nn.Dropout(0.2) self.lng_dim = disc_lng_dim if disc_lng_dim else lng_dim self.quant = quant_nbins > 1 self.quant = False if self.quant: self.ling_classifier = nn.Linear(hidden_dim, self.lng_dim * quant_nbins) else: self.ling_classifier = nn.Linear(hidden_dim, self.lng_dim) lng_ids = torch.tensor(lng_ids) if lng_ids is not None else None # from const import used_indices # lng_ids = torch.tensor(used_indices) self.register_buffer('lng_ids', lng_ids) elif disc_type == 'deberta': self.encoder= DebertaReplacedTokenizer.from_pretrained( pretrained_model_name_or_path=disc_ckpt, tok_model_name = model_name, problem_type='regression', num_labels=40) self.quant = False self.disc_type = disc_type def forward(self, **batch): if not 'attention_mask' in batch: if 'input_ids' in batch: att_mask = torch.ones_like(batch['input_ids']) else: att_mask = torch.ones_like(batch['logits'])[:,:,0] else: att_mask = batch['attention_mask'] if 'input_ids' in batch: enc_output = self.encoder(input_ids=batch['input_ids'], attention_mask=att_mask) elif 'logits' in batch: logits = batch['logits'] scores = F.softmax(logits, dim = -1) onehot = F.one_hot(logits.argmax(-1), num_classes=logits.shape[2]).float().to(logits.device) onehot_ = scores - scores.detach() + onehot embed_layer = self.encoder.get_input_embeddings() if isinstance(embed_layer, nn.Sequential): for i, module in enumerate(embed_layer): if i == 0: embeds = torch.matmul(onehot_, module.weight) else: embeds = module(embeds) else: embeds = onehot_ @ embed_layer.weight embeds = torch.matmul(onehot_, embed_layer.weight) enc_output = self.encoder(inputs_embeds=embeds, attention_mask=att_mask) if self.disc_type == 't5': sent_emb = self.dropout(enc_output.last_hidden_state.mean(1)) bs = sent_emb.shape[0] output = self.ling_classifier(sent_emb) if self.quant: output = output.reshape(bs, -1, self.lng_dim) if self.lng_ids is not None: output = torch.index_select(output, 1, self.lng_ids) elif self.disc_type == 'deberta': output = enc_output.logits return output class SemEmb(T5EncoderModel): def __init__(self, config, sep_token_id): super().__init__(config) self.sep_token_id = sep_token_id hidden_dim = self.config.d_model self.projection = nn.Sequential(nn.ReLU(), nn.Dropout(0.2), nn.Linear(hidden_dim, 1)) def compare_sem(self, **batch): bs = batch['attention_mask'].shape[0] ones = torch.ones((bs, 1), device=batch['attention_mask'].device) sep = torch.ones((bs, 1), dtype=torch.long, device=batch['attention_mask'].device) * self.sep_token_id att_mask = torch.cat([batch['attention_mask'], ones, batch['sentence2_attention_mask']], dim=1) if 'logits' in batch: input_ids = torch.cat([batch['input_ids'], sep], dim=1) embeds1 = self.shared(input_ids) logits = batch['logits'] scores = F.softmax(logits, dim = -1) onehot = F.one_hot(logits.argmax(-1), num_classes=logits.shape[2]).float().to(logits.device) onehot_ = scores - scores.detach() + onehot embeds2 = onehot_ @ self.shared.weight embeds1_2 = torch.cat([embeds1, embeds2], dim=1) hidden_units = super().forward(inputs_embeds=embeds1_2, attention_mask=att_mask).last_hidden_state.mean(1) elif 'sentence2_input_ids' in batch: input_ids = torch.cat([batch['input_ids'], sep, batch['sentence2_input_ids']], dim=1) hidden_units = super().forward(input_ids=input_ids, attention_mask=att_mask).last_hidden_state.mean(1) probs = self.projection(hidden_units) return probs def prepare_inputs_for_generation( combine_method, ling2_only, self, input_ids, past_key_values=None, attention_mask=None, head_mask=None, decoder_head_mask=None, cross_attn_head_mask=None, use_cache=None, encoder_outputs=None, sentence1_ling=None, sentence2_ling=None, **kwargs ): # cut decoder_input_ids if past is used if past_key_values is not None: input_ids = input_ids[:, -1:] cached = use_cache and len(past_key_values) > 0 input_ids = input_ids.clone() decoder_inputs_embeds = self.shared(input_ids) if combine_method == 'layer_injection': # For layer injection, we'll pass the ling embeddings separately ling_embed = sentence2_ling if ling2_only else (sentence1_ling + sentence2_ling) elif combine_method == 'decoder_add_first' and not cached: sentence2_ling = torch.cat([sentence2_ling, torch.repeat_interleave(torch.zeros_like(sentence2_ling), input_ids.shape[1] - 1, dim=1)], dim = 1) elif combine_method == 'decoder_concat': if ling2_only: decoder_inputs_embeds = torch.cat([sentence2_ling, decoder_inputs_embeds], dim=1) else: decoder_inputs_embeds = torch.cat([sentence1_ling, sentence2_ling, decoder_inputs_embeds], dim=1) if combine_method == 'decoder_add' or (not cached and combine_method == 'decoder_add_first'): if ling2_only: decoder_inputs_embeds = decoder_inputs_embeds + sentence2_ling else: decoder_inputs_embeds = decoder_inputs_embeds + sentence1_ling + sentence2_ling return { "decoder_inputs_embeds": decoder_inputs_embeds, "past_key_values": past_key_values, "encoder_outputs": encoder_outputs, "attention_mask": attention_mask, "head_mask": head_mask, "decoder_head_mask": decoder_head_mask, "cross_attn_head_mask": cross_attn_head_mask, "use_cache": use_cache, "ling_embed": ling_embed if combine_method == 'layer_injection' else None, } class LogitsAdd(LogitsProcessor): def __init__(self, sentence2_ling): super().__init__() self.sentence2_ling = sentence2_ling def __call__(self, input_ids, scores): return scores + self.sentence2_ling class EncoderDecoderVAE(LingConvT5ForConditionalGeneration): def __init__(self, config, args, pad_token_id, sepeos_token_id, vocab_size = 32128): if args.combine_method == 'layer_injection': if args.injection_layer < 0 or args.injection_layer >= config.num_decoder_layers: raise ValueError(f"Invalid injection layer: {args.injection_layer}. Must be between 0 and {config.num_decoder_layers - 1}.") config.ling_injection_layer = args.injection_layer config.ling_injection_type = args.injection_type # 'first' or 'all' super().__init__(config) self.prepare_inputs_for_generation = types.MethodType( partial(prepare_inputs_for_generation, args.combine_method, args.ling2_only), self) self.args = args self.pad_token_id = pad_token_id self.eos_token_id = sepeos_token_id hidden_dim = self.config.d_model if not 'logits' in args.combine_method else vocab_size if args.combine_method == 'fusion1': self.fusion = nn.Sequential( nn.Linear(hidden_dim + 2 * args.lng_dim, hidden_dim), ) elif args.combine_method == 'fusion2': self.fusion = nn.Sequential( nn.Linear(hidden_dim + 2 * args.lng_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), ) elif 'concat' in args.combine_method or 'add' in args.combine_method or 'layer_injection' in args.combine_method: if args.ling_embed_type == 'two-layer': self.ling_embed = nn.Sequential( nn.Linear(args.lng_dim, args.lng_dim), nn.ReLU(), nn.Linear(args.lng_dim, hidden_dim), ) else: self.ling_embed = nn.Linear(args.lng_dim, hidden_dim) self.ling_dropout = nn.Dropout(args.ling_dropout) self.ling_embed.apply(self._init_weights) if args.ling_vae: self.ling_mu = nn.Linear(hidden_dim, hidden_dim) self.ling_logvar = nn.Linear(hidden_dim, hidden_dim) nn.init.xavier_uniform_(self.ling_embed.weight) nn.init.xavier_uniform_(self.ling_mu.weight) nn.init.xavier_uniform_(self.ling_logvar.weight) generate_with_grad = unwrap(super().generate) self.generate_with_grad = MethodType(generate_with_grad, self) self.generate_original = super().generate def _init_weights(self, module): std = self.args.initializer_range if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() def get_fusion_layer(self): if 'fusion' in self.args.combine_method: return self.fusion elif 'concat' in self.args.combine_method or 'add' in self.args.combine_method: return self.ling_embed else: return None def sample(self, mu, logvar): std = torch.exp(0.5 * logvar) return mu + std * torch.randn_like(std) def _process_ling_embeddings(self, sentence1_ling, sentence2_ling, sentence1_ling_embed, sentence2_ling_embed, bs): """Helper method to process linguistic embeddings""" cache = {} # Process sentence1 embedding if sentence1_ling_embed is not None: sentence1_ling = sentence1_ling_embed elif sentence1_ling is not None: sentence1_ling = self.ling_embed(self.ling_dropout(sentence1_ling)) else: sentence1_ling = None # Process sentence2 embedding if sentence2_ling_embed is not None: sentence2_ling = sentence2_ling_embed elif sentence2_ling is not None: sentence2_ling = self.ling_embed(self.ling_dropout(sentence2_ling)) else: sentence2_ling = None # Apply VAE if configured if self.args.ling_vae and sentence1_ling is not None and sentence2_ling is not None: sentence1_ling = F.leaky_relu(sentence1_ling) sent1_mu, sent1_logvar = self.ling_mu(sentence1_ling), self.ling_logvar(sentence1_ling) sentence1_ling = self.sample(sent1_mu, sent1_logvar) sentence2_ling = F.leaky_relu(sentence2_ling) sent2_mu, sent2_logvar = self.ling_mu(sentence2_ling), self.ling_logvar(sentence2_ling) sentence2_ling = self.sample(sent2_mu, sent2_logvar) cache.update({ 'sent1_mu': sent1_mu, 'sent1_logvar': sent1_logvar, 'sent2_mu': sent2_mu, 'sent2_logvar': sent2_logvar, 'sentence1_ling': sentence1_ling, 'sentence2_ling': sentence2_ling }) else: if sentence2_ling is not None: cache['sentence2_ling'] = sentence2_ling if sentence1_ling is not None: cache['sentence1_ling'] = sentence1_ling # Reshape embeddings if sentence1_ling is not None: sentence1_ling = sentence1_ling.view(bs, 1, -1) if sentence2_ling is not None: sentence2_ling = sentence2_ling.view(bs, 1, -1) return sentence1_ling, sentence2_ling, cache def encode(self, input_ids=None, attention_mask=None, sentence1_ling=None, sentence2_ling=None, sentence1_ling_embed=None, sentence2_ling_embed=None, inputs_embeds=None, ): if inputs_embeds is None: inputs_embeds = self.shared(input_ids) inputs_att_mask = attention_mask if attention_mask is not None else torch.ones_like(input_ids) bs = inputs_embeds.shape[0] if self.args.combine_method in ('input_concat', 'input_add'): sentence1_ling, sentence2_ling, cache = self._process_ling_embeddings( sentence1_ling, sentence2_ling, sentence1_ling_embed, sentence2_ling_embed, bs ) if self.args.combine_method == 'input_concat': if self.args.ling2_only: inputs_embeds = torch.cat([inputs_embeds, sentence2_ling], dim=1) inputs_att_mask = torch.cat([inputs_att_mask, torch.ones((bs, 1)).to(inputs_embeds.device)], dim=1) else: inputs_embeds = torch.cat([inputs_embeds, sentence1_ling, sentence2_ling], dim=1) inputs_att_mask = torch.cat([inputs_att_mask, torch.ones((bs, 2)).to(inputs_embeds.device)], dim=1) elif self.args.combine_method == 'input_add': if self.args.ling2_only: inputs_embeds = inputs_embeds + sentence2_ling else: inputs_embeds = inputs_embeds + sentence1_ling + sentence2_ling else: cache = {} return self.encoder(inputs_embeds=inputs_embeds, attention_mask=inputs_att_mask), inputs_att_mask, cache def decode(self, sentence2_input_ids=None, sentence1_ling=None, sentence2_ling=None, encoder_outputs=None, encoder_attention_mask=None, decoder_inputs_embeds=None, decoder_attention_mask=None, generate=False, sentence1_ling_embed=None, sentence2_ling_embed=None, ling_embed=None, generate_with_grad=False, **kwargs ): bs = encoder_outputs[0].shape[0] cache = {} if decoder_inputs_embeds is None: if self.args.combine_method in ('embed_concat', 'decoder_concat', 'decoder_add', 'logits_add', 'decoder_add_first', 'layer_injection'): sentence1_ling, sentence2_ling, cache = self._process_ling_embeddings( sentence1_ling, sentence2_ling, sentence1_ling_embed, sentence2_ling_embed, bs ) if (self.args.combine_method == 'decoder_add_first' or (self.args.combine_method == 'layer_injection' and self.args.injection_type == 'first')) and not generate: sentence2_ling = torch.cat([sentence2_ling, torch.repeat_interleave(torch.zeros_like(sentence2_ling), sentence2_input_ids.shape[1] - 1, dim=1)], dim = 1) else: sentence1_ling, sentence2_ling = None, None if generate: if self.args.combine_method == 'logits_add': logits_processor = LogitsProcessorList([LogitsAdd(sentence2_ling.view(bs, -1))]) else: logits_processor = LogitsProcessorList() generate_fn = self.generate_with_grad if generate_with_grad else self.generate_original dec_output = generate_fn( attention_mask=encoder_attention_mask, encoder_outputs=encoder_outputs, sentence1_ling=sentence1_ling, sentence2_ling=sentence2_ling, logits_processor = logits_processor, # renormalize_logits=True, # do_sample=True, # top_p=0.8, eos_token_id=self.eos_token_id, # min_new_tokens=3, # repetition_penalty=1.2, max_length=self.args.max_length, use_cache=True, **kwargs ) return dec_output, cache if sentence2_input_ids is not None: labels = sentence2_input_ids.clone() labels[labels == self.pad_token_id] = -100 else: labels = None if decoder_inputs_embeds is None: decoder_input_ids = self._shift_right(sentence2_input_ids) decoder_inputs_embeds = self.shared(decoder_input_ids) if self.args.combine_method == 'decoder_concat': if self.args.ling2_only: decoder_inputs_embeds = torch.cat([sentence2_ling, decoder_inputs_embeds], dim=1) decoder_attention_mask = torch.cat([torch.ones((bs, 1)).to(decoder_inputs_embeds.device), decoder_attention_mask], dim=1) labels = torch.cat([torch.ones((bs, 1), dtype=torch.int64).to(decoder_inputs_embeds.device) * self.pad_token_id, labels], dim=1) else: decoder_inputs_embeds = torch.cat([sentence1_ling, sentence2_ling, decoder_inputs_embeds], dim=1) decoder_attention_mask = torch.cat([torch.ones((bs, 2)).to(decoder_inputs_embeds.device), decoder_attention_mask], dim=1) labels = torch.cat([torch.ones((bs, 2), dtype=torch.int64).to(decoder_inputs_embeds.device) * self.pad_token_id, labels], dim=1) elif self.args.combine_method == 'decoder_add' or self.args.combine_method == 'decoder_add_first' : if self.args.ling2_only: decoder_inputs_embeds = decoder_inputs_embeds + self.args.combine_weight * sentence2_ling else: decoder_inputs_embeds = decoder_inputs_embeds + sentence1_ling + sentence2_ling if ling_embed is None: ling_embed = sentence2_ling dec_output = super().forward( decoder_inputs_embeds=decoder_inputs_embeds, decoder_attention_mask=decoder_attention_mask, encoder_outputs=encoder_outputs, attention_mask=encoder_attention_mask, labels=labels, ling_embed=ling_embed, **kwargs ) if self.args.combine_method == 'logits_add': dec_output.logits = dec_output.logits + self.args.combine_weight * sentence2_ling vocab_size = dec_output.logits.size(-1) dec_output.loss = F.cross_entropy(dec_output.logits.view(-1, vocab_size), labels.view(-1)) return dec_output, cache def generate(self, *args, **kwargs): return self.forward(*args, **kwargs, generate=True) def forward(self, input_ids=None, attention_mask=None, labels=None, decoder_attention_mask=None, decoder_inputs_embeds=None, sentence1_ling=None, sentence2_ling=None, sentence1_ling_embed=None, sentence2_ling_embed=None, inputs_embeds=None, generate=False, encoder_outputs=None, encoder_attention_mask=None, ling_embed=None, generate_with_grad=False, **kwargs): cache = {} if encoder_outputs is None: encoder_outputs, encoder_attention_mask, cache = self.encode( input_ids=input_ids, attention_mask=attention_mask, sentence1_ling=sentence1_ling, sentence2_ling=sentence2_ling, sentence1_ling_embed=sentence1_ling_embed, sentence2_ling_embed=sentence2_ling_embed, inputs_embeds=inputs_embeds ) dec_output, cache2 = self.decode( sentence2_input_ids=labels, sentence1_ling=sentence1_ling, sentence2_ling=sentence2_ling, decoder_inputs_embeds=decoder_inputs_embeds, decoder_attention_mask=decoder_attention_mask, encoder_outputs=encoder_outputs, encoder_attention_mask=encoder_attention_mask, generate=generate, sentence1_ling_embed=sentence1_ling_embed, sentence2_ling_embed=sentence2_ling_embed, ling_embed=ling_embed, generate_with_grad=generate_with_grad, **kwargs ) cache.update(cache2) if generate: return dec_output else: return MySeq2SeqLMOutput( loss=dec_output.loss, logits=dec_output.logits, past_key_values=dec_output.past_key_values, decoder_hidden_states=dec_output.decoder_hidden_states, decoder_attentions=dec_output.decoder_attentions, cross_attentions=dec_output.cross_attentions, encoder_last_hidden_state=encoder_outputs[0], encoder_hidden_states=getattr(encoder_outputs, 'hidden_states', None), encoder_attentions=getattr(encoder_outputs, 'attentions', None), cache=cache ) def infer_with_cache(self, batch): dec_output, _, cache = self(batch, generate = True) return dec_output, cache def infer(self, batch): dec_output, _ = self.infer_with_cache(batch) return dec_output def infer_with_feedback_BP(self, ling_disc, sem_emb, batch, tokenizer, progress=None): from torch.autograd import grad interpolations = [] def line_search(): eta = 1e3 sem_prob = 1 patience = 4 while patience > 0: param_ = param - eta * grads with torch.no_grad(): new_loss, pred = get_loss(param_) max_len = pred.shape[1] lens = torch.where(pred == self.eos_token_id, 1, 0).argmax(-1) + 1 sem_batch = {**batch, 'sentence2_input_ids': pred, 'sentence2_attention_mask': sequence_mask(lens, max_len = max_len) } sem_prob = torch.sigmoid(sem_emb.compare_sem(**sem_batch)).item() if new_loss < loss and sem_prob >= 0.90 and lens.item() > 1: return param_ eta *= 2.25 patience -= 1 return False def get_loss(param): if self.args.feedback_param == 'l': batch.update({'sentence2_ling_embed': param}) elif self.args.feedback_param == 's': batch.update({'inputs_embeds': param}) if self.args.feedback_param == 'logits': logits = param pred = param.argmax(-1) else: outputs = self.generate(**batch, output_scores=True, return_dict_in_generate=True, generate_with_grad=True) pred = outputs.sequences logits = torch.stack(outputs.scores, dim=1) out = ling_disc(logits = logits) probs = F.softmax(out, 1) if ling_disc.quant: loss = F.cross_entropy(out, batch['sentence2_discr']) else: loss = F.mse_loss(out, batch['sentence2_ling']) return loss, pred if self.args.feedback_param == 'l': ling2_embed = self.ling_embed(batch['sentence2_ling']) param = torch.nn.Parameter(ling2_embed, requires_grad = True) elif self.args.feedback_param == 's': inputs_embeds = self.shared(batch['input_ids']) param = torch.nn.Parameter(inputs_embeds, requires_grad = True) elif self.args.feedback_param == 'logits': logits = self.infer_with_cache(batch)[1]['scores'] param = torch.nn.Parameter(logits, requires_grad = True) num_iter = 0 while num_iter < 3: loss, pred = get_loss(param) pred_text = tokenizer.batch_decode(pred.cpu().numpy(), skip_special_tokens=True)[0] interpolations.append(pred_text) if loss < 1: break self.zero_grad() grads = grad(loss, param)[0] param = line_search() if param is False: break num_iter += 1 if progress is not None: progress((num_iter, None), unit='intermediate paraphrase generated.') return pred, [pred_text, interpolations] def set_grad(module, state): if module is not None: for p in module.parameters(): p.requires_grad = state def set_grad_except(model, name, state): for n, p in model.named_parameters(): if not name in n: p.requires_grad = state class SemEmbPipeline(): def __init__(self, ckpt = "/data/mohamed/checkpoints/ling_conversion_sem_emb_best.pt"): self.tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base") self.model = SemEmb(T5EncoderModel.from_pretrained('google/flan-t5-base'), self.tokenizer.get_vocab()['']) state = torch.load(ckpt) self.model.load_state_dict(state['model'], strict=False) self.model.eval() self.model.cuda() def __call__(self, sentence1, sentence2): sentence1 = self.tokenizer(sentence1, return_attention_mask = True, return_tensors = 'pt') sentence2 = self.tokenizer(sentence2, return_attention_mask = True, return_tensors = 'pt') sem_logit = self.model( sentence1_input_ids = sentence1.input_ids.cuda(), sentence1_attention_mask = sentence1.attention_mask.cuda(), sentence2_input_ids = sentence2.input_ids.cuda(), sentence2_attention_mask = sentence2.attention_mask.cuda(), ) sem_prob = torch.sigmoid(sem_logit).item() return sem_prob class LingDiscPipeline(): def __init__(self, model_name="google/flan-t5-base", disc_type='deberta', disc_ckpt='mohdelgaar/lingconv-discriminator', # disc_type='t5', # disc_ckpt='/data/mohamed/checkpoints/ling_conversion_ling_disc.pt', ): self.tokenizer = T5Tokenizer.from_pretrained(model_name) self.model = LingDisc(model_name, disc_type, disc_ckpt) self.model.eval() self.model.cuda() def __call__(self, sentence): inputs = self.tokenizer(sentence, return_tensors = 'pt') with torch.no_grad(): ling_pred = self.model(input_ids=inputs.input_ids.cuda()) return ling_pred def get_model(args, tokenizer, device): if args.pretrain_disc or args.disc_loss or args.disc_ckpt: ling_disc = LingDisc(args.model_name, args.disc_type, args.disc_model_path).to(device) else: ling_disc = None if args.model_path: model = EncoderDecoderVAE.from_pretrained(args.model_path, args, tokenizer.pad_token_id, tokenizer.eos_token_id).to(device) else: model = EncoderDecoderVAE.from_pretrained(args.model_name, args, tokenizer.pad_token_id, tokenizer.eos_token_id).to(device) if args.sem_loss or args.model_path: if args.sem_loss_type == 'shared': sem_emb = model.encoder elif args.sem_loss_type == 'dedicated': sem_emb = SemEmb.from_pretrained(args.sem_model_path, tokenizer.eos_token_id).to(device) else: raise NotImplementedError('Semantic loss type') else: sem_emb = None return model, ling_disc, sem_emb @dataclass class MySeq2SeqLMOutput(Seq2SeqLMOutput): """ Extends Seq2SeqLMOutput to include a cache dictionary for additional model outputs. Args: cache (`Dict[str, Any]`): Dictionary containing additional model outputs like linguistic features, VAE parameters, scores, etc. """ cache: Optional[Dict[str, Any]] = None