| import gc |
| import torch as t |
|
|
| def freeze_model(model): |
| model.eval() |
| for params in model.parameters(): |
| params.requires_grad = False |
|
|
|
|
| def unfreeze_model(model): |
| model.train() |
| for params in model.parameters(): |
| params.requires_grad = True |
|
|
| def zero_grad(model): |
| for p in model.parameters(): |
| if p.requires_grad and p.grad is not None: |
| p.grad = None |
|
|
| def empty_cache(): |
| gc.collect() |
| t.cuda.empty_cache() |
|
|
| def assert_shape(x, exp_shape): |
| assert x.shape == exp_shape, f"Expected {exp_shape} got {x.shape}" |
|
|
| def count_parameters(model): |
| return sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
| def count_state(model): |
| return sum(s.numel() for s in model.state_dict().values()) |
|
|
|
|