Spaces:
Running
on
Zero
Running
on
Zero
| # Adapted from: https://github.com/bghira/SimpleTuner/blob/main/helpers/training/quantisation/__init__.py | |
| from typing import Literal | |
| import torch | |
| from optimum.quanto import qtype | |
| from ltx_trainer import logger | |
| QuantizationOptions = Literal[ | |
| "no_change", | |
| "int8-quanto", | |
| "int4-quanto", | |
| "int2-quanto", | |
| "fp8-quanto", | |
| "fp8uz-quanto", | |
| ] | |
| def quantize_model( | |
| model: torch.nn.Module, | |
| precision: QuantizationOptions, | |
| quantize_activations: bool = False, | |
| ) -> torch.nn.Module: | |
| """ | |
| Quantize a model using the specified precision settings. | |
| Args: | |
| model: The model to quantize. | |
| precision: The precision level to quantize to (e.g. "int8-quanto", "fp8-quanto"). | |
| quantize_activations: Whether to quantize activations in addition to weights. | |
| Returns: | |
| The quantized model, or the original model if no quantization is performed. | |
| """ | |
| if precision is None or precision == "no_change": | |
| return model | |
| from optimum.quanto import freeze, quantize # noqa: PLC0415 | |
| weight_quant = _quanto_type_map(precision) | |
| extra_quanto_args = { | |
| "exclude": [ | |
| "proj_in", | |
| "time_embed.*", | |
| "caption_projection.*", | |
| "rope", | |
| "*norm*", | |
| "proj_out", | |
| ] | |
| } | |
| if quantize_activations: | |
| logger.info("Freezing model weights and activations") | |
| extra_quanto_args["activations"] = weight_quant | |
| else: | |
| logger.info("Freezing model weights only") | |
| quantize(model, weights=weight_quant, **extra_quanto_args) | |
| freeze(model) | |
| return model | |
| def _quanto_type_map(precision: QuantizationOptions) -> torch.dtype | qtype | None: # noqa: PLR0911 | |
| if precision == "no_change": | |
| return None | |
| from optimum.quanto import ( # noqa: PLC0415 | |
| qfloat8, | |
| qfloat8_e4m3fnuz, | |
| qint2, | |
| qint4, | |
| qint8, | |
| ) | |
| if precision == "int2-quanto": | |
| return qint2 | |
| elif precision == "int4-quanto": | |
| return qint4 | |
| elif precision == "int8-quanto": | |
| return qint8 | |
| elif precision in ("fp8-quanto", "fp8uz-quanto"): | |
| if torch.backends.mps.is_available(): | |
| logger.warning( | |
| "MPS doesn't support dtype float8. " | |
| "you must select another precision level such as int2, int8, or int8.", | |
| ) | |
| return None | |
| if precision == "fp8-quanto": | |
| return qfloat8 | |
| elif precision == "fp8uz-quanto": | |
| return qfloat8_e4m3fnuz | |
| raise ValueError(f"Invalid quantisation level: {precision}") | |