|
|
|
|
|
from typing import Literal |
|
|
|
|
|
import torch |
|
|
from optimum.quanto import qtype |
|
|
|
|
|
from ltx_trainer import logger |
|
|
|
|
|
QuantizationOptions = Literal[ |
|
|
"no_change", |
|
|
"int8-quanto", |
|
|
"int4-quanto", |
|
|
"int2-quanto", |
|
|
"fp8-quanto", |
|
|
"fp8uz-quanto", |
|
|
] |
|
|
|
|
|
|
|
|
def quantize_model( |
|
|
model: torch.nn.Module, |
|
|
precision: QuantizationOptions, |
|
|
quantize_activations: bool = False, |
|
|
) -> torch.nn.Module: |
|
|
""" |
|
|
Quantize a model using the specified precision settings. |
|
|
|
|
|
Args: |
|
|
model: The model to quantize. |
|
|
precision: The precision level to quantize to (e.g. "int8-quanto", "fp8-quanto"). |
|
|
quantize_activations: Whether to quantize activations in addition to weights. |
|
|
|
|
|
Returns: |
|
|
The quantized model, or the original model if no quantization is performed. |
|
|
""" |
|
|
if precision is None or precision == "no_change": |
|
|
return model |
|
|
|
|
|
from optimum.quanto import freeze, quantize |
|
|
|
|
|
weight_quant = _quanto_type_map(precision) |
|
|
extra_quanto_args = { |
|
|
"exclude": [ |
|
|
"proj_in", |
|
|
"time_embed.*", |
|
|
"caption_projection.*", |
|
|
"rope", |
|
|
"*norm*", |
|
|
"proj_out", |
|
|
] |
|
|
} |
|
|
if quantize_activations: |
|
|
logger.info("Freezing model weights and activations") |
|
|
extra_quanto_args["activations"] = weight_quant |
|
|
else: |
|
|
logger.info("Freezing model weights only") |
|
|
|
|
|
quantize(model, weights=weight_quant, **extra_quanto_args) |
|
|
freeze(model) |
|
|
return model |
|
|
|
|
|
|
|
|
def _quanto_type_map(precision: QuantizationOptions) -> torch.dtype | qtype | None: |
|
|
if precision == "no_change": |
|
|
return None |
|
|
|
|
|
from optimum.quanto import ( |
|
|
qfloat8, |
|
|
qfloat8_e4m3fnuz, |
|
|
qint2, |
|
|
qint4, |
|
|
qint8, |
|
|
) |
|
|
|
|
|
if precision == "int2-quanto": |
|
|
return qint2 |
|
|
elif precision == "int4-quanto": |
|
|
return qint4 |
|
|
elif precision == "int8-quanto": |
|
|
return qint8 |
|
|
elif precision in ("fp8-quanto", "fp8uz-quanto"): |
|
|
if torch.backends.mps.is_available(): |
|
|
logger.warning( |
|
|
"MPS doesn't support dtype float8. " |
|
|
"you must select another precision level such as int2, int8, or int8.", |
|
|
) |
|
|
return None |
|
|
if precision == "fp8-quanto": |
|
|
return qfloat8 |
|
|
elif precision == "fp8uz-quanto": |
|
|
return qfloat8_e4m3fnuz |
|
|
|
|
|
raise ValueError(f"Invalid quantisation level: {precision}") |
|
|
|