# Example of Fine-tuning 7.1 billion Bloom with 8-bit weights

This notebook shows an example of how to fine tune Bloom with Low Rank Adapters. Heavily inspired by [Hivemind's work](https://colab.research.google.com/drive/1ft6wQU0BhqG5PRlwgaZJv2VukKKjU4Es)

### Load and convert original Bloom structure to 8-bit LoRA

You can load an already compressed 8-bit version of Bloom from [joaoalvarenga/bloom-8bit](https://huggingface.co/joaoalvarenga/bloom-8bit), but first we need to make some adaptations into original model structure. Some of the following code is an adaptation from [Hivemind's GPT-J 8-bit fine-tuning notebook](https://colab.research.google.com/drive/1ft6wQU0BhqG5PRlwgaZJv2VukKKjU4Es).

In [1]:
import transformers

import torch
import torch.nn.functional as F
from torch import nn
from torch.cuda.amp import custom_fwd, custom_bwd

from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise

from tqdm.auto import tqdm

  from .autonotebook import tqdm as notebook_tqdm



Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues
For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link
CUDA SETUP: Loading binary /home/dm/.local/lib/python3.8/site-packages/bitsandbytes/libbitsandbytes_cpu.so...


  warn(
  warn(
  warn(
  warn(


In [2]:
class FrozenBNBLinear(nn.Module):
    def __init__(self, weight, absmax, code, bias=None):
        assert isinstance(bias, nn.Parameter) or bias is None
        super().__init__()
        self.out_features, self.in_features = weight.shape
        self.register_buffer("weight", weight.requires_grad_(False))
        self.register_buffer("absmax", absmax.requires_grad_(False))
        self.register_buffer("code", code.requires_grad_(False))
        self.adapter = None
        self.bias = bias
 
    def forward(self, input):
        output = DequantizeAndLinear.apply(input, self.weight, self.absmax, self.code, self.bias)
        if self.adapter:
            output += self.adapter(input)
        return output
 
    @classmethod
    def from_linear(cls, linear: nn.Linear) -> "FrozenBNBLinear":
        weights_int8, state = quantize_blockise_lowmemory(linear.weight)
        return cls(weights_int8, *state, linear.bias)
 
    def __repr__(self):
        return f"{self.__class__.__name__}({self.in_features}, {self.out_features})"
 
 
class DequantizeAndLinear(torch.autograd.Function): 
    @staticmethod
    @custom_fwd
    def forward(ctx, input: torch.Tensor, weights_quantized: torch.ByteTensor,
                absmax: torch.FloatTensor, code: torch.FloatTensor, bias: torch.FloatTensor):
        weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)
        ctx.save_for_backward(input, weights_quantized, absmax, code)
        ctx._has_bias = bias is not None
        return F.linear(input, weights_deq, bias)
 
    @staticmethod
    @custom_bwd
    def backward(ctx, grad_output: torch.Tensor):
        assert not ctx.needs_input_grad[1] and not ctx.needs_input_grad[2] and not ctx.needs_input_grad[3]
        input, weights_quantized, absmax, code = ctx.saved_tensors
        # grad_output: [*batch, out_features]
        weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)
        grad_input = grad_output @ weights_deq
        grad_bias = grad_output.flatten(0, -2).sum(dim=0) if ctx._has_bias else None
        return grad_input, None, None, None, grad_bias
 
 
class FrozenBNBEmbedding(nn.Module):
    def __init__(self, weight, absmax, code):
        super().__init__()
        self.num_embeddings, self.embedding_dim = weight.shape
        self.register_buffer("weight", weight.requires_grad_(False))
        self.register_buffer("absmax", absmax.requires_grad_(False))
        self.register_buffer("code", code.requires_grad_(False))
        self.adapter = None
 
    def forward(self, input, **kwargs):
        with torch.no_grad():
            # note: both quantuized weights and input indices are *not* differentiable
            weight_deq = dequantize_blockwise(self.weight, absmax=self.absmax, code=self.code)
            output = F.embedding(input, weight_deq, **kwargs)
        if self.adapter:
            output += self.adapter(input)
        return output 
 
    @classmethod
    def from_embedding(cls, embedding: nn.Embedding) -> "FrozenBNBEmbedding":
        weights_int8, state = quantize_blockise_lowmemory(embedding.weight)
        return cls(weights_int8, *state)
 
    def __repr__(self):
        return f"{self.__class__.__name__}({self.num_embeddings}, {self.embedding_dim})"
 
 
def quantize_blockise_lowmemory(matrix: torch.Tensor, chunk_size: int = 2 ** 20):
    assert chunk_size % 4096 == 0
    code = None
    chunks = []
    absmaxes = []
    flat_tensor = matrix.view(-1)
    for i in range((matrix.numel() - 1) // chunk_size + 1):
        input_chunk = flat_tensor[i * chunk_size: (i + 1) * chunk_size].clone()
        quantized_chunk, (absmax_chunk, code) = quantize_blockwise(input_chunk, code=code)
        chunks.append(quantized_chunk)
        absmaxes.append(absmax_chunk)
 
    matrix_i8 = torch.cat(chunks).reshape_as(matrix)
    absmax = torch.cat(absmaxes)
    return matrix_i8, (absmax, code)


def convert_to_int8(model):
    """Convert linear and embedding modules to 8-bit with optional adapters"""
    for module in list(model.modules()):
        for name, child in module.named_children():
            if isinstance(child, nn.Linear):
                print(name, child)
                setattr( 
                    module,
                    name,
                    FrozenBNBLinear(
                        weight=torch.zeros(child.out_features, child.in_features, dtype=torch.uint8),
                        absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),
                        code=torch.zeros(256),
                        bias=child.bias,
                    ),
                )
            elif isinstance(child, nn.Embedding):
                setattr(
                    module,
                    name,
                    FrozenBNBEmbedding(
                        weight=torch.zeros(child.num_embeddings, child.embedding_dim, dtype=torch.uint8),
                        absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),
                        code=torch.zeros(256),
                    )
                )

In [3]:
class BloomBlock(transformers.models.bloom.modeling_bloom.BloomBlock):
    def __init__(self, config):
        super().__init__(config)
        convert_to_int8(self.self_attention)
        convert_to_int8(self.mlp)


class BloomModel(transformers.models.bloom.modeling_bloom.BloomModel):
    def __init__(self, config):
        super().__init__(config)
        convert_to_int8(self)
        

class BloomForCausalLM(transformers.models.bloom.modeling_bloom.BloomForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        convert_to_int8(self)
        
transformers.models.bloom.modeling_bloom.BloomBlock = BloomBlock

In [None]:
#!g1.1
from transformers import BloomForCausalLM, AutoModel
tokenizer = transformers.AutoTokenizer.from_pretrained("bigscience/bloom-7b1",  cache_dir="mycache")
model = BloomForCausalLM.from_pretrained('bloom-8bit-v4.pt')
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
pass

query_key_value Linear(in_features=4096, out_features=12288, bias=True)
dense Linear(in_features=4096, out_features=4096, bias=True)
dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)
dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)
query_key_value Linear(in_features=4096, out_features=12288, bias=True)
dense Linear(in_features=4096, out_features=4096, bias=True)
dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)
dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)
query_key_value Linear(in_features=4096, out_features=12288, bias=True)
dense Linear(in_features=4096, out_features=4096, bias=True)
dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)
dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)
query_key_value Linear(in_features=4096, out_features=12288, bias=True)
dense Linear(in_features=4096, out_features=4096, bias=True)
dense_h_to_4h Linear(in_features=4096, out_featu

In [None]:
model

In [None]:
prefix = """"It is a fantasy role-play game.

Game Master: You are John, a wizard living in the kingdom of Larion. You have a staff and a spellbook. You finish your long journey and finally arrive at the ruin you've been looking for. You have come here searching for a mystical spellbook of great power called the book of essence. You look around and see the ancient ruins of an elf tower. The ruins have not been touched for decades. You look at the tower, and you can see a set of stone stairs that seem to lead somewhere deep inside the tower.
Player: I walk upstairs
Game Master: You climb up the stairs in the ruined tower. There is a door on the second floor of the tower, the door seems to be made of enchanted wood.
Player: I ask the door if I may to come in
Game Master: The door sighs open and you walk into the room.
Player: I take a look around
Game Master:"""

print(end=prefix)
past_key_values = None  # used to keep track of conversation history
input_dict = tokenizer([prefix], return_tensors='pt', padding=False)

output=""

with torch.inference_mode():
    for i in range(200):
        outputs = model.forward(**input_dict, use_cache=True, past_key_values=past_key_values)
        last_logits = outputs.logits[0, -1]
        
        last_logits[last_logits.topk(k=10).indices] += 10 # other logits are now e^10 times less likely to be chosen

        past_key_values = outputs.past_key_values
        token_ix = torch.multinomial(last_logits.softmax(-1), 1).item()
        prefix = tokenizer.decode([token_ix])
        output = output + tokenizer.decode([token_ix])
        if 'player' in output or 'Player' in output:
            break
        if 'Master' in output:
            break
        print(end=tokenizer.decode([token_ix]), flush=True)

        input_dict = dict(input_ids=torch.tensor([[token_ix]]))
print()

In [None]:
prefix = """"It is a fantasy role-play game.

Game Master: You are John, a wizard living in the kingdom of Larion. You have a staff and a spellbook. You finish your long journey and finally arrive at the ruin you've been looking for. You have come here searching for a mystical spellbook of great power called the book of essence. You look around and see the ancient ruins of an elf tower. The ruins have not been touched for decades. You look at the tower, and you can see a set of stone stairs that seem to lead somewhere deep inside the tower.
Player: I walk upstairs
Game Master: You climb up the stairs in the ruined tower. There is a door on the second floor of the tower, the door seems to be made of enchanted wood.
Player: I ask the door if I may to come in
Game Master: The door sighs open and you walk into the room.
Player: I take a look around
Game Master:"""

print(end=prefix)
past_key_values = None  # used to keep track of conversation history
input_dict = tokenizer([prefix], return_tensors='pt', padding=False)

output = ""

with torch.inference_mode():
    for i in range(200):
        outputs = model.forward(**input_dict, use_cache=True, past_key_values=past_key_values)
        last_logits = outputs.logits[0, -1]
        
        last_logits[last_logits.topk(k=10).indices] += 10 # other logits are now e^10 times less likely to be chosen

        past_key_values = outputs.past_key_values
        token_ix = torch.multinomial(last_logits.softmax(-1), 1).item()
        prefix = tokenizer.decode([token_ix])
        output = output + tokenizer.decode([token_ix])
        if 'player' in output or 'Player' in output:
            break
        if 'Master' in output:
            break
        print(end=tokenizer.decode([token_ix]), flush=True)

        input_dict = dict(input_ids=torch.tensor([[token_ix]]))
print()

In [None]:
#!g1.1
prompt = tokenizer("A cat sat on a mat and", return_tensors='pt')
out = model.generate(**prompt, min_length=10, max_length=10, do_sample=True)
tokenizer.decode(out[0])

### Fine-tune and save model

In [None]:
#!g1.1
def add_adapters(model, adapter_dim=16):
    assert adapter_dim > 0

    for module in model.modules():
        if isinstance(module, FrozenBNBLinear):
            module.adapter = nn.Sequential(
                nn.Linear(module.in_features, adapter_dim, bias=False),
                nn.Linear(adapter_dim, module.out_features, bias=False),
            )
            nn.init.zeros_(module.adapter[1].weight)
        elif isinstance(module, FrozenBNBEmbedding):
            module.adapter = nn.Sequential(
                nn.Embedding(module.num_embeddings, adapter_dim),
                nn.Linear(adapter_dim, module.embedding_dim, bias=False),
            )
            nn.init.zeros_(module.adapter[1].weight)

add_adapters(model)
model.to(device)

In [None]:
#!g1.1
from datasets import load_dataset
from bitsandbytes.optim import Adam8bit

model.gradient_checkpointing_enable()

wikisql = load_dataset("wikisql", streaming=True)
optimizer = Adam8bit(model.parameters(), lr=1e-5)

with torch.cuda.amp.autocast():
    for row in tqdm(wikisql['train']):

        batch = tokenizer(row['question'] + row['sql']['human_readable'], truncation=True, max_length=128, return_tensors='pt')
        batch = {k: v.cuda() for k, v in batch.items()}

        out = gpt.forward(**batch,)

        loss = F.cross_entropy(out.logits[:, :-1, :].flatten(0, -2), batch['input_ids'][:, 1:].flatten(),
                               reduction='mean')
        print(loss)
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

In [None]:
#!g1.1
model.save_pretrained('bloom-8bit-fine-tuned')