File size: 1,136 Bytes
12c5b18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
"""
量化配置(语义分析、信息密度分析共用)

从环境变量读取并返回设备相关的量化策略:
- FORCE_INT8=1: INT8 量化(CPU/CUDA 支持,MPS 不支持)
- CPU_FORCE_BFLOAT16=1: CPU 使用 bfloat16
"""

import os
from typing import NamedTuple

import torch


class QuantizationConfig(NamedTuple):
    """量化配置,语义模型和信息密度模型共用"""
    use_int8: bool
    dtype: torch.dtype


def get_quantization_config(device: torch.device) -> QuantizationConfig:
    """
    根据设备和环境变量返回量化配置。

    Returns:
        QuantizationConfig: use_int8, dtype
    """
    force_int8 = os.environ.get("FORCE_INT8") == "1"
    force_bfloat16 = os.environ.get("CPU_FORCE_BFLOAT16") == "1"

    if device.type == "cpu":
        use_int8 = force_int8
        dtype = torch.bfloat16 if force_bfloat16 else torch.float32
    elif device.type == "cuda":
        use_int8 = force_int8
        dtype = torch.float16
    else:
        # MPS 不支持 INT8
        use_int8 = False
        dtype = torch.float16

    return QuantizationConfig(use_int8=use_int8, dtype=dtype)