linoy
inital commit
ebfc6b3
# Adapted from: https://github.com/bghira/SimpleTuner/blob/main/helpers/training/quantisation/__init__.py
from typing import Literal
import torch
from optimum.quanto import qtype
from ltx_trainer import logger
QuantizationOptions = Literal[
"no_change",
"int8-quanto",
"int4-quanto",
"int2-quanto",
"fp8-quanto",
"fp8uz-quanto",
]
def quantize_model(
model: torch.nn.Module,
precision: QuantizationOptions,
quantize_activations: bool = False,
) -> torch.nn.Module:
"""
Quantize a model using the specified precision settings.
Args:
model: The model to quantize.
precision: The precision level to quantize to (e.g. "int8-quanto", "fp8-quanto").
quantize_activations: Whether to quantize activations in addition to weights.
Returns:
The quantized model, or the original model if no quantization is performed.
"""
if precision is None or precision == "no_change":
return model
from optimum.quanto import freeze, quantize # noqa: PLC0415
weight_quant = _quanto_type_map(precision)
extra_quanto_args = {
"exclude": [
"proj_in",
"time_embed.*",
"caption_projection.*",
"rope",
"*norm*",
"proj_out",
]
}
if quantize_activations:
logger.info("Freezing model weights and activations")
extra_quanto_args["activations"] = weight_quant
else:
logger.info("Freezing model weights only")
quantize(model, weights=weight_quant, **extra_quanto_args)
freeze(model)
return model
def _quanto_type_map(precision: QuantizationOptions) -> torch.dtype | qtype | None: # noqa: PLR0911
if precision == "no_change":
return None
from optimum.quanto import ( # noqa: PLC0415
qfloat8,
qfloat8_e4m3fnuz,
qint2,
qint4,
qint8,
)
if precision == "int2-quanto":
return qint2
elif precision == "int4-quanto":
return qint4
elif precision == "int8-quanto":
return qint8
elif precision in ("fp8-quanto", "fp8uz-quanto"):
if torch.backends.mps.is_available():
logger.warning(
"MPS doesn't support dtype float8. "
"you must select another precision level such as int2, int8, or int8.",
)
return None
if precision == "fp8-quanto":
return qfloat8
elif precision == "fp8uz-quanto":
return qfloat8_e4m3fnuz
raise ValueError(f"Invalid quantisation level: {precision}")