| | |
| |
|
| | import logging |
| | import re |
| | from functools import wraps |
| | from re import Pattern |
| | from typing import Callable, Dict, Optional, Tuple |
| |
|
| | import numpy as np |
| | import torch as T |
| |
|
| | from .modules import TTCompressedLinear |
| |
|
| |
|
| | def map_module(root: T.nn.Module, |
| | func: Callable[[T.nn.Module, str], T.nn.Module], |
| | patt: Optional[str] = None) -> T.nn.Module: |
| | """Function ``map_module`` applies a function to each leaf of module tree |
| | which matches to a specified pattern. |
| | |
| | Parameters |
| | ---------- |
| | root : torch.nn.Module |
| | Module to modify. |
| | func : callable |
| | Function to be applied to every module (or matched to pattern) in |
| | module tree. |
| | patt : str, optional |
| | Pattern to filter modules by path in module tree. |
| | |
| | Returns |
| | ------- |
| | torch.nn.Module |
| | Module modified in-place. |
| | """ |
| | @wraps(func) |
| | def func_safe(*args, **kwargs): |
| | node = func(*args, **kwargs) |
| | if not isinstance(node, T.nn.Module): |
| | raise ValueError('Mapped result must be toch.nn.Module type ' |
| | f'but given {type(node)}.') |
| | return node |
| |
|
| | return _map_module(root, func_safe, re.compile(patt or r'.*'), '') |
| |
|
| |
|
| | def _map_module(root: T.nn.Module, |
| | func: Callable[[T.nn.Module, str], T.nn.Module], patt: Pattern, |
| | path: str) -> T.nn.Module: |
| | for name, child in root.named_children(): |
| | node = _map_module(child, func, patt, f'{path}/{name}') |
| | if node != child: |
| | setattr(root, name, node) |
| | if patt.match(path or '/'): |
| | root = func(root, path or '/') |
| | return root |
| |
|
| |
|
| | def convert_linear(module: T.nn.Linear, ctor, **kwargs) -> T.nn.Module: |
| | """Function convert_linear takes module and returns linear module with |
| | approximate matmul. Non-linear modules are returned intact. |
| | """ |
| | if not isinstance(module, T.nn.Linear): |
| | return module |
| | raise NotImplementedError |
| |
|
| |
|
| | def numel(module: T.nn.Module): |
| | value = sum(x.numel() for x in module.parameters()) + \ |
| | sum(x.numel() for x in module.buffers()) |
| |
|
| | def account_prunned(module: T.nn.Module, path: str): |
| | nonlocal value |
| | for name, attr in vars(module).items(): |
| | if not name.endswith('_mask') or not isinstance(attr, T.Tensor): |
| | continue |
| |
|
| | weight_name = name[:-5] |
| | if not hasattr(module, weight_name): |
| | continue |
| |
|
| | weight = getattr(module, weight_name) |
| | value -= weight.numel() - attr.sum() |
| | value += attr.numel() |
| | return module |
| |
|
| | def account_quantized(module: T.nn.Module, path: str): |
| | nonlocal value |
| | if isinstance(module, T.nn.quantized.Linear): |
| | value += module.weight().numel() |
| | if module.bias() is not None: |
| | value += module.bias().numel() |
| | return module |
| |
|
| | def account_rest(module: T.nn.Module, path: str): |
| | account_prunned(module, path) |
| | account_quantized(module, path) |
| | return module |
| |
|
| | map_module(module, account_rest) |
| | return value |
| |
|
| |
|
| | def sizeof(module: T.nn.Module): |
| | value = sum(x.numel() * x.element_size() for x in module.parameters()) + \ |
| | sum(x.numel() * x.element_size() for x in module.buffers()) |
| |
|
| | def account_prunned(module: T.nn.Module, path: str): |
| | nonlocal value |
| | for name, attr in vars(module).items(): |
| | if not name.endswith('_mask') or not isinstance(attr, T.Tensor): |
| | continue |
| |
|
| | weight_name = name[:-5] |
| | if not hasattr(module, weight_name): |
| | continue |
| |
|
| | weight = getattr(module, weight_name) |
| | value -= (weight.numel() - attr.sum()) * weight.element_size() |
| | value += attr.numel() * attr.element_size() |
| | return module |
| |
|
| | def account_quantized(module: T.nn.Module, path: str): |
| | nonlocal value |
| | if isinstance(module, T.nn.quantized.Linear): |
| | value += module.weight().numel() * module.weight().element_size() |
| | if (bias := module.bias()) is not None: |
| | value += bias.numel() * bias.element_size() |
| | return module |
| |
|
| | def account_rest(module: T.nn.Module, path: str): |
| | account_prunned(module, path) |
| | account_quantized(module, path) |
| | return module |
| |
|
| | map_module(module, account_rest) |
| | return value |
| |
|
| |
|
| | def flatten_module(module: T.nn.Module, regexp=None) -> Dict[str, T.nn.Module]: |
| | modules = {} |
| | map_module(module, lambda x, y: modules.update(**{y: x}) or x, regexp) |
| | return modules |
| |
|
| |
|
| | def print_flatten(module: T.nn.Module): |
| | paths = [] |
| | path_len = 0 |
| | names = [] |
| | name_len = 0 |
| | indx_len = 0 |
| |
|
| | def func(module, path): |
| | nonlocal path_len, name_len, indx_len |
| | paths.append(path) |
| | path_len = max(path_len, len(path)) |
| | name = module.__class__.__name__ |
| | names.append(name) |
| | name_len = max(name_len, len(name)) |
| | indx_len += 1 |
| | return module |
| |
|
| | map_module(module, func) |
| |
|
| | indx_len = int(np.ceil(np.log10(indx_len))) |
| | fmt = f'{{indx:>{indx_len}s}} {{path:{path_len}s}} {{name:{name_len}s}}' |
| | print(fmt.format(indx='#', path='Path', name='Layer')) |
| | print('-' * (indx_len + path_len + name_len + 2)) |
| | for i, (path, name) in enumerate(zip(paths, names)): |
| | print(fmt.format(indx=str(i), path=path, name=name)) |
| |
|
| |
|
| | def compress_linear_tt(module: T.nn.Module, path: str, |
| | shape: Tuple[Tuple[int], Tuple[int]], |
| | rank: int) -> T.nn.Module: |
| | if not isinstance(module, T.nn.Linear): |
| | return module |
| |
|
| | |
| | inp_size = np.prod(shape[0]) |
| | out_size = np.prod(shape[1]) |
| | if inp_size == module.in_features and out_size == module.out_features: |
| | pass |
| | elif inp_size == module.out_features and out_size == module.in_features: |
| | shape = (shape[1], shape[0]) |
| | else: |
| | raise ValueError( |
| | 'Input and output features does not match to compression shape: ' |
| | f'{shape[0]} vs {module.in_features} and {shape[1]} vs ' |
| | f'{module.out_features}.') |
| |
|
| | logging.info('apply tt compression to layer %s', path) |
| | return TTCompressedLinear.from_linear(module, shape, rank) |
| |
|