| | |
| | |
| | |
| | |
| |
|
| | import unittest |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from fairseq.modules.checkpoint_activations import checkpoint_wrapper |
| | from torch.utils.checkpoint import checkpoint |
| |
|
| |
|
| | class Model(nn.Module): |
| | def __init__( |
| | self, use_pytorch_checkpoint=False, use_fairseq_checkpoint=False, **kwargs |
| | ): |
| | super().__init__() |
| | torch.manual_seed(0) |
| | self.use_pytorch_checkpoint = use_pytorch_checkpoint |
| | self.ffn = nn.Sequential( |
| | nn.Linear(32, 128), |
| | |
| | nn.Dropout(p=0.5), |
| | nn.Linear(128, 32), |
| | ) |
| | if use_fairseq_checkpoint: |
| | self.ffn = checkpoint_wrapper(self.ffn, **kwargs) |
| | self.out = nn.Linear(32, 1) |
| |
|
| | def forward(self, x): |
| | if self.use_pytorch_checkpoint: |
| | x = checkpoint(self.ffn, x) |
| | else: |
| | x = self.ffn(x) |
| | return self.out(x) |
| |
|
| |
|
| | class TestActivationCheckpointing(unittest.TestCase): |
| | def _test_checkpoint_wrapper(self, device, log_memory_usage=False): |
| | def get_loss_and_gnorm(model): |
| | torch.manual_seed(1) |
| | input = torch.rand(2, 16, 32).requires_grad_(True).to(device) |
| | model.zero_grad() |
| | loss = model(input).sum() |
| | loss.backward() |
| | gnorm = torch.norm( |
| | torch.stack([torch.norm(p.grad.detach()) for p in model.parameters()]) |
| | ) |
| | return {"loss": loss, "gnorm": gnorm} |
| |
|
| | model = Model().to(device) |
| | no_cpt = get_loss_and_gnorm(model) |
| |
|
| | model = Model(use_pytorch_checkpoint=True).to(device) |
| | pyt_cpt = get_loss_and_gnorm(model) |
| | torch.testing.assert_allclose(no_cpt["loss"], pyt_cpt["loss"]) |
| | torch.testing.assert_allclose(no_cpt["gnorm"], pyt_cpt["gnorm"]) |
| |
|
| | model = Model(use_fairseq_checkpoint=True).to(device) |
| | fairseq_cpt = get_loss_and_gnorm(model) |
| | torch.testing.assert_allclose(no_cpt["loss"], fairseq_cpt["loss"]) |
| | torch.testing.assert_allclose(no_cpt["gnorm"], fairseq_cpt["gnorm"]) |
| |
|
| | model = Model(use_fairseq_checkpoint=True, offload_to_cpu=True).to(device) |
| | fairseq_cpt_offload = get_loss_and_gnorm(model) |
| | torch.testing.assert_allclose(no_cpt["loss"], fairseq_cpt_offload["loss"]) |
| | torch.testing.assert_allclose(no_cpt["gnorm"], fairseq_cpt_offload["gnorm"]) |
| |
|
| | def test_checkpoint_wrapper_cpu(self): |
| | self._test_checkpoint_wrapper(device=torch.device("cpu")) |
| |
|
| | @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") |
| | def test_checkpoint_wrapper_cuda(self): |
| | self._test_checkpoint_wrapper(device=torch.device("cuda")) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | unittest.main() |
| |
|