| | import torch |
| |
|
| |
|
| | def extend_instance(obj, mixin): |
| | """Apply mixins to a class instance after creation""" |
| | base_cls = obj.__class__ |
| | base_cls_name = obj.__class__.__name__ |
| | obj.__class__ = type( |
| | base_cls_name, (mixin, base_cls), {} |
| | ) |
| |
|
| |
|
| | def getattr_recursive(obj, att): |
| | """ |
| | Return nested attribute of obj |
| | Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c |
| | """ |
| | if att == "": |
| | return obj |
| | i = att.find(".") |
| | if i < 0: |
| | return getattr(obj, att) |
| | else: |
| | return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :]) |
| |
|
| |
|
| | def setattr_recursive(obj, att, val): |
| | """ |
| | Set nested attribute of obj |
| | Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val |
| | """ |
| | if "." in att: |
| | obj = getattr_recursive(obj, ".".join(att.split(".")[:-1])) |
| | setattr(obj, att.split(".")[-1], val) |
| |
|
| |
|
| | def apply_with_stopping_condition( |
| | module, apply_fn, apply_condition=None, stopping_condition=None, **other_args |
| | ): |
| | if stopping_condition(module): |
| | return |
| | if apply_condition(module): |
| | apply_fn(module, **other_args) |
| | for child in module.children(): |
| | apply_with_stopping_condition( |
| | child, |
| | apply_fn, |
| | apply_condition=apply_condition, |
| | stopping_condition=stopping_condition, |
| | **other_args |
| | ) |
| |
|
| |
|
| | def num_params(module, filter_to_trainable=False): |
| | """Returns the number of parameters in the module, or optionally only the trainable parameters""" |
| | if filter_to_trainable: |
| | return sum(p.numel() for p in module.parameters() if p.requires_grad) |
| | else: |
| | return sum(p.numel() for p in module.parameters()) |
| |
|
| |
|
| | def stack_with_padding(list_of_tensors, padding_value=0, padding_side="right"): |
| | """ |
| | Stack a list of tensors with padding on one side |
| | Args: |
| | list_of_tensors (list[torch.Tensor]): List of tensors to stack |
| | padding_value (int, optional): Value to pad with. Defaults to 0. |
| | padding_side (str, optional): Side to pad on. Defaults to "right". |
| | Returns: |
| | torch.Tensor: Stacked tensors |
| | """ |
| | max_tokens = max(tensor.size(0) for tensor in list_of_tensors) |
| | padded_tensors = [] |
| | for tensor in list_of_tensors: |
| | num_tokens = tensor.size(0) |
| | if len(tensor.size()) == 1: |
| | padding = torch.full( |
| | (max_tokens - num_tokens,), |
| | padding_value, |
| | dtype=tensor.dtype, |
| | device=tensor.device, |
| | ) |
| | else: |
| | padding = torch.full( |
| | (max_tokens - num_tokens, tensor.size(1)), |
| | padding_value, |
| | dtype=tensor.dtype, |
| | device=tensor.device, |
| | ) |
| | padded_tensor = ( |
| | torch.cat((tensor, padding), dim=0) |
| | if padding_side == "right" |
| | else torch.cat((padding, tensor), dim=0) |
| | ) |
| | padded_tensors.append(padded_tensor) |
| | return torch.stack(padded_tensors) |
| |
|
| |
|
| | def stack_with_padding_2D_attention(list_of_tensors): |
| | max_size = max(tensor.size(1) for tensor in list_of_tensors) |
| | |
| | padded_tensors = [] |
| | for tensor in list_of_tensors: |
| | a = tensor.shape[-1] |
| | padding = (0, max_size - a, 0, max_size - a) |
| | padded_tensor = torch.nn.functional.pad(tensor, padding) |
| | padded_tensors.append(padded_tensor) |
| | return torch.stack(padded_tensors) |