| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from enum import Enum |
| | from typing import List |
| |
|
| | from pydantic import BaseModel, ConfigDict, Field, field_validator |
| |
|
| |
|
| | __all__ = ["TransformArgs", "TransformLocation"] |
| |
|
| |
|
| | class TransformLocation(str, Enum): |
| | """ |
| | Enum representing which parameters/activations a transform weight should be applied |
| | to on a given module. |
| | |
| | | -------------------------------------------------------------------------------------------------------- | # noqa: E501 |
| | | Name | Runtime | Values | Locations Where Inverse Could Be Applied | # noqa: E501 |
| | | --------------- | ----------- | ------------- | -------------------------------------------------------- | # noqa: E501 |
| | | `INPUT` | online | activations | `prev.WEIGHT_OUTPUT`, `prev.OUTPUT`, `this.WEIGHT_INPUT` | # noqa: E501 |
| | | `WEIGHT_INPUT` | offline | weight | `prev.WEIGHT_OUTPUT`, `prev.OUTPUT`, `this.INPUT` | # noqa: E501 |
| | | `WEIGHT_OUTPUT` | offline | weight | `this.OUTPUT`, `next.INPUT`, `next.WEIGHT_INPUT` | # noqa: E501 |
| | | `OUTPUT` | online | activations | `this.WEIGHT_OUTPUT`, `next.INPUT`, `next.WEIGHT_INPUT` | # noqa: E501 |
| | | `K_CACHE` | online | key_values | `q_proj.Q_ATTN` | # noqa: E501 |
| | | `Q_ATTN` | online | query_values | `k_proj.K_CACHE` | # noqa: E501 |
| | | -------------------------------------------------------------------------------------------------------- | # noqa: E501 |
| | """ |
| |
|
| | INPUT = "input" |
| | WEIGHT_INPUT = "weight_input" |
| | WEIGHT_OUTPUT = "weight_output" |
| | OUTPUT = "output" |
| | K_CACHE = "k_cache" |
| | Q_ATTN = "q_attn" |
| |
|
| | def is_online(self) -> bool: |
| | """ |
| | Returns True if the transform location is online |
| | (applied at runtime), False otherwise |
| | """ |
| | return self not in ( |
| | TransformLocation.WEIGHT_INPUT, |
| | TransformLocation.WEIGHT_OUTPUT, |
| | ) |
| |
|
| |
|
| | class TransformArgs(BaseModel, use_enum_values=True): |
| | """ |
| | Arguments which define *how* and where a transform should be applied to a model |
| | |
| | :param targets: list of modules to apply transforms to |
| | :param location: where to apply transform on module, one of (`input`, `weight`, |
| | `output`, `k_cache`, `q_attn`) |
| | :param inverse: whether or not to apply the inverse of a transform |
| | :param ignore: any modules which should be ignored from the targets list |
| | """ |
| |
|
| | targets: List[str] |
| | location: TransformLocation |
| | inverse: bool = Field(default=False) |
| | ignore: List[str] = Field(default_factory=list) |
| |
|
| | @field_validator("targets", "ignore", mode="before") |
| | @classmethod |
| | def wrap_singleton(cls, value): |
| | if isinstance(value, str): |
| | return [value] |
| | return value |
| |
|
| | def is_online(self) -> bool: |
| | return TransformLocation(self.location).is_online() |
| |
|
| | model_config = ConfigDict(extra="forbid") |
| |
|