| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| import re |
| import transformers |
| import torch |
| from tqdm import tqdm |
| from transformers import GPT2LMHeadModel, GPT2TokenizerFast |
| import warnings |
| warnings.filterwarnings("ignore") |
| device = "cuda" |
|
|
| tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2") |
|
|
| from datasets import load_dataset |
|
|
| test = load_dataset("wikitext", "wikitext-2-raw-v1", split="validation") |
| |
| encodings = tokenizer("\n\n".join(test["text"]), return_tensors="pt") |
| import time |
| import gc |
| def run_experiment(model): |
| print(f'Memory usage of model alone = {model.get_memory_footprint()/10**6}') |
| max_length = model.config.n_positions |
| stride = 512 |
| seq_len = encodings.input_ids.size(1) |
|
|
| nlls = [] |
| start_time = time.time() |
| prev_end_loc = 0 |
| for begin_loc in tqdm(range(0, seq_len, stride)): |
| end_loc = min(begin_loc + max_length, seq_len) |
| trg_len = end_loc - prev_end_loc |
| input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device) |
| target_ids = input_ids.clone() |
| target_ids[:, :-trg_len] = -100 |
|
|
| with torch.no_grad(): |
| outputs = model(input_ids, labels=target_ids) |
|
|
| |
| neg_log_likelihood = outputs.loss |
|
|
| if begin_loc == 0: |
| print(f'Memory usage at forward pass = {torch.cuda.memory_allocated(0)/10**6}') |
| nlls.append(neg_log_likelihood) |
|
|
| prev_end_loc = end_loc |
| if end_loc == seq_len: |
| break |
|
|
| ppl = torch.exp(torch.stack(nlls).mean()) |
| print(f'Loss = {ppl.item()}') |
| print(f'Time taken: {- start_time + time.time()}') |
|
|
|
|
| from transformers import BitsAndBytesConfig |
|
|
| bnb_config = BitsAndBytesConfig( |
| load_in_4bit=True, |
| ) |
| model =AutoModelForCausalLM.from_pretrained("gpt2", quantization_config=bnb_config ) |
|
|
| |
| print('4 bit model') |
| run_experiment(model) |
|
|
| torch.save(model, 'bnb-4.pth') |
| print() |
|
|
| |
| bnb_config = BitsAndBytesConfig( |
| load_in_8bit=True, |
| ) |
| model =AutoModelForCausalLM.from_pretrained("gpt2", quantization_config=bnb_config ) |
| print('8 bit model') |
| run_experiment(model) |
| torch.save(model, 'bnb-8.pth') |
| print() |
|
|
|
|
| |
| bnb_config = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_quant_type="nf4", |
| ) |
| model =AutoModelForCausalLM.from_pretrained("gpt2", quantization_config=bnb_config ) |
| print('4 bit nf4 model') |
| run_experiment(model) |
| torch.save(model, 'bnb-nf4.pth') |
| print() |
|
|
|
|
|
|
|
|
|
|