File size: 1,594 Bytes
0bb1a82 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
from functools import cache
SUPPORT_QUANT = False
try:
from bitsandbytes.nn import LinearNF4, Linear8bitLt, LinearFP4
SUPPORT_QUANT = True
except Exception:
import torch.nn as nn
class LinearNF4(nn.Linear):
pass
class Linear8bitLt(nn.Linear):
pass
class LinearFP4(nn.Linear):
pass
try:
from quanto.nn import QLinear, QConv2d, QLayerNorm
SUPPORT_QUANT = True
except Exception:
import torch.nn as nn
class QLinear(nn.Linear):
pass
class QConv2d(nn.Conv2d):
pass
class QLayerNorm(nn.LayerNorm):
pass
try:
from optimum.quanto.nn import (
QLinear as QLinearOpt,
QConv2d as QConv2dOpt,
QLayerNorm as QLayerNormOpt,
)
SUPPORT_QUANT = True
except Exception:
import torch.nn as nn
class QLinearOpt(nn.Linear):
pass
class QConv2dOpt(nn.Conv2d):
pass
class QLayerNormOpt(nn.LayerNorm):
pass
from ..logging import logger
QuantLinears = (
Linear8bitLt,
LinearFP4,
LinearNF4,
QLinear,
QConv2d,
QLayerNorm,
QLinearOpt,
QConv2dOpt,
QLayerNormOpt,
)
@cache
def log_bypass():
return logger.warning(
"Using bnb/quanto/optimum-quanto with LyCORIS will enable force-bypass mode."
)
@cache
def log_suspect():
return logger.warning(
"Non-native Linear detected but bypass_mode is not set. "
"Automatically using force-bypass mode to avoid possible issues. "
"Please set bypass_mode=False explicitly if there are no quantized layers."
)
|