| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from typing import List, Optional |
| |
|
| | import torch |
| | from compressed_tensors.transform import TransformArgs |
| | from compressed_tensors.utils import TorchDtype |
| | from pydantic import BaseModel, ConfigDict, Field |
| |
|
| |
|
| | __all__ = ["TransformScheme"] |
| |
|
| |
|
| | class TransformScheme(BaseModel): |
| | """ |
| | Scheme used to parameterize a particular transform type and specify how and where it |
| | should be applied to the model |
| | |
| | :param type: string indicating the particular transform type that should be created |
| | and applied. This should be one of the registered transform types |
| | (see `Transforms.registered_names()`) |
| | :param apply: list of TransformationArgs containing the information about the |
| | modules that should be targeted by the specified transform |
| | :param randomize: True if uniquely randomized transform weights should be used, |
| | otherwise use identical transform weights where applicable |
| | :param requires_grad: True if weights include gradients for training |
| | :param head_dim: If set, the transform matrix will be block diagonal with each |
| | block being a square matrix of this size. The name head_dim was chosen because |
| | some rotations need to be block-diagonal with block size equal to the head_dim, |
| | but research has shown value in applying some rotations with smaller block size, |
| | irrespective of head_dim. |
| | :param precision: Precision at which this transform should be applied during online |
| | rotations. Fused (offline) rotations are always performed in float64 |
| | """ |
| |
|
| | type: str |
| | apply: List[TransformArgs] = Field(default_factory=list) |
| | randomize: bool = Field(default=False) |
| | requires_grad: bool = Field(default=False) |
| | head_dim: Optional[int] = Field(default=None) |
| | precision: TorchDtype = Field(default=torch.float32) |
| |
|
| | model_config = ConfigDict(extra="forbid") |
| |
|