|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| import torch
|
| from torch import nn
|
| from transformers import AutoModelForCausalLM
|
| from transformers.testing_utils import torch_device
|
| from transformers.utils import is_peft_available
|
|
|
| from trl.models.activation_offloading import NoOpManager, OffloadActivations
|
|
|
| from .testing_utils import TrlTestCase, require_peft, require_torch_accelerator
|
|
|
|
|
| if is_peft_available():
|
| from peft import LoraConfig, get_peft_model
|
|
|
|
|
| class TestActivationOffloading(TrlTestCase):
|
| @require_torch_accelerator
|
| @require_peft
|
| def test_offloading_with_peft_models(self) -> None:
|
| """Test that activation offloading works with PEFT models."""
|
| model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
|
| model = AutoModelForCausalLM.from_pretrained(model_id).to(torch_device)
|
| peft_config = LoraConfig(
|
| lora_alpha=16,
|
| lora_dropout=0.1,
|
| r=8,
|
| bias="none",
|
| task_type="CAUSAL_LM",
|
| )
|
|
|
| model = get_peft_model(model, peft_config)
|
| inp = torch.randint(0, 100, (2, 10), device=torch_device)
|
|
|
|
|
| torch.manual_seed(42)
|
| loss = model(inp, labels=inp).loss
|
| loss.backward()
|
|
|
|
|
| grads_original = []
|
| for name, param in model.named_parameters():
|
| if param.requires_grad and param.grad is not None:
|
| grads_original.append((name, param.grad.clone()))
|
|
|
|
|
| for p in model.parameters():
|
| if p.grad is not None:
|
| p.grad = None
|
|
|
|
|
| torch.manual_seed(42)
|
| with OffloadActivations():
|
| loss_c = model(inp, labels=inp).loss
|
| loss_c.backward()
|
|
|
|
|
| for name_orig, grad_orig in grads_original:
|
| for name_param, param in model.named_parameters():
|
| if name_param == name_orig and param.requires_grad and param.grad is not None:
|
| (
|
| torch.testing.assert_close(grad_orig, param.grad, rtol=1e-4, atol=1e-5),
|
| (f"Gradient mismatch for {name_orig}"),
|
| )
|
|
|
| @require_torch_accelerator
|
| def test_noop_manager_with_offloading(self):
|
| model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
|
| model = AutoModelForCausalLM.from_pretrained(model_id).to(torch_device)
|
| inp = torch.randint(0, 100, (2, 10), device=torch_device)
|
|
|
|
|
| with OffloadActivations():
|
|
|
| torch.manual_seed(42)
|
| out1 = model(inp, labels=inp)
|
| out1.loss.backward()
|
| grads1 = [p.grad.clone() for p in model.parameters()]
|
|
|
|
|
| for p in model.parameters():
|
| p.grad = None
|
|
|
|
|
| with NoOpManager():
|
| torch.manual_seed(42)
|
| out2 = model(inp, labels=inp)
|
| out2.loss.backward()
|
|
|
| grads2 = [p.grad.clone() for p in model.parameters()]
|
|
|
|
|
| for g1, g2 in zip(grads1, grads2, strict=True):
|
| torch.testing.assert_close(g1, g2, rtol=1e-4, atol=1e-5)
|
|
|
| @require_torch_accelerator
|
| def test_min_offload_size(self):
|
| """Test that tensors smaller than min_offload_size aren't offloaded"""
|
| model = nn.Sequential(
|
| nn.Linear(5, 5),
|
| nn.Linear(5, 1000),
|
| ).to(torch_device)
|
|
|
| inp = torch.randn(2, 5, device=torch_device)
|
|
|
| with OffloadActivations(min_offload_size=1000):
|
| out = model(inp)
|
| out.sum().backward()
|
|
|
|
|
|
|
|
|
| @require_torch_accelerator
|
| def test_real_hf_model(self):
|
| """Test with an actual HuggingFace model"""
|
| model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
|
| model = AutoModelForCausalLM.from_pretrained(model_id).to(torch_device)
|
|
|
|
|
| inp = torch.randint(0, 100, (2, 10), device=torch_device)
|
|
|
|
|
| torch.manual_seed(42)
|
| out1 = model(inp, labels=inp).loss
|
| out1.backward()
|
| grads1 = [p.grad.clone() for p in model.parameters()]
|
|
|
|
|
| for p in model.parameters():
|
| p.grad = None
|
|
|
|
|
| with OffloadActivations():
|
| torch.manual_seed(42)
|
| out2 = model(inp, labels=inp).loss
|
| out2.backward()
|
|
|
| grads2 = [p.grad.clone() for p in model.parameters()]
|
|
|
|
|
| torch.testing.assert_close(out1, out2)
|
| for g1, g2 in zip(grads1, grads2, strict=True):
|
| torch.testing.assert_close(g1, g2)
|
|
|
| @require_torch_accelerator
|
| def test_tensor_deduplication(self):
|
| """Test that deduplication works correctly for tensors sharing storage"""
|
|
|
| class ModelWithViews(nn.Module):
|
| def __init__(self):
|
| super().__init__()
|
| self.linear = nn.Linear(100, 100)
|
|
|
| def forward(self, x):
|
| out = self.linear(x)
|
| view1 = out.view(-1)
|
| view2 = out.transpose(0, 1)
|
| return view1.sum() + view2.sum()
|
|
|
| model = ModelWithViews().to(torch_device)
|
| offload_ctx = OffloadActivations(min_offload_size=1)
|
| offload_ctx.update_model_params(model)
|
|
|
| x = torch.randn(10, 100, device=torch_device, requires_grad=True)
|
| with offload_ctx:
|
| loss = model(x)
|
|
|
| total_tensor_ids = offload_ctx.tensor_id
|
| assert total_tensor_ids > 0, "Should have created tensor IDs"
|
|
|
|
|
| deduplicated_count = sum(1 for _, modified, _, _, _ in offload_ctx.tracker.values() if not modified)
|
| offloaded_count = sum(1 for _, modified, _, _, _ in offload_ctx.tracker.values() if modified)
|
|
|
| assert offloaded_count > 0, "Should have offloaded at least one tensor"
|
| assert deduplicated_count > 0, "Should have deduplicated at least one tensor (view)"
|
|
|
| unique_storages_offloaded = len(offload_ctx.storage_to_tensor_id)
|
| assert unique_storages_offloaded < total_tensor_ids, (
|
| f"Deduplication should result in fewer storages ({unique_storages_offloaded}) "
|
| f"than total tensors ({total_tensor_ids})"
|
| )
|
|
|
| loss.backward()
|
|
|
| @require_torch_accelerator
|
| def test_parameter_filtering(self):
|
| """Test that model parameters are filtered during offloading"""
|
| model = nn.Sequential(nn.Linear(10, 20), nn.Linear(20, 10)).to(torch_device)
|
| offload_ctx = OffloadActivations()
|
| offload_ctx.update_model_params(model)
|
|
|
| assert len(offload_ctx.param_storages) > 0, "Should have tracked parameter storages"
|
|
|
| param_ptrs = {p.data.untyped_storage().data_ptr() for p in model.parameters()}
|
| assert offload_ctx.param_storages == param_ptrs, "Tracked storages should match parameter storages"
|
|
|