| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from typing import Optional |
| |
|
| | import torch |
| | from compressed_tensors.transform import TransformLocation |
| |
|
| |
|
| | __all__ = ["get_transform_size", "apply_transform_weight"] |
| |
|
| |
|
| | def get_transform_size( |
| | module: torch.nn.Module, |
| | location: TransformLocation, |
| | head_dim: Optional[int] = None, |
| | ) -> int: |
| | """ |
| | Determine the size of a transform matrix given its location on the module |
| | |
| | :param module: module that matrix will be applied to |
| | :param location: location on module |
| | :param head_dim: size of head when transform is applied to mha |
| | :return: size of matrix |
| | """ |
| | size = None |
| |
|
| | if isinstance(module, torch.nn.Linear): |
| | if location in (TransformLocation.INPUT, TransformLocation.WEIGHT_INPUT): |
| | size = module.in_features |
| | else: |
| | size = module.out_features |
| | elif isinstance(module, torch.nn.Embedding): |
| | if location in (TransformLocation.INPUT, TransformLocation.WEIGHT_INPUT): |
| | size = module.num_embeddings |
| | else: |
| | size = module.embedding_dim |
| | elif head_dim is None: |
| | raise NotImplementedError( |
| | f"Transforms on {type(module)} are not supported without head_dim" |
| | ) |
| |
|
| | if head_dim is not None: |
| | if size is not None and size % head_dim != 0: |
| | raise ValueError( |
| | f"{head_dim} must divide {size} for {type(module)} at {location}" |
| | ) |
| |
|
| | size = head_dim |
| |
|
| | return size |
| |
|
| |
|
| | def apply_transform_weight( |
| | transform_weight: torch.Tensor, |
| | value: torch.Tensor, |
| | location: TransformLocation, |
| | module_type: type[torch.nn.Module], |
| | ) -> torch.Tensor: |
| | """ |
| | Using the transform location, apply the transform_weight to the |
| | given value wrt linear weights. For more info on input and output transforms, |
| | see `TransformLocation` |
| | |
| | The following explains how weights should be applied to values according to location |
| | |
| | let x be input activation |
| | W be weight, |
| | yh, xh, Wh be transformed output, input, weight |
| | |
| | note that |
| | y = (x W.T) // torch.nn.Linear |
| | |
| | Choose values for yh, xh, and Wh which incorporate matrix transforms |
| | |
| | let V, Vi be transform matrices on input side |
| | U, Ui be transform matrices on output side |
| | |
| | pick xh = (x V) |
| | Wh = (U.T W Vi.T) |
| | yh = (y U) |
| | |
| | The following shows that `yh = (xh) (Wh).T` for the chosen values of yh, xh, and Wh |
| | |
| | (xh) (Wh).T = (x V) (U.T W Vi.T).T |
| | = (x V) (Vi W.T U) // transpose matrix product identity |
| | = (x W.T) U |
| | = y U |
| | = yh |
| | |
| | :param transform_weight: transform weight to apply |
| | :param value: value to apply transform_weight to |
| | :param location: determines how weight should be applied |
| | :param model_type: result of type(module), passed in to determine application of |
| | weight transform |
| | :return: value after transform_weight has been applied |
| | """ |
| |
|
| | assert transform_weight.shape[0] == transform_weight.shape[1] |
| |
|
| | if TransformLocation(location).is_online(): |
| | return _multihead_matmul(value, transform_weight) |
| |
|
| | if module_type == torch.nn.Linear: |
| | if location == TransformLocation.WEIGHT_INPUT: |
| | |
| | return _multihead_matmul(value, transform_weight.T) |
| |
|
| | elif location == TransformLocation.WEIGHT_OUTPUT: |
| | |
| | return _multihead_matmul(transform_weight.T, value) |
| |
|
| | |
| | elif module_type == torch.nn.Embedding: |
| | if location == TransformLocation.WEIGHT_INPUT: |
| | return _multihead_matmul(transform_weight, value) |
| |
|
| | elif location == TransformLocation.WEIGHT_OUTPUT: |
| | return _multihead_matmul(value, transform_weight) |
| |
|
| | raise NotImplementedError( |
| | f"Applying transforms to {module_type} {location} is not supported" |
| | ) |
| |
|
| |
|
| | def _multihead_matmul(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Performs A @ B for last two dims of two matrices A and B that possibly |
| | have different shapes, as is the case in multi-headed dimension. If |
| | shapes are different, this is equivalent to converting the last two dims |
| | of the smaller matrix into a block-diagonal matrix with the same shape as |
| | the last two dims of the larger matrix. |
| | |
| | E.g. if A is half the size of B, this function will perform |
| | [[A ] @ B |
| | [ A]] |
| | |
| | If B is a third of the size of A, this function will perform |
| | A @ [[B ] |
| | [ B ] |
| | [ B]] |
| | |
| | This function will error out if the shapes are not evenly divisble |
| | |
| | :param A: left-hand tensor |
| | :param B: right-hand tensor |
| | :return: result |
| | """ |
| | if A.shape[-1] > B.shape[-2]: |
| | head_dim = B.shape[-2] |
| | num_heads = A.shape[-1] // head_dim |
| | A = A.unflatten(-1, (num_heads, head_dim)) |
| | return (A @ B).flatten(-2, -1) |
| | elif A.shape[-1] < B.shape[-2]: |
| | head_dim = A.shape[-1] |
| | num_heads = B.shape[-2] // head_dim |
| | B = B.unflatten(-2, (num_heads, head_dim)) |
| | return (A @ B).flatten(-3, -2) |
| | else: |
| | return A @ B |
| |
|