Spaces:
Build error
Build error
| import pytorch_lightning as pl | |
| import torch | |
| import numpy as np | |
| import datasets | |
| from transformers import MaxLengthCriteria, StoppingCriteriaList | |
| from transformers.optimization import AdamW | |
| import itertools | |
| from averitec.models.utils import count_stats, f1_metric, pairwise_meteor | |
| from torchmetrics.text.rouge import ROUGEScore | |
| import torch.nn.functional as F | |
| import torchmetrics | |
| from torchmetrics.classification import F1Score | |
| def freeze_params(model): | |
| for layer in model.parameters(): | |
| layer.requires_grade = False | |
| class JustificationGenerationModule(pl.LightningModule): | |
| def __init__(self, tokenizer, model, learning_rate=1e-3, gen_num_beams=2, gen_max_length=100, should_pad_gen=True): | |
| super().__init__() | |
| self.tokenizer = tokenizer | |
| self.model = model | |
| self.learning_rate = learning_rate | |
| self.gen_num_beams = gen_num_beams | |
| self.gen_max_length = gen_max_length | |
| self.should_pad_gen = should_pad_gen | |
| #self.metrics = datasets.load_metric('meteor') | |
| freeze_params(self.model.get_encoder()) | |
| self.freeze_embeds() | |
| def freeze_embeds(self): | |
| ''' freeze the positional embedding parameters of the model; adapted from finetune.py ''' | |
| freeze_params(self.model.model.shared) | |
| for d in [self.model.model.encoder, self.model.model.decoder]: | |
| freeze_params(d.embed_positions) | |
| freeze_params(d.embed_tokens) | |
| # Do a forward pass through the model | |
| def forward(self, input_ids, **kwargs): | |
| return self.model(input_ids, **kwargs) | |
| def configure_optimizers(self): | |
| optimizer = AdamW(self.parameters(), lr = self.learning_rate) | |
| return optimizer | |
| def shift_tokens_right(self, input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): | |
| """ | |
| Shift input ids one token to the right. | |
| https://github.com/huggingface/transformers/blob/main/src/transformers/models/bart/modeling_bart.py. | |
| """ | |
| shifted_input_ids = input_ids.new_zeros(input_ids.shape) | |
| shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() | |
| shifted_input_ids[:, 0] = decoder_start_token_id | |
| if pad_token_id is None: | |
| raise ValueError("self.model.config.pad_token_id has to be defined.") | |
| # replace possible -100 values in labels by `pad_token_id` | |
| shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) | |
| return shifted_input_ids | |
| def run_model(self, batch): | |
| src_ids, src_mask, tgt_ids = batch[0], batch[1], batch[2] | |
| decoder_input_ids = self.shift_tokens_right( | |
| tgt_ids, self.tokenizer.pad_token_id, self.tokenizer.pad_token_id # BART uses the EOS token to start generation as well. Might have to change for other models. | |
| ) | |
| outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False) | |
| return outputs | |
| def compute_loss(self, batch): | |
| tgt_ids = batch[2] | |
| logits = self.run_model(batch)[0] | |
| cross_entropy = torch.nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id) | |
| loss = cross_entropy(logits.view(-1, logits.shape[-1]), tgt_ids.view(-1)) | |
| return loss | |
| def training_step(self, batch, batch_idx): | |
| loss = self.compute_loss(batch) | |
| self.log("train_loss", loss, on_epoch=True) | |
| return {'loss':loss} | |
| def validation_step(self, batch, batch_idx): | |
| preds, loss, tgts = self.generate_and_compute_loss_and_tgts(batch) | |
| if self.should_pad_gen: | |
| preds = F.pad(preds, pad=(0, self.gen_max_length - preds.shape[1]), value=self.tokenizer.pad_token_id) | |
| self.log('val_loss', loss, prog_bar=True, sync_dist=True) | |
| return {'loss': loss, 'pred': preds, 'target': tgts} | |
| def test_step(self, batch, batch_idx): | |
| test_preds, test_loss, test_tgts = self.generate_and_compute_loss_and_tgts(batch) | |
| if self.should_pad_gen: | |
| test_preds = F.pad(test_preds, pad=(0, self.gen_max_length - test_preds.shape[1]), value=self.tokenizer.pad_token_id) | |
| self.log('test_loss', test_loss, prog_bar=True, sync_dist=True) | |
| return {'loss': test_loss, 'pred': test_preds, 'target': test_tgts} | |
| def test_epoch_end(self, outputs): | |
| self.handle_end_of_epoch_scoring(outputs, "test") | |
| def validation_epoch_end(self, outputs): | |
| self.handle_end_of_epoch_scoring(outputs, "val") | |
| def handle_end_of_epoch_scoring(self, outputs, prefix): | |
| gen = {} | |
| tgt = {} | |
| rouge = ROUGEScore() | |
| rouge_metric = lambda x, y: rouge(x,y)["rougeL_precision"] | |
| for out in outputs: | |
| preds = out['pred'] | |
| tgts = out['target'] | |
| preds = self.do_batch_detokenize(preds) | |
| tgts = self.do_batch_detokenize(tgts) | |
| for pred, t in zip(preds, tgts): | |
| rouge_d = rouge_metric(pred, t) | |
| self.log(prefix+"_rouge", rouge_d) | |
| meteor_d = pairwise_meteor(pred, t) | |
| self.log(prefix+"_meteor", meteor_d) | |
| def generate_and_compute_loss_and_tgts(self, batch): | |
| src_ids = batch[0] | |
| loss = self.compute_loss(batch) | |
| pred_ids, _ = self.generate_for_batch(src_ids) | |
| tgt_ids = batch[2] | |
| return pred_ids, loss, tgt_ids | |
| def do_batch_detokenize(self, batch): | |
| tokens = self.tokenizer.batch_decode( | |
| batch, | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=True | |
| ) | |
| # Huggingface skipping of special tokens doesn't work for all models, so we do it manually as well for safety: | |
| tokens = [p.replace("<pad>", "") for p in tokens] | |
| tokens = [p.replace("<s>", "") for p in tokens] | |
| tokens = [p.replace("</s>", "") for p in tokens] | |
| return [t for t in tokens if t != ""] | |
| def generate_for_batch(self, batch): | |
| generated_ids = self.model.generate( | |
| batch, | |
| decoder_start_token_id = self.tokenizer.pad_token_id, | |
| num_beams = self.gen_num_beams, | |
| max_length = self.gen_max_length | |
| ) | |
| generated_tokens = self.tokenizer.batch_decode( | |
| generated_ids, | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=True | |
| ) | |
| return generated_ids, generated_tokens | |
| def generate(self, text, max_input_length=512, device=None): | |
| encoded_dict = self.tokenizer( | |
| [text], | |
| max_length=max_input_length, | |
| padding="longest", | |
| truncation=True, | |
| return_tensors="pt", | |
| add_prefix_space = True | |
| ) | |
| input_ids = encoded_dict['input_ids'] | |
| if device is not None: | |
| input_ids = input_ids.to(device) | |
| with torch.no_grad(): | |
| _, generated_tokens = self.generate_for_batch(input_ids) | |
| return generated_tokens[0] |