| """This module uses parts of rut5compressed. It shares the same module |
| structure as model used in neural network compression experiments with |
| rut5compressed. |
| """ |
|
|
| from functools import partial |
| from typing import Optional, Tuple |
|
|
| import numpy as np |
| import torch as T |
| from transformers import BartForConditionalGeneration |
|
|
| from .configuration_bart import TTCompressedBartConfig |
| from .linalg import ttd |
| from .modules import TTCompressedLinear |
| from .util import compress_linear_tt, map_module |
|
|
|
|
| class TTCompressedBartForConditionGeneration(BartForConditionalGeneration): |
| """Class TTCompressedBartForConditionGeneration defines a BART-based model |
| with compressed linear layers with TT. |
| """ |
|
|
| LAYERS = r'/(de|en)coder/layers/\d+/fc[12]' |
|
|
| config_class = TTCompressedBartConfig |
|
|
| def __init__(self, config: TTCompressedBartConfig, |
| shape: Optional[Tuple[Tuple[int], Tuple[int]]] = None, |
| rank: Optional[int] = None, |
| compress: bool = False): |
| super().__init__(config) |
|
|
| self.rank = rank or config.rank |
| self.shape = shape |
| if self.shape is None: |
| self.shape = (tuple(self.config.shape_in), |
| tuple(self.config.shape_out)) |
|
|
| compress_fn = partial(compress_linear_tt, rank=self.rank) |
| if not compress: |
| compress_fn = self.convert |
| self.model = map_module(self.model, compress_fn, self.LAYERS) |
|
|
| def convert(self, module: T.nn.Module, path: str) -> T.nn.Module: |
| if isinstance(module, T.nn.Linear): |
| |
| |
| |
| in_shape, out_shape = self.shape |
| if module.in_features > module.out_features: |
| out_shape, in_shape = self.shape |
|
|
| shape = (in_shape, out_shape) |
| bias = module.bias is not None |
| return TTCompressedLinear.from_random(shape, self.rank, bias) |
| return module |
|
|
|
|
| TTCompressedBartForConditionGeneration \ |
| .register_for_auto_class('AutoModelForSeq2SeqLM') |
|
|