File size: 2,608 Bytes
ebfc6b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# Adapted from: https://github.com/bghira/SimpleTuner/blob/main/helpers/training/quantisation/__init__.py
from typing import Literal

import torch
from optimum.quanto import qtype

from ltx_trainer import logger

QuantizationOptions = Literal[
    "no_change",
    "int8-quanto",
    "int4-quanto",
    "int2-quanto",
    "fp8-quanto",
    "fp8uz-quanto",
]


def quantize_model(
    model: torch.nn.Module,
    precision: QuantizationOptions,
    quantize_activations: bool = False,
) -> torch.nn.Module:
    """
    Quantize a model using the specified precision settings.

    Args:
        model: The model to quantize.
        precision: The precision level to quantize to (e.g. "int8-quanto", "fp8-quanto").
        quantize_activations: Whether to quantize activations in addition to weights.

    Returns:
        The quantized model, or the original model if no quantization is performed.
    """
    if precision is None or precision == "no_change":
        return model

    from optimum.quanto import freeze, quantize  # noqa: PLC0415

    weight_quant = _quanto_type_map(precision)
    extra_quanto_args = {
        "exclude": [
            "proj_in",
            "time_embed.*",
            "caption_projection.*",
            "rope",
            "*norm*",
            "proj_out",
        ]
    }
    if quantize_activations:
        logger.info("Freezing model weights and activations")
        extra_quanto_args["activations"] = weight_quant
    else:
        logger.info("Freezing model weights only")

    quantize(model, weights=weight_quant, **extra_quanto_args)
    freeze(model)
    return model


def _quanto_type_map(precision: QuantizationOptions) -> torch.dtype | qtype | None:  # noqa: PLR0911
    if precision == "no_change":
        return None

    from optimum.quanto import (  # noqa: PLC0415
        qfloat8,
        qfloat8_e4m3fnuz,
        qint2,
        qint4,
        qint8,
    )

    if precision == "int2-quanto":
        return qint2
    elif precision == "int4-quanto":
        return qint4
    elif precision == "int8-quanto":
        return qint8
    elif precision in ("fp8-quanto", "fp8uz-quanto"):
        if torch.backends.mps.is_available():
            logger.warning(
                "MPS doesn't support dtype float8. "
                "you must select another precision level such as int2, int8, or int8.",
            )
            return None
        if precision == "fp8-quanto":
            return qfloat8
        elif precision == "fp8uz-quanto":
            return qfloat8_e4m3fnuz

    raise ValueError(f"Invalid quantisation level: {precision}")