File size: 1,136 Bytes
12c5b18 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 | """
量化配置(语义分析、信息密度分析共用)
从环境变量读取并返回设备相关的量化策略:
- FORCE_INT8=1: INT8 量化(CPU/CUDA 支持,MPS 不支持)
- CPU_FORCE_BFLOAT16=1: CPU 使用 bfloat16
"""
import os
from typing import NamedTuple
import torch
class QuantizationConfig(NamedTuple):
"""量化配置,语义模型和信息密度模型共用"""
use_int8: bool
dtype: torch.dtype
def get_quantization_config(device: torch.device) -> QuantizationConfig:
"""
根据设备和环境变量返回量化配置。
Returns:
QuantizationConfig: use_int8, dtype
"""
force_int8 = os.environ.get("FORCE_INT8") == "1"
force_bfloat16 = os.environ.get("CPU_FORCE_BFLOAT16") == "1"
if device.type == "cpu":
use_int8 = force_int8
dtype = torch.bfloat16 if force_bfloat16 else torch.float32
elif device.type == "cuda":
use_int8 = force_int8
dtype = torch.float16
else:
# MPS 不支持 INT8
use_int8 = False
dtype = torch.float16
return QuantizationConfig(use_int8=use_int8, dtype=dtype)
|