| | from math import inf |
| | |
| | import torch |
| | import torch.nn as nn |
| | import numpy as np |
| | import torch.utils |
| | import torch.utils.data |
| | |
| | from torch.nn.utils.rnn import pad_sequence,pad_packed_sequence,pack_padded_sequence |
| | import wandb |
| | import torch.nn.functional as F |
| |
|
| | import einops |
| | from transformers import AutoModelForCausalLM, AutoTokenizer, GPT2TokenizerFast |
| |
|
| | np.random.seed(123) |
| | torch.manual_seed(123) |
| | torch.cuda.random.manual_seed(123) |
| |
|
| | import lightning as L |
| | import utils |
| | from torchmetrics.text.rouge import ROUGEScore |
| | def top_p_sampling(logits, p=0.9, temperature=0.5): |
| |
|
| | |
| | logits = logits / temperature |
| |
|
| | |
| | sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) |
| | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
| |
|
| | |
| | sorted_indices_to_remove = cumulative_probs > p |
| | |
| | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
| | sorted_indices_to_remove[..., 0] = 0 |
| |
|
| | |
| | indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) |
| | logits[indices_to_remove] = float('-inf') |
| |
|
| | |
| | probs = F.softmax(logits, dim=-1) |
| | sampled_indices = torch.multinomial(probs, num_samples=1) |
| | sampled_indices = sampled_indices.squeeze(1) |
| |
|
| | return sampled_indices |
| |
|
| | class PromptTuningModel(nn.Module): |
| | def __init__(self, num_prompts=6): |
| | super().__init__() |
| | self.num_prompts = num_prompts |
| |
|
| | self.model = AutoModelForCausalLM.from_pretrained("gpt2", ) |
| | self.model.requires_grad_(False) |
| | self.tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2") |
| | self.tokenizer.add_special_tokens({'pad_token': '[PAD]'}) |
| | self.tokenizer.add_special_tokens({'pad_token': '[START]'}) |
| |
|
| | self.eot = self.tokenizer("<|endoftext|>", return_tensors="pt").input_ids[0] |
| | self.pad = self.tokenizer("[PAD]", return_tensors="pt").input_ids[0] |
| | |
| |
|
| | self.model.resize_token_embeddings(new_num_tokens = len(self.tokenizer),) |
| |
|
| | tmp = self.tokenizer('summarise', return_tensors="pt").input_ids |
| | token_embedding = self.model.transformer.wte(tmp[0]) |
| | self.token_embedding = token_embedding |
| | for _ in range(num_prompts//3-1): |
| | self.token_embedding = torch.cat([self.token_embedding, token_embedding]) |
| |
|
| | |
| | data = torch.zeros(num_prompts, 768) + self.token_embedding[:] |
| | self.learnable_prompt = nn.Parameter(data, requires_grad=True) |
| |
|
| | |
| | def forward(self, X, y): |
| | self.learnable_prompt = self.learnable_prompt.to(X.device) |
| | embeddings = self.model.transformer.wte(X, ) |
| | |
| | embeddings = torch.cat([self.learnable_prompt[None, :, :].repeat(X.shape[0], 1, 1), embeddings], dim=1) |
| | |
| | |
| | |
| | |
| | |
| | out = self.model(inputs_embeds = embeddings) |
| | |
| | logits = out.logits[:,self.num_prompts:] |
| | return logits |
| |
|
| | def generate_new(self, X): |
| | batch_size = X.shape[0] |
| | self.learnable_prompt = self.learnable_prompt.to(X.device) |
| | embeddings = self.model.transformer.wte(X) |
| | embeddings = torch.cat([self.learnable_prompt[None, :, :].repeat(batch_size, 1, 1), embeddings], dim=1) |
| |
|
| | cnt = 0 |
| | past_key_values = None |
| | generated_ids = torch.tensor([], dtype=torch.long, device=X.device).view(batch_size, 0) |
| |
|
| | while cnt < 196: |
| |
|
| | out = self.model(inputs_embeds=embeddings, use_cache=True, past_key_values=past_key_values) |
| | past_key_values = out.past_key_values |
| | |
| | if cnt == 0: |
| | logits = out.logits[:, self.num_prompts:] |
| | else: |
| | logits = out.logits |
| |
|
| | logits[:, :, 50257:] = -1e4 |
| |
|
| | next_token_ids = top_p_sampling(logits[:, -1, :]) |
| | |
| | print(next_token_ids.shape) |
| | exit() |
| | generated_ids = torch.cat([generated_ids, next_token_ids.unsqueeze(-1)], dim=-1) |
| |
|
| | embeddings = self.model.transformer.wte(next_token_ids) |
| |
|
| |
|
| | cnt += 1 |
| |
|
| | |
| | if torch.all((generated_ids == self.eot.item()).any(dim=-1)): |
| | break |
| | |
| | return generated_ids |
| | def generate(self, X): |
| | |
| | self.learnable_prompt = self.learnable_prompt.to(X.device) |
| | embeddings = self.model.transformer.wte(X, ) |
| | embeddings = torch.cat([self.learnable_prompt[None, :, :].repeat(X.shape[0], 1, 1), embeddings], dim=1) |
| |
|
| | cnt = 0 |
| | past_key_values = None |
| | final_prediction = torch.tensor([], dtype=torch.long).to(X.device) |
| | while cnt < 196: |
| | out = self.model(inputs_embeds = embeddings, use_cache=True, past_key_values=past_key_values) |
| | |
| | past_key_values = out.past_key_values |
| | if cnt == 0: |
| | logits = out.logits[:,self.num_prompts:] |
| | logits[:,:, 50257:] = -1e4 |
| |
|
| | output = top_p_sampling(logits[:,-1,:], temperature=self.temperature)[:, None] |
| |
|
| | |
| | |
| | final_prediction = torch.cat([final_prediction, output], dim=1) |
| | |
| | embeddings = self.model.transformer.wte(output) |
| | |
| | |
| |
|
| | else: |
| | |
| | logits = out.logits |
| | logits[:, :, 50257:] = -1e4 |
| |
|
| | output = top_p_sampling(logits[:,-1,:], temperature=self.temperature)[:,None] |
| | |
| | final_prediction = torch.cat([final_prediction, output], dim=1) |
| | |
| | embeddings = self.model.transformer.wte(output) |
| |
|
| | |
| |
|
| | cnt += 1 |
| | |
| | if torch.all((final_prediction == self.eot.item()).any(dim=-1)): |
| | break |
| |
|
| | return final_prediction |
| |
|
| | |
| |
|
| |
|
| | class LMModel(nn.Module): |
| | def __init__(self, num_prompts=0): |
| | super().__init__() |
| | self.num_prompts = num_prompts |
| |
|
| | self.model = AutoModelForCausalLM.from_pretrained("gpt2", ) |
| | self.model.requires_grad_(False) |
| | self.tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2") |
| | self.tokenizer.add_special_tokens({'pad_token': '[PAD]'}) |
| | self.tokenizer.add_special_tokens({'pad_token': '[START]'}) |
| |
|
| | self.eot = self.tokenizer("<|endoftext|>", return_tensors="pt").input_ids[0] |
| | self.pad = self.tokenizer("[PAD]", return_tensors="pt").input_ids[0] |
| | |
| |
|
| | self.model.lm_head.requires_grad_(True) |
| |
|
| | |
| | def forward(self, X, y): |
| | embeddings = self.model.transformer.wte(X, ) |
| | logits = self.model(inputs_embeds = embeddings).logits |
| | return logits |
| |
|
| | def generate(self, X): |
| | |
| | |
| | embeddings = self.model.transformer.wte(X, ) |
| | |
| |
|
| | cnt = 0 |
| | past_key_values = None |
| | final_prediction = torch.tensor([], dtype=torch.long).to(X.device) |
| | while cnt < 196: |
| | out = self.model(inputs_embeds = embeddings, use_cache=True, past_key_values=past_key_values) |
| | |
| | past_key_values = out.past_key_values |
| | if cnt == 0: |
| | logits = out.logits[:,self.num_prompts:] |
| | logits[:,:, 50257:] = -1e4 |
| |
|
| | output = top_p_sampling(logits[:,-1,:], temperature=self.temperature)[:,None] |
| |
|
| | |
| | |
| | final_prediction = torch.cat([final_prediction, output], dim=1) |
| | |
| | embeddings = self.model.transformer.wte(output) |
| | |
| | |
| |
|
| | else: |
| | |
| | logits = out.logits |
| | logits[:, :, 50257:] = -1e4 |
| |
|
| | output = top_p_sampling(logits[:,-1,:], temperature=self.temperature)[:,None] |
| | |
| | final_prediction = torch.cat([final_prediction, output], dim=1) |
| | |
| | embeddings = self.model.transformer.wte(output) |
| |
|
| | |
| |
|
| | cnt += 1 |
| | |
| | if torch.all((final_prediction == self.eot.item()).any(dim=-1)): |
| | break |
| |
|
| | return final_prediction |
| |
|
| | def zero_after_x(tensor, x): |
| | """ |
| | Zeros out all elements in each row of a 2D tensor after the first occurrence of x. |
| | |
| | Args: |
| | tensor: The input 2D tensor. |
| | x: The value after which to zero out elements. |
| | |
| | Returns: |
| | A new tensor with elements zeroed out after x. |
| | """ |
| |
|
| | mask = (tensor == x).cumsum(dim=1) > 0 |
| | result = tensor.where(~mask, torch.ones_like(tensor, dtype=torch.long)*x) |
| |
|
| | return result |
| |
|
| | class LitModelPromptTuning(L.LightningModule): |
| | def __init__(self, model, lr=1e-4, temperature): |
| | super().__init__() |
| | self.model = model |
| | self.lr = lr |
| | self.model.temperature = temperature |
| | tokenize_to_strings = lambda text: self.model.tokenizer.convert_ids_to_tokens(self.model.tokenizer(text)["input_ids"]) |
| | self.rouge = ROUGEScore(tokenizer=tokenize_to_strings) |
| |
|
| | self.save_hyperparameters(ignore=['model']) |
| |
|
| | |
| | def training_step(self, batch, batch_idx): |
| | X, y = batch |
| | |
| | |
| | |
| | logits = self.model(X, y) |
| |
|
| | logits[:,:, 50257:] = -1e4 |
| | |
| | loss = F.cross_entropy(logits[:,:-1,:].reshape(-1, logits.shape[-1]), target=y[:,:-1].reshape(-1), ignore_index=50257) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | |
| |
|
| | self.log('Training loss', loss, on_step=True, on_epoch=True,logger=True, sync_dist=True) |
| | return loss |
| |
|
| |
|
| | def validation_step(self, batch, batch_idx): |
| | X, y = batch |
| |
|
| | logits = self.model(X, y) |
| | logits[:,:, 50257:] = -1e4 |
| | |
| | loss = F.cross_entropy(logits[:,:-1,:].reshape(-1, logits.shape[-1]), target=y[:,:-1].reshape(-1), ignore_index=50257) |
| | |
| | self.log('Validation loss', loss, on_step=True, on_epoch=True, logger=True, sync_dist=True) |
| | return loss |
| |
|
| | def on_test_epoch_start(self, ): |
| | self.all_text = [] |
| | self.predicted_text = [] |
| |
|
| | def test_step(self, batch, batch_idx): |
| | if batch_idx == 0: |
| | return |
| | X, y = batch |
| | |
| | |
| | |
| | out = self.model.generate(X) |
| | |
| | |
| | |
| | pred = self.model.tokenizer.batch_decode(out, skip_special_tokens=True) |
| | gt = self.model.tokenizer.batch_decode(y, skip_special_tokens=True) |
| | |
| | |
| | print(pred) |
| | print('GAP') |
| | print(gt) |
| | final_score = 0 |
| |
|
| | for p,g in zip(pred, gt): |
| | score = self.rouge(p, g, ) |
| | print(score) |
| | |
| | |
| |
|
| | self.log_dict(score, on_step=True, on_epoch=True, logger=True, sync_dist=True) |
| | |
| |
|
| |
|
| | def configure_optimizers(self): |
| | optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr) |
| | return optimizer |
| | |
| | from lightning.pytorch.loggers import WandbLogger |
| | if __name__ == '__main__': |
| | torch.set_float32_matmul_precision('medium') |
| | dl_train, dl_val, dl_test = utils.import_data(bs=25, fraction=0.1) |
| | |
| | gpt_model = LMModel(num_prompts=0) |
| | |
| | |
| | model = LitModelPromptTuning( |
| | model=gpt_model, |
| | lr=1e-4, |
| | temperature=0.9, |
| | epoch = 10 |
| |
|
| | |
| | ) |
| | print('Training') |
| | |
| | logger = WandbLogger(project='Anlp-3') |
| | trainer = L.Trainer( |
| | accelerator='gpu', |
| | |
| | |
| | |
| | devices=[2], |
| | default_root_dir=f'./logs/', |
| | num_nodes=1, |
| | num_sanity_val_steps=1, |
| | precision='bf16-mixed', |
| | max_epochs=5, |
| | check_val_every_n_epoch=1, |
| | log_every_n_steps=20, |
| | logger=logger, |
| | |
| | ) |
| |
|
| | |
| | trainer.fit(model, train_dataloaders=dl_train, val_dataloaders=dl_val) |
| | trainer.test(model, dataloaders=dl_test) |