| from math import inf |
| |
| import torch |
| import torch.nn as nn |
| import numpy as np |
| import torch.utils |
| import torch.utils.data |
| |
| from torch.nn.utils.rnn import pad_sequence,pad_packed_sequence,pack_padded_sequence |
| import wandb |
| import torch.nn.functional as F |
|
|
| import einops |
| from transformers import AutoModelForCausalLM, AutoTokenizer, GPT2TokenizerFast |
|
|
| np.random.seed(123) |
| torch.manual_seed(123) |
| torch.cuda.random.manual_seed(123) |
|
|
| import lightning as L |
| import utils |
| from torchmetrics.text.rouge import ROUGEScore |
| def top_p_sampling(logits, p=0.9, temperature=0.5): |
|
|
| |
| logits = logits / temperature |
|
|
| |
| sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) |
| cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
|
|
| |
| sorted_indices_to_remove = cumulative_probs > p |
| |
| sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
| sorted_indices_to_remove[..., 0] = 0 |
|
|
| |
| indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) |
| logits[indices_to_remove] = float('-inf') |
|
|
| |
| probs = F.softmax(logits, dim=-1) |
| sampled_indices = torch.multinomial(probs, num_samples=1) |
| sampled_indices = sampled_indices.squeeze(1) |
|
|
| return sampled_indices |
|
|
| class PromptTuningModel(nn.Module): |
| def __init__(self, num_prompts=6): |
| super().__init__() |
| self.num_prompts = num_prompts |
|
|
| self.model = AutoModelForCausalLM.from_pretrained("gpt2", ) |
| |
| |
| self.model.requires_grad_(False) |
| self.tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2") |
| self.tokenizer.add_special_tokens({'pad_token': '[PAD]'}) |
| self.tokenizer.add_special_tokens({'cls_token': '[START]'}) |
|
|
| self.eot = self.tokenizer("<|endoftext|>", return_tensors="pt").input_ids[0] |
| self.pad = self.tokenizer("[PAD]", return_tensors="pt").input_ids[0] |
| self.start = self.tokenizer("[START]", return_tensors="pt").input_ids[0] |
| |
|
|
| self.model.resize_token_embeddings(new_num_tokens = len(self.tokenizer),) |
|
|
| tmp = self.tokenizer('summarise', return_tensors="pt").input_ids |
| token_embedding = self.model.transformer.wte(tmp[0]) |
| self.token_embedding = token_embedding |
| for _ in range(num_prompts//3-1): |
| self.token_embedding = torch.cat([self.token_embedding, token_embedding]) |
|
|
| |
| data = torch.zeros(num_prompts, 768) + self.token_embedding[:] |
| self.learnable_prompt = nn.Parameter(data, requires_grad=True) |
|
|
| |
|
|
| |
| def forward(self, X, y): |
| self.learnable_prompt = self.learnable_prompt.to(X.device) |
| embeddings = self.model.transformer.wte(X, ) |
| |
| embeddings = torch.cat([self.learnable_prompt[None, :, :].repeat(X.shape[0], 1, 1), embeddings], dim=1) |
| |
| |
| |
| |
| |
| out = self.model(inputs_embeds = embeddings) |
| |
| logits = out.logits[:,self.num_prompts:] |
| return logits |
|
|
| def generate_new(self, X): |
| batch_size = X.shape[0] |
| self.learnable_prompt = self.learnable_prompt.to(X.device) |
| embeddings = self.model.transformer.wte(X) |
| embeddings = torch.cat([self.learnable_prompt[None, :, :].repeat(batch_size, 1, 1), embeddings], dim=1) |
|
|
| cnt = 0 |
| past_key_values = None |
| generated_ids = torch.tensor([], dtype=torch.long, device=X.device).view(batch_size, 0) |
|
|
| while cnt < 196: |
|
|
| out = self.model(inputs_embeds=embeddings, use_cache=True, past_key_values=past_key_values) |
| past_key_values = out.past_key_values |
| |
| if cnt == 0: |
| logits = out.logits[:, self.num_prompts:] |
| else: |
| logits = out.logits |
|
|
| logits[:, :, 50257:] = -1e4 |
|
|
| next_token_ids = top_p_sampling(logits[:, -1, :]) |
| |
| print(next_token_ids.shape) |
| exit() |
| generated_ids = torch.cat([generated_ids, next_token_ids.unsqueeze(-1)], dim=-1) |
|
|
| embeddings = self.model.transformer.wte(next_token_ids) |
|
|
|
|
| cnt += 1 |
|
|
| |
| if torch.all((generated_ids == self.eot.item()).any(dim=-1)): |
| break |
| |
| return generated_ids |
| def generate(self, X): |
| |
| self.learnable_prompt = self.learnable_prompt.to(X.device) |
| embeddings = self.model.transformer.wte(X, ) |
| embeddings = torch.cat([self.learnable_prompt[None, :, :].repeat(X.shape[0], 1, 1), embeddings], dim=1) |
|
|
| cnt = 0 |
| past_key_values = None |
| final_prediction = torch.tensor([], dtype=torch.long).to(X.device) |
| while cnt < 196: |
| out = self.model(inputs_embeds = embeddings, use_cache=True, past_key_values=past_key_values) |
| |
| past_key_values = out.past_key_values |
| if cnt == 0: |
| logits = out.logits[:,self.num_prompts:] |
| logits[:,:, 50257:] = -1e4 |
|
|
| output = top_p_sampling(logits[:,-1,:], temperature=self.temperature)[:,None] |
|
|
| |
| |
| final_prediction = torch.cat([final_prediction, output], dim=1) |
| |
| embeddings = self.model.transformer.wte(output) |
| |
| |
|
|
| else: |
| |
| logits = out.logits |
| logits[:, :, 50257:] = -1e4 |
|
|
| output = top_p_sampling(logits[:,-1,:], temperature=self.temperature)[:,None] |
| |
| final_prediction = torch.cat([final_prediction, output], dim=1) |
| |
| embeddings = self.model.transformer.wte(output) |
|
|
| |
|
|
| cnt += 1 |
| |
| if torch.all((final_prediction == self.eot.item()).any(dim=-1)): |
| break |
|
|
| return final_prediction |
| from peft import PeftModel, LoraConfig, get_peft_model |
| class LoraModel(nn.Module): |
| def __init__(self, dim=8): |
| super().__init__() |
| self.num_prompts = 0 |
| self.dim = dim |
|
|
| self.model = AutoModelForCausalLM.from_pretrained("gpt2", ) |
| |
| |
| self.model.requires_grad_(False) |
| self.tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2") |
| self.tokenizer.add_special_tokens({'pad_token': '[PAD]'}) |
| self.tokenizer.add_special_tokens({'cls_token': '[START]'}) |
|
|
| self.eot = self.tokenizer("<|endoftext|>", return_tensors="pt").input_ids[0] |
| self.pad = self.tokenizer("[PAD]", return_tensors="pt").input_ids[0] |
| self.start = self.tokenizer("[START]", return_tensors="pt").input_ids[0] |
| |
|
|
| self.model.resize_token_embeddings(new_num_tokens = len(self.tokenizer),) |
| |
| lora_config = LoraConfig( |
| r=dim, |
| lora_alpha=32, |
| target_modules=["c_attn"], |
| lora_dropout=0.05, |
| bias="none", |
| task_type="CAUSAL_LM" |
| ) |
| self.model = get_peft_model(self.model, lora_config) |
|
|
|
|
|
|
| |
|
|
| |
| def forward(self, X, y): |
| embeddings = self.model.transformer.wte(X, ) |
| logits = self.model(inputs_embeds = embeddings).logits |
| return logits |
|
|
| def generate(self, X): |
| |
| |
| embeddings = self.model.transformer.wte(X, ) |
| |
|
|
| cnt = 0 |
| past_key_values = None |
| final_prediction = torch.tensor([], dtype=torch.long).to(X.device) |
| while cnt < 196: |
| out = self.model(inputs_embeds = embeddings, use_cache=True, past_key_values=past_key_values) |
| |
| past_key_values = out.past_key_values |
| if cnt == 0: |
| logits = out.logits[:,self.num_prompts:] |
| logits[:,:, 50257:] = -1e4 |
|
|
| output = top_p_sampling(logits[:,-1,:], temperature=self.temperature)[:,None] |
|
|
| |
| |
| final_prediction = torch.cat([final_prediction, output], dim=1) |
| |
| embeddings = self.model.transformer.wte(output) |
| |
| |
|
|
| else: |
| |
| logits = out.logits |
| logits[:, :, 50257:] = -1e4 |
|
|
| output = top_p_sampling(logits[:,-1,:], temperature=self.temperature)[:,None] |
| |
| final_prediction = torch.cat([final_prediction, output], dim=1) |
| |
| embeddings = self.model.transformer.wte(output) |
|
|
| |
|
|
| cnt += 1 |
| |
| if torch.all((final_prediction == self.eot.item()).any(dim=-1)): |
| break |
|
|
| return final_prediction |
|
|
|
|
| class LMModel(nn.Module): |
| def __init__(self, num_prompts=6): |
| super().__init__() |
| self.num_prompts = num_prompts |
|
|
| self.model = AutoModelForCausalLM.from_pretrained("gpt2", ) |
| |
| |
| self.model.requires_grad_(False) |
| self.tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2") |
| self.tokenizer.add_special_tokens({'pad_token': '[PAD]'}) |
| self.tokenizer.add_special_tokens({'cls_token': '[START]'}) |
|
|
| self.eot = self.tokenizer("<|endoftext|>", return_tensors="pt").input_ids[0] |
| self.pad = self.tokenizer("[PAD]", return_tensors="pt").input_ids[0] |
| self.start = self.tokenizer("[START]", return_tensors="pt").input_ids[0] |
| |
|
|
| self.model.resize_token_embeddings(new_num_tokens = len(self.tokenizer),) |
| |
|
|
| self.model.lm_head.requires_grad_(True) |
|
|
| |
| def forward(self, X, y): |
| embeddings = self.model.transformer.wte(X, ) |
| logits = self.model(inputs_embeds = embeddings).logits |
| return logits |
|
|
| def generate(self, X): |
| |
| |
| embeddings = self.model.transformer.wte(X, ) |
| |
|
|
| cnt = 0 |
| past_key_values = None |
| final_prediction = torch.tensor([], dtype=torch.long).to(X.device) |
| while cnt < 196: |
| out = self.model(inputs_embeds = embeddings, use_cache=True, past_key_values=past_key_values) |
| |
| past_key_values = out.past_key_values |
| if cnt == 0: |
| logits = out.logits[:,self.num_prompts:] |
| logits[:,:, 50257:] = -1e4 |
|
|
| output = top_p_sampling(logits[:,-1,:], temperature=self.temperature)[:,None] |
|
|
| |
| |
| final_prediction = torch.cat([final_prediction, output], dim=1) |
| |
| embeddings = self.model.transformer.wte(output) |
| |
| |
|
|
| else: |
| |
| logits = out.logits |
| logits[:, :, 50257:] = -1e4 |
|
|
| output = top_p_sampling(logits[:,-1,:], temperature=self.temperature)[:,None] |
| |
| final_prediction = torch.cat([final_prediction, output], dim=1) |
| |
| embeddings = self.model.transformer.wte(output) |
|
|
| |
|
|
| cnt += 1 |
| |
| if torch.all((final_prediction == self.eot.item()).any(dim=-1)): |
| break |
|
|
| return final_prediction |
|
|
| def zero_after_x(arr, x): |
| """ |
| Zeros out all elements in each row of a 2D tensor after the first occurrence of x. |
| |
| Args: |
| tensor: The input 2D tensor. |
| x: The value after which to zero out elements. |
| |
| Returns: |
| A new tensor with elements zeroed out after x. |
| """ |
|
|
| mask = (arr == x).cumsum(dim=1) > 0 |
| result = torch.where(mask, x, arr) |
|
|
| return result |
|
|
| class LitModelPromptTuning(L.LightningModule): |
| def __init__(self, model, temperature, epoch, lr=1e-4, **kwargs): |
| super().__init__() |
| self.model = model |
| self.lr = lr |
| self.model.temperature = temperature |
| self.epoch = epoch |
| self.temperature = temperature |
|
|
| for key, value in kwargs.items(): |
| setattr(self, key, value) |
|
|
| tokenize_to_strings = lambda text: self.model.tokenizer.convert_ids_to_tokens(self.model.tokenizer(text)["input_ids"]) |
| self.rouge = ROUGEScore(tokenizer=tokenize_to_strings) |
|
|
| self.save_hyperparameters(ignore=['model']) |
|
|
| |
| def training_step(self, batch, batch_idx): |
| X, y = batch |
| |
| |
| |
| logits = self.model(X, y) |
|
|
| logits[:,:, 50257:] = -1e4 |
| |
| loss = F.cross_entropy(logits[:,:-1,:].reshape(-1, logits.shape[-1]), target=y[:,:-1].reshape(-1), ignore_index=50257) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
|
|
| |
|
|
| self.log('Training loss', loss, on_step=True, on_epoch=True,logger=True, sync_dist=True) |
| return loss |
|
|
|
|
| def validation_step(self, batch, batch_idx): |
| X, y = batch |
|
|
| logits = self.model(X, y) |
| logits[:,:, 50257:] = -1e4 |
| |
| loss = F.cross_entropy(logits[:,:-1,:].reshape(-1, logits.shape[-1]), target=y[:,:-1].reshape(-1), ignore_index=50257) |
| |
| self.log('Validation loss', loss, on_step=True, on_epoch=True, logger=True, sync_dist=True) |
| return loss |
|
|
| def on_test_epoch_start(self, ): |
| self.all_text = [] |
| self.predicted_text = [] |
|
|
| def test_step(self, batch, batch_idx): |
| if batch_idx == 0: |
| return |
| X, y = batch |
| |
| |
| |
| out = self.model.generate(X) |
| |
| |
| |
| pred = self.model.tokenizer.batch_decode(out, skip_special_tokens=False) |
| gt = self.model.tokenizer.batch_decode(y, skip_special_tokens=False) |
| |
| |
| print(pred) |
| print('GAP') |
| print(gt) |
| final_score = 0 |
|
|
| for p,g in zip(pred, gt): |
| score = self.rouge(p, g, ) |
| print(score) |
| |
| |
|
|
| self.log_dict(score, on_step=True, on_epoch=True, logger=True, sync_dist=True) |
| |
|
|
|
|
| def configure_optimizers(self): |
| optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr) |
| return optimizer |
| |
| from lightning.pytorch.loggers import WandbLogger |
| if __name__ == '__main__': |
| train = False |
|
|
| torch.set_float32_matmul_precision('medium') |
| dl_train, dl_val, dl_test = utils.import_data(bs=24, fraction=1) |
| |
| if train: |
| gpt_model = LoraModel(dim=16) |
| else: |
| gpt_model = torch.load('./model1.bin') |
| |
| |
| model = LitModelPromptTuning( |
| model=gpt_model, |
| lr=1e-3, |
| temperature=0.9, |
| epoch = 5, |
|
|
| type_model = 'lora', |
| dimension = 16 |
| ) |
| print('Training') |
| |
| logger = WandbLogger(project='Anlp-3') |
| trainer = L.Trainer( |
| accelerator='gpu', |
| |
| |
| |
| devices=1, |
| default_root_dir=f'./logs/', |
| num_nodes=1, |
| num_sanity_val_steps=1, |
| precision='bf16-mixed', |
| max_epochs=5, |
| check_val_every_n_epoch=1, |
| log_every_n_steps=20, |
| logger=logger, |
| |
| ) |
|
|
| if train: |
| trainer.fit(model, train_dataloaders=dl_train, val_dataloaders=dl_val) |
| trainer.test(model, dataloaders=dl_test) |
| torch.save(model.model, './model2.bin') |
| else: |
| trainer.test(model, dataloaders=dl_test) |