# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from enum import Enum from typing import List from pydantic import BaseModel, ConfigDict, Field, field_validator __all__ = ["TransformArgs", "TransformLocation"] class TransformLocation(str, Enum): """ Enum representing which parameters/activations a transform weight should be applied to on a given module. | -------------------------------------------------------------------------------------------------------- | # noqa: E501 | Name | Runtime | Values | Locations Where Inverse Could Be Applied | # noqa: E501 | --------------- | ----------- | ------------- | -------------------------------------------------------- | # noqa: E501 | `INPUT` | online | activations | `prev.WEIGHT_OUTPUT`, `prev.OUTPUT`, `this.WEIGHT_INPUT` | # noqa: E501 | `WEIGHT_INPUT` | offline | weight | `prev.WEIGHT_OUTPUT`, `prev.OUTPUT`, `this.INPUT` | # noqa: E501 | `WEIGHT_OUTPUT` | offline | weight | `this.OUTPUT`, `next.INPUT`, `next.WEIGHT_INPUT` | # noqa: E501 | `OUTPUT` | online | activations | `this.WEIGHT_OUTPUT`, `next.INPUT`, `next.WEIGHT_INPUT` | # noqa: E501 | `K_CACHE` | online | key_values | `q_proj.Q_ATTN` | # noqa: E501 | `Q_ATTN` | online | query_values | `k_proj.K_CACHE` | # noqa: E501 | -------------------------------------------------------------------------------------------------------- | # noqa: E501 """ INPUT = "input" WEIGHT_INPUT = "weight_input" WEIGHT_OUTPUT = "weight_output" OUTPUT = "output" K_CACHE = "k_cache" Q_ATTN = "q_attn" def is_online(self) -> bool: """ Returns True if the transform location is online (applied at runtime), False otherwise """ return self not in ( TransformLocation.WEIGHT_INPUT, TransformLocation.WEIGHT_OUTPUT, ) class TransformArgs(BaseModel, use_enum_values=True): """ Arguments which define *how* and where a transform should be applied to a model :param targets: list of modules to apply transforms to :param location: where to apply transform on module, one of (`input`, `weight`, `output`, `k_cache`, `q_attn`) :param inverse: whether or not to apply the inverse of a transform :param ignore: any modules which should be ignored from the targets list """ targets: List[str] location: TransformLocation inverse: bool = Field(default=False) ignore: List[str] = Field(default_factory=list) @field_validator("targets", "ignore", mode="before") @classmethod def wrap_singleton(cls, value): if isinstance(value, str): return [value] return value def is_online(self) -> bool: return TransformLocation(self.location).is_online() model_config = ConfigDict(extra="forbid")