| #!/usr/bin/env python3 | |
| import torch | |
| import gc | |
| shape = (10,000) | |
| input = torch.ones((shape, shape), device="cuda") | |
| def clear_memory(model): | |
| model.to('cpu') | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| torch.cuda.ipc_collect() | |
| torch.clear_autocast_cache() | |
| for _ in range(6): | |
| linear = torch.nn.Linear(shape, shape).to("cuda") | |
| output = linear(input) | |
| clear_memory(linear) | |