| import torch.nn as nn |
|
|
| def remove_padding(tensors, lengths): |
| return [tensor[:tensor_length] for tensor, tensor_length in zip(tensors, lengths)] |
|
|
| class AutoParams(nn.Module): |
| def __init__(self, **kargs): |
| try: |
| for param in self.needed_params: |
| if param in kargs: |
| setattr(self, param, kargs[param]) |
| else: |
| raise ValueError(f"{param} is needed.") |
| except : |
| pass |
| |
| try: |
| for param, default in self.optional_params.items(): |
| if param in kargs and kargs[param] is not None: |
| setattr(self, param, kargs[param]) |
| else: |
| setattr(self, param, default) |
| except : |
| pass |
| super().__init__() |
|
|
|
|
| |
| def freeze_params(module: nn.Module) -> None: |
| """ |
| Freeze the parameters of this module, |
| i.e. do not update them during training |
| |
| :param module: freeze parameters of this module |
| """ |
| for _, p in module.named_parameters(): |
| p.requires_grad = False |
|
|