| | import torch |
| |
|
| |
|
| | def test_lora_layer_replacement(lit_llama): |
| | from lit_llama.lora import lora, CausalSelfAttention as LoRACausalSelfAttention |
| | from lit_llama.model import LLaMA, LLaMAConfig |
| | |
| | config = LLaMAConfig() |
| | config.n_layer = 2 |
| | config.n_head = 4 |
| | config.n_embd = 8 |
| | config.block_size = 8 |
| | config.vocab_size = 8 |
| |
|
| | with lora(r=8, alpha=8, dropout=0.1): |
| | model = LLaMA(config) |
| |
|
| | assert isinstance(model.transformer.h[0].attn, LoRACausalSelfAttention) |
| | assert isinstance(model.transformer.h[1].attn, LoRACausalSelfAttention) |
| |
|
| |
|
| | def test_lora_merge_unmerge(lit_llama): |
| | from lit_llama.lora import lora, mark_only_lora_as_trainable |
| | from lit_llama.model import LLaMA, LLaMAConfig |
| | |
| | config = LLaMAConfig(n_layer=1, n_head=2, n_embd=8, block_size=8, vocab_size=8) |
| |
|
| | with lora(r=8, alpha=8, dropout=0.1): |
| | model = LLaMA(config) |
| | |
| | initial_weight = model.transformer.h[0].attn.c_attn.weight.clone() |
| | model.train() |
| | assert torch.equal(model.transformer.h[0].attn.c_attn.weight, initial_weight) |
| |
|
| | |
| | mark_only_lora_as_trainable(model) |
| | optimizer = torch.optim.SGD(model.parameters(), lr=1.0) |
| | model(torch.randint(0, 8, size=(2, 4), dtype=torch.int64)).sum().backward() |
| | optimizer.step() |
| | optimizer.zero_grad() |
| | |
| | assert torch.equal(model.transformer.h[0].attn.c_attn.weight, initial_weight) |
| |
|
| | |
| | weight_before = model.transformer.h[0].attn.c_attn.weight.clone() |
| | model.eval() |
| | assert not torch.equal(model.transformer.h[0].attn.c_attn.weight, weight_before) |
| | model.train() |
| | |
| | assert torch.allclose(model.transformer.h[0].attn.c_attn.weight, weight_before) |
| |
|
| | |
| | model.eval() |
| | assert model.transformer.h[0].attn.c_attn.merged |
| | weight_after = model.transformer.h[0].attn.c_attn.weight.clone() |
| | model.eval() |
| | model.eval() |
| | assert torch.equal(model.transformer.h[0].attn.c_attn.weight, weight_after) |
| | model.train() |
| | assert not model.transformer.h[0].attn.c_attn.merged |
| | weight_after = model.transformer.h[0].attn.c_attn.weight.clone() |
| | model.train() |
| | model.train() |
| | assert torch.equal(model.transformer.h[0].attn.c_attn.weight, weight_after) |
| |
|