File size: 6,750 Bytes
7a87926
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
"""
Model quantization utilities for faster inference and lower memory usage.

Supports FP16, INT8, and dynamic quantization.
"""

import logging
from pathlib import Path
from typing import Dict, Optional
import torch
import torch.nn as nn

logger = logging.getLogger(__name__)


def quantize_fp16(model: nn.Module) -> nn.Module:
    """
    Convert model to FP16 (half precision).

    Args:
        model: Model to quantize

    Returns:
        FP16 quantized model
    """
    model = model.half()
    logger.info("Model quantized to FP16")
    return model


def quantize_dynamic_int8(
    model: nn.Module,
    quantizable_modules: Optional[list] = None,
) -> nn.Module:
    """
    Apply dynamic INT8 quantization to model.

    Args:
        model: Model to quantize
        quantizable_modules: List of module types to quantize (default: Linear, Conv2d)

    Returns:
        INT8 quantized model
    """
    if quantizable_modules is None:
        quantizable_modules = [torch.nn.Linear, torch.nn.Conv2d]

    try:
        quantized_model = torch.quantization.quantize_dynamic(
            model,
            quantizable_modules,
            dtype=torch.qint8,
        )
        logger.info(f"Model quantized to INT8 (modules: {quantizable_modules})")
        return quantized_model
    except Exception as e:
        logger.error(f"INT8 quantization failed: {e}")
        logger.warning("Falling back to FP16 quantization")
        return quantize_fp16(model)


def quantize_static_int8(
    model: nn.Module,
    calibration_data,
    quantizable_modules: Optional[list] = None,
) -> nn.Module:
    """
    Apply static INT8 quantization with calibration data.

    Args:
        model: Model to quantize
        calibration_data: DataLoader or list of inputs for calibration
        quantizable_modules: List of module types to quantize

    Returns:
        INT8 quantized model
    """
    if quantizable_modules is None:
        quantizable_modules = [torch.nn.Linear, torch.nn.Conv2d]

    model.eval()

    # Prepare model for quantization
    model.qconfig = torch.quantization.get_default_qconfig("fbgemm")
    torch.quantization.prepare(model, inplace=True)

    # Calibrate with data
    logger.info("Calibrating model for static quantization...")
    with torch.no_grad():
        if hasattr(calibration_data, "__iter__"):
            for i, data in enumerate(calibration_data):
                if isinstance(data, (list, tuple)):
                    inputs = data[0]
                else:
                    inputs = data
                model(inputs)
                if i >= 100:  # Limit calibration samples
                    break
        else:
            for inputs in calibration_data[:100]:
                model(inputs)

    # Convert to quantized
    quantized_model = torch.quantization.convert(model, inplace=False)
    logger.info("Model quantized to static INT8")
    return quantized_model


def save_quantized_model(
    model: nn.Module,
    output_path: Path,
    quantization_type: str = "fp16",
):
    """
    Save quantized model.

    Args:
        model: Quantized model
        output_path: Path to save model
        quantization_type: Type of quantization ('fp16', 'int8')
    """
    output_path = Path(output_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)

    if quantization_type == "fp16":
        torch.save(model.state_dict(), output_path)
    else:
        # For INT8, save the full model (quantization state needed)
        torch.save(model, output_path)

    logger.info(f"Quantized model saved to {output_path}")


def load_quantized_model(
    model: nn.Module,
    checkpoint_path: Path,
    quantization_type: str = "fp16",
    device: str = "cuda",
) -> nn.Module:
    """
    Load quantized model.

    Args:
        model: Base model architecture
        checkpoint_path: Path to quantized checkpoint
        quantization_type: Type of quantization
        device: Device to load on

    Returns:
        Loaded quantized model
    """
    checkpoint_path = Path(checkpoint_path)

    if quantization_type == "fp16":
        state_dict = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(state_dict)
        model = model.half()
    else:
        # For INT8, load full model
        model = torch.load(checkpoint_path, map_location=device)

    logger.info(f"Quantized model loaded from {checkpoint_path}")
    return model


def compare_model_sizes(
    model_fp32: nn.Module,
    model_quantized: nn.Module,
) -> Dict[str, float]:
    """
    Compare model sizes between FP32 and quantized versions.

    Args:
        model_fp32: Original FP32 model
        model_quantized: Quantized model

    Returns:
        Dict with size comparisons
    """

    def get_model_size(model):
        param_size = sum(p.numel() * p.element_size() for p in model.parameters())
        buffer_size = sum(b.numel() * b.element_size() for b in model.buffers())
        return param_size + buffer_size

    size_fp32 = get_model_size(model_fp32)
    size_quantized = get_model_size(model_quantized)

    reduction = (1 - size_quantized / size_fp32) * 100

    return {
        "fp32_size_mb": size_fp32 / 1024 / 1024,
        "quantized_size_mb": size_quantized / 1024 / 1024,
        "reduction_percent": reduction,
    }


def benchmark_quantized_model(
    model: nn.Module,
    sample_input,
    num_runs: int = 100,
    device: str = "cuda",
) -> Dict[str, float]:
    """
    Benchmark quantized model inference speed.

    Args:
        model: Model to benchmark
        sample_input: Sample input tensor
        num_runs: Number of inference runs
        device: Device to run on

    Returns:
        Dict with timing statistics
    """
    model.eval()
    model = model.to(device)

    if isinstance(sample_input, list):
        sample_input = [x.to(device) for x in sample_input]
    else:
        sample_input = sample_input.to(device)

    # Warmup
    with torch.no_grad():
        for _ in range(10):
            if isinstance(sample_input, list):
                _ = model.inference(sample_input)
            else:
                _ = model(sample_input)

    # Benchmark
    torch.cuda.synchronize()
    import time

    start_time = time.time()

    with torch.no_grad():
        for _ in range(num_runs):
            if isinstance(sample_input, list):
                _ = model.inference(sample_input)
            else:
                _ = model(sample_input)

    torch.cuda.synchronize()
    end_time = time.time()

    avg_time = (end_time - start_time) / num_runs
    fps = 1.0 / avg_time

    return {
        "avg_inference_time_ms": avg_time * 1000,
        "fps": fps,
        "total_time_s": end_time - start_time,
    }