File size: 78,208 Bytes
dc9bb20 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783 1784 1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1824 1825 1826 1827 1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 1838 1839 1840 1841 1842 1843 1844 1845 1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857 1858 1859 1860 1861 1862 1863 1864 1865 1866 1867 1868 1869 1870 1871 1872 1873 1874 1875 1876 1877 1878 1879 1880 1881 1882 1883 1884 1885 1886 1887 1888 1889 1890 1891 1892 1893 1894 1895 1896 1897 1898 1899 1900 1901 1902 1903 1904 1905 1906 1907 1908 1909 1910 1911 1912 1913 1914 1915 1916 1917 1918 1919 1920 1921 1922 1923 1924 1925 1926 1927 1928 1929 1930 1931 1932 1933 1934 1935 1936 1937 1938 1939 1940 1941 1942 1943 1944 1945 1946 1947 1948 1949 1950 1951 1952 1953 1954 1955 1956 1957 1958 1959 1960 1961 1962 1963 1964 1965 1966 1967 1968 1969 1970 1971 1972 1973 1974 1975 1976 1977 1978 1979 1980 1981 1982 1983 1984 1985 1986 1987 1988 1989 1990 1991 1992 1993 1994 1995 1996 1997 1998 1999 2000 2001 2002 2003 2004 2005 2006 2007 2008 2009 2010 2011 2012 2013 2014 2015 2016 2017 2018 2019 2020 2021 2022 2023 2024 2025 2026 2027 2028 2029 2030 2031 2032 2033 2034 2035 2036 2037 2038 2039 2040 2041 2042 2043 2044 2045 2046 2047 2048 2049 2050 2051 2052 2053 2054 2055 2056 2057 2058 2059 2060 2061 2062 2063 2064 2065 2066 2067 2068 2069 2070 2071 2072 2073 2074 2075 2076 2077 2078 2079 2080 2081 2082 2083 2084 2085 2086 2087 2088 2089 2090 2091 2092 2093 2094 2095 2096 2097 2098 2099 2100 2101 2102 2103 2104 2105 2106 2107 2108 2109 2110 2111 2112 2113 2114 2115 2116 2117 2118 2119 2120 2121 2122 2123 2124 2125 2126 2127 2128 2129 2130 2131 2132 2133 2134 2135 2136 2137 2138 2139 2140 2141 2142 2143 2144 2145 2146 2147 2148 2149 2150 2151 2152 2153 2154 2155 2156 2157 2158 2159 2160 2161 2162 2163 2164 2165 2166 2167 2168 2169 2170 2171 2172 2173 2174 2175 2176 2177 2178 2179 2180 2181 2182 2183 2184 2185 2186 2187 2188 2189 2190 2191 2192 2193 2194 2195 2196 2197 2198 2199 2200 2201 2202 2203 2204 2205 2206 2207 2208 2209 2210 2211 2212 2213 2214 2215 2216 2217 2218 2219 2220 2221 2222 2223 2224 2225 2226 2227 2228 2229 2230 2231 2232 2233 2234 2235 2236 2237 2238 2239 2240 2241 2242 2243 2244 2245 2246 2247 2248 2249 2250 2251 2252 2253 2254 2255 2256 2257 2258 2259 2260 2261 2262 2263 2264 2265 2266 2267 2268 2269 2270 2271 2272 2273 2274 2275 2276 2277 2278 2279 2280 2281 2282 2283 2284 2285 2286 2287 2288 2289 2290 2291 2292 2293 2294 2295 2296 2297 | # Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from collections.abc import Iterable
import ctypes as ct
import itertools
from math import prod
from typing import Any, Optional
import numpy as np
import torch
from torch import Tensor
from typing_extensions import deprecated
from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict
from .cextension import lib
name2qmap = {}
"""C FUNCTIONS FOR OPTIMIZERS"""
str2optimizer8bit = {
"adam": (
lib.cadam_static_8bit_grad_32,
lib.cadam_static_8bit_grad_16,
),
"momentum": (
lib.cmomentum_static_8bit_grad_32,
lib.cmomentum_static_8bit_grad_16,
),
"rmsprop": (
lib.crmsprop_static_8bit_grad_32,
lib.crmsprop_static_8bit_grad_16,
),
"lion": (
lib.clion_static_8bit_grad_32,
lib.clion_static_8bit_grad_16,
),
"lamb": (
lib.cadam_static_8bit_grad_32,
lib.cadam_static_8bit_grad_16,
),
"lars": (
lib.cmomentum_static_8bit_grad_32,
lib.cmomentum_static_8bit_grad_16,
),
}
class GlobalPageManager:
_instance = None
def __init__(self):
raise RuntimeError("Call get_instance() instead")
def initialize(self):
self.paged_tensors = []
@classmethod
def get_instance(cls):
if cls._instance is None:
cls._instance = cls.__new__(cls)
cls._instance.initialize()
return cls._instance
def prefetch_all(self, to_cpu=False):
# assume the first added, will be the
# ones that are used first, so swap them in last
# in the case they are evicted again
for t in self.paged_tensors[::-1]:
prefetch_tensor(t, to_cpu)
class CUBLAS_Context:
_instance = None
def __init__(self):
raise RuntimeError("Call get_instance() instead")
def initialize(self):
self.context = {}
@classmethod
def get_instance(cls):
if cls._instance is None:
cls._instance = cls.__new__(cls)
cls._instance.initialize()
return cls._instance
def get_context(self, device):
if device.index not in self.context:
prev_device = torch.cuda.current_device()
torch.cuda.set_device(device)
self.context[device.index] = ct.c_void_p(lib.get_context())
torch.cuda.set_device(prev_device)
return self.context[device.index]
class Cusparse_Context:
_instance = None
def __init__(self):
raise RuntimeError("Call get_instance() instead")
def initialize(self):
self.context = ct.c_void_p(lib.get_cusparse())
@classmethod
def get_instance(cls):
if cls._instance is None:
cls._instance = cls.__new__(cls)
cls._instance.initialize()
return cls._instance
FIRST_CUDA_DEVICE = torch.device("cuda", index=0)
# When multiple GPUs are present, we use a context manager to
# switch to the correct device of a tensor before invoking our CUDA
# kernels in the C++ library. However, when there's only one device
# there is no need to incur the overhead of cudaGetDevice/cudaSetDevice.
if torch.cuda.device_count() > 1:
def _cuda_device_of(a: torch.Tensor):
return torch.cuda.device_of(a)
else:
import contextlib
def _cuda_device_of(a: torch.Tensor):
return contextlib.nullcontext()
def get_paged(*shape, dtype=torch.float32, device=FIRST_CUDA_DEVICE):
num_bytes = dtype.itemsize * prod(shape)
cuda_ptr = lib.cget_managed_ptr(ct.c_size_t(num_bytes))
c_ptr = ct.cast(cuda_ptr, ct.POINTER(ct.c_int))
new_array = np.ctypeslib.as_array(c_ptr, shape=shape)
out = torch.frombuffer(new_array, dtype=dtype, count=prod(shape)).view(shape)
out.is_paged = True
out.page_deviceid = device.index
return out
def prefetch_tensor(A: torch.Tensor, to_cpu=False):
assert A.is_paged, "Only paged tensors can be prefetched!"
if to_cpu:
deviceid = -1
else:
deviceid = A.page_deviceid
lib.cprefetch(get_ptr(A), ct.c_size_t(A.nbytes), ct.c_int32(deviceid))
def elementwise_func(func_name, A, B, value, prefetch=True):
func = None
if A.dtype == torch.float32:
func = getattr(lib, f"c{func_name}_fp32", None)
cvalue = ct.c_float(value)
elif A.dtype == torch.uint8:
func = getattr(lib, f"c{func_name}_uint8", None)
cvalue = ct.c_uint8(value)
if func is None:
raise NotImplementedError(f"Function not implemented: {func_name}")
is_managed = getattr(A, "is_managed", False)
if is_managed and prefetch:
prefetch_tensor(A)
if B is not None:
prefetch_tensor(B)
func(get_ptr(A), get_ptr(B), cvalue, ct.c_int64(A.numel()))
if A.is_paged or B.is_paged:
# paged function are fully asynchronous
# if we return from this function, we want to the tensor
# to be in the correct state, that is the final state after the
# operation occurred. So we synchronize.
torch.cuda.synchronize()
def fill(A, value, device=None, prefetch=True):
elementwise_func("fill", A, None, value)
def _mul(A, B, device=None):
elementwise_func("_mul", A, B, 0)
def create_linear_map(signed=True, total_bits=8, add_zero=True):
sign = -1.0 if signed else 0.0
total_values = 2**total_bits
if add_zero or total_bits < 8:
# add a zero
# since we simulate less bits by having zeros in the data type, we
# we need to center the quantization around zero and as such lose
# a single value
total_values = 2**total_bits if not signed else 2**total_bits - 1
values = torch.linspace(sign, 1.0, total_values)
gap = 256 - values.numel()
if gap == 0:
return values
else:
l = values.numel() // 2 # noqa: E741
return torch.Tensor(values[:l].tolist() + [0] * gap + values[l:].tolist())
def create_normal_map(offset=0.9677083, use_extra_value=True):
"""Create the NormalFloat (NF4) quantization map.
Constructs a lookup table of 16 quantization values (stored in a 256-element tensor for
indexing convenience) derived from quantiles of the standard normal distribution N(0, 1).
Each bin has approximately equal probability mass under the normal distribution, which is
optimal for normally-distributed data like neural network weights.
Unlike floating-point types (FP4, FP8), NF4 is NOT a float encoding — the 4-bit index is
simply a lookup into this table. There is no sign/exponent/mantissa decomposition.
The values are generated by computing ``scipy.stats.norm.ppf()`` (inverse CDF) at evenly
spaced quantile points, then normalizing to [-1, 1].
For more details, see: QLoRA: Efficient Finetuning of Quantized LLMs
(https://arxiv.org/abs/2305.14314)
Args:
offset: The outermost quantile boundary, controlling the range of the normal distribution
that is covered. ``norm.ppf(offset)`` gives the largest bin edge in standard deviations.
The default (0.9677083) covers up to ~1.845 standard deviations and was empirically
optimized to minimize quantization error for typical neural network weight distributions.
use_extra_value: If True, creates an asymmetric type with 8 negative and 9 positive values
(including zero), for 15 non-zero values total. If False, creates a symmetric type
with 7 negative and 7 positive values (14 non-zero values total).
Returns:
A 256-element tensor where the first 16 values are the sorted NF4 quantization levels
normalized to [-1, 1], and the remaining values are zero (padding for 8-bit indexing).
"""
try:
from scipy.stats import norm
except ImportError as ie:
raise ImportError(
"Scipy is required for `create_normal_map`. Install `bitsandbytes` with the `[test]` extra.",
) from ie
if use_extra_value:
# one more positive value, this is an asymmetric type
v1 = norm.ppf(torch.linspace(offset, 0.5, 9)[:-1]).tolist()
v2 = [0] * (256 - 15) ## we have 15 non-zero values in this data type
v3 = (-norm.ppf(torch.linspace(offset, 0.5, 8)[:-1])).tolist()
else:
v1 = norm.ppf(torch.linspace(offset, 0.5, 8)[:-1]).tolist()
v2 = [0] * (256 - 14) ## we have 14 non-zero values in this data type
v3 = (-norm.ppf(torch.linspace(offset, 0.5, 8)[:-1])).tolist()
v = v1 + v2 + v3
values = torch.Tensor(v)
values = values.sort().values
values /= values.max()
assert values.numel() == 256
return values
def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8):
"""Create a floating-point quantization map with configurable bit layout.
Generates a lookup table for a custom floating-point format following IEEE 754-like encoding
with configurable exponent and mantissa (precision) bits. Despite the name, this function
handles any total bit width (including FP4 when called with ``total_bits=4``).
The encoding uses:
- Exponent bias: ``2^(exponent_bits - 1)``
- Normal values: ``(1 + mantissa) * 2^(exponent - bias - 1)``
- Subnormal values (exponent field = 0): ``mantissa * 2^(-bias)``
Note: The values in the returned tensor are normalized by dividing by the maximum value,
so the actual represented range is [-1, 1].
For the FP4 type used in bitsandbytes (2 exponent bits, 1 mantissa bit, signed):
``create_fp8_map(signed=True, exponent_bits=2, precision_bits=1, total_bits=4)``
Args:
signed: Whether the format includes a sign bit.
exponent_bits: Number of bits for the exponent field.
precision_bits: Number of bits for the mantissa (precision/fraction) field.
total_bits: Total number of bits per value (must equal sign + exponent + precision).
Returns:
A 256-element tensor of sorted quantization levels normalized to [-1, 1].
For types with fewer than 8 bits, the remaining entries are zero-padded.
"""
e = exponent_bits
p = precision_bits
has_sign = 1 if signed else 0
assert e + p == total_bits - has_sign
# the exponent is biased to 2^(e-1) -1 == 0
evalues = []
for i, val in enumerate(range(-(2 ** (exponent_bits - has_sign)), 2 ** (exponent_bits - has_sign), 1)):
evalues.append(2**val)
values = []
lst = list(itertools.product([0, 1], repeat=precision_bits))
# for ev in evalues:
bias = 2 ** (exponent_bits - 1)
for evalue in range(2 ** (exponent_bits)):
for bit_pattern in lst:
value = 1 if evalue != 0 else 0
for i, pval in enumerate(list(bit_pattern)):
value += pval * (2 ** -(i + 1))
if evalue == 0:
# subnormals
value = value * 2**-(bias)
else:
# normals
value = value * 2 ** -(evalue - bias - 1)
values.append(value)
if signed:
values.append(-value)
assert len(values) == 2**total_bits
values.sort()
if total_bits < 8:
gap = 256 - len(values)
for i in range(gap):
values.append(0)
values.sort()
code = torch.tensor(values)
code /= code.max()
return code
def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8):
"""
Creates the dynamic quantiztion map.
The dynamic data type is made up of a dynamic exponent and
fraction. As the exponent increase from 0 to -7 the number
of bits available for the fraction shrinks.
This is a generalization of the dynamic type where a certain
number of the bits and be reserved for the linear quantization
region (the fraction). n determines the maximum number of
exponent bits.
For more details see
(8-Bit Approximations for Parallelism in Deep Learning)[https://arxiv.org/abs/1511.04561]
"""
data = []
# these are additional items that come from the case
# where all the exponent bits are zero and no
# indicator bit is present
non_sign_bits = total_bits - 1
additional_items = 2 ** (non_sign_bits - max_exponent_bits) - 1
for i in range(max_exponent_bits):
fraction_items = int(
2 ** (i + non_sign_bits - max_exponent_bits) + 1
if signed
else 2 ** (i + non_sign_bits - max_exponent_bits + 1) + 1,
)
boundaries = torch.linspace(0.1, 1, fraction_items, dtype=torch.float32)
means = (boundaries[:-1] + boundaries[1:]) / 2.0
data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
if signed:
data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
if additional_items > 0:
boundaries = torch.linspace(0.1, 1, additional_items + 1, dtype=torch.float32)
means = (boundaries[:-1] + boundaries[1:]) / 2.0
data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
if signed:
data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
data.append(0)
data.append(1.0)
assert len(data) == 2**total_bits
gap = 256 - len(data)
for i in range(gap):
data.append(0)
data.sort()
return torch.tensor(data, dtype=torch.float32)
def is_on_gpu(tensors: Iterable[Optional[torch.Tensor]]):
"""Verifies that the input tensors are all on the same device.
An input tensor may also be marked as `paged`, in which case the device placement is ignored.
Args:
tensors (`Iterable[Optional[torch.Tensor]]`): A list of tensors to verify.
Raises:
`RuntimeError`: Raised when the verification fails.
Returns:
`Literal[True]`
"""
on_gpu = True
gpu_ids = set()
for t in tensors:
# NULL pointers and paged tensors are OK.
if t is not None and not getattr(t, "is_paged", False):
on_gpu &= t.device.type != "cpu"
gpu_ids.add((t.device.type, t.device.index))
if not on_gpu:
raise RuntimeError(
f"All input tensors need to be on the same GPU, but found some tensors to not be on a GPU:\n {[(t.shape, t.device) for t in tensors]}",
)
if len(gpu_ids) > 1:
raise RuntimeError(
f"Input tensors need to be on the same GPU, but found the following tensor and device combinations:\n {[(t.shape, t.device) for t in tensors]}",
)
return on_gpu
def _get_tensor_stream(tensor: Tensor) -> ct.c_void_p:
# We use the raw stream for performance reasons.
if tensor.device.type == "xpu":
return ct.c_void_p(torch._C._xpu_getCurrentRawStream(tensor.device.index))
return ct.c_void_p(torch._C._cuda_getCurrentRawStream(tensor.device.index))
def get_ptr(A: Optional[Tensor]) -> Optional[ct.c_void_p]:
"""Gets the memory address of the first element of a tenso
Args:
A (`Optional[Tensor]`): A PyTorch tensor.
Returns:
`Optional[ct.c_void_p]`: A pointer to the underlying tensor data.
"""
if A is None:
return None
return ct.c_void_p(A.data_ptr())
class QuantState:
"""container for quantization state components to work with Params4bit and similar classes"""
valid_quant_types = ("fp4", "nf4")
valid_qs_type_keys = [f"bitsandbytes__{x}" for x in valid_quant_types]
valid_qs_keys = [
"absmax",
"quant_map",
"nested_absmax",
"nested_quant_map",
"quant_state",
"quant_type",
"blocksize",
"dtype",
"shape",
"nested_blocksize",
"nested_dtype",
"nested_offset",
]
def __init__(
self,
absmax,
shape=None,
code=None,
blocksize=None,
quant_type=None,
dtype=None,
offset=None,
state2=None,
):
self.absmax = absmax
self.shape = shape
self.code = code
self.dtype = dtype
self.blocksize = blocksize
self.quant_type = quant_type
self.offset = offset
self.state2 = state2
self.nested = state2 is not None
def __getitem__(self, idx):
"""
ensures compatibility with older quant state scheme with nested lists.
assumes the following layout:
state = [qabsmax, input_shape, A.dtype, blocksize, [offset, state2], quant_type]
state2 = [absmax, input_shape, A.dtype, blocksize, None, quant_type]
"""
if self.nested:
list_repr = [
self.absmax,
self.shape,
self.dtype,
self.blocksize,
[self.offset, self.state2],
self.quant_type,
]
else:
list_repr = [self.absmax, self.shape, self.dtype, self.blocksize, None, self.quant_type]
return list_repr[idx]
@classmethod
def from_dict(cls, qs_dict: dict[str, Any], device: torch.device) -> "QuantState":
"""
unpacks components of state_dict into QuantState
where necessary, convert into strings, torch.dtype, ints, etc.
qs_dict: based on state_dict, with only relevant keys, striped of prefixes.
item with key `quant_state.bitsandbytes__[nf4/fp4]` may contain minor and non-tensor quant state items.
"""
# unpacking tensor with non-tensor components
qs_key = [k for k, v in qs_dict.items() if "quant_state" in k and isinstance(v, torch.Tensor)]
if not len(qs_key) and "quant_type" not in qs_dict:
raise ValueError("Expected packed or unpacked quant_state items, found neither")
elif len(qs_key) != 1 or qs_key[0].split(".")[-1] not in cls.valid_qs_type_keys:
raise ValueError(
f"There should be exactly one `quant_state` item with ending from {cls.valid_qs_type_keys}.\nDetected {qs_key}.",
)
# unpacking minor and non-tensor quant state items if necessary
if len(qs_key) == 1:
first_qs_key = qs_key[0]
qs_dict.update(unpack_tensor_to_dict(qs_dict.pop(first_qs_key)))
qs_dict = {k.split(".")[-1]: v for k, v in qs_dict.items()} # strip prefixes
assert set(qs_dict.keys()).issubset(cls.valid_qs_keys)
if "nested_absmax" in qs_dict:
offset = torch.tensor(float(qs_dict["nested_offset"])).to(device)
state2 = cls(
absmax=qs_dict["nested_absmax"].to(device),
blocksize=qs_dict["nested_blocksize"],
code=qs_dict["nested_quant_map"].to(device),
dtype=getattr(torch, qs_dict["nested_dtype"]),
)
else:
offset, state2 = None, None
quant_state = cls(
quant_type=qs_dict["quant_type"],
absmax=qs_dict["absmax"].to(device),
blocksize=qs_dict["blocksize"],
code=qs_dict["quant_map"].to(device),
dtype=getattr(torch, qs_dict["dtype"]),
shape=torch.Size(qs_dict["shape"]) if qs_dict["shape"] is not None else None,
offset=offset,
state2=state2,
)
return quant_state
def as_dict(self, packed=False):
"""
returns dict of tensors and strings to use in serialization via _save_to_state_dict()
param: packed -- returns dict[str, torch.Tensor] for state_dict fit for safetensors saving
"""
qs_dict = {
"quant_type": self.quant_type,
"absmax": self.absmax,
"blocksize": self.blocksize,
"quant_map": self.code,
"dtype": str(self.dtype).strip("torch."),
"shape": tuple(self.shape),
}
if self.nested:
qs_dict.update(
{
"nested_absmax": self.state2.absmax,
"nested_blocksize": self.state2.blocksize,
"nested_quant_map": self.state2.code.clone(), # un-shared to avoid restoring it after shared tensors are removed by safetensors
"nested_dtype": str(self.state2.dtype).strip("torch."),
"nested_offset": self.offset.item(),
},
)
if not packed:
return qs_dict
# packed format allows serialization of non-tensor components, critical for saving in safetensors format
qs_packed_dict = {k: v for k, v in qs_dict.items() if isinstance(v, torch.Tensor)}
non_tensor_dict = {k: v for k, v in qs_dict.items() if not isinstance(v, torch.Tensor)}
qs_packed_dict["quant_state." + "bitsandbytes__" + self.quant_type] = pack_dict_to_tensor(non_tensor_dict)
return qs_packed_dict
def to(self, device):
# make sure the quantization state is on the right device
self.code = self.code.to(device)
self.absmax = self.absmax.to(device)
if self.nested:
self.offset = self.offset.to(device)
self.state2.absmax = self.state2.absmax.to(device)
self.state2.code = self.state2.code.to(device)
def __eq__(self, other):
if not isinstance(other, QuantState):
return False
return (
torch.allclose(self.absmax, other.absmax, atol=1e-6)
and self.shape == other.shape
and torch.allclose(self.code, other.code, atol=1e-6)
and self.dtype == other.dtype
and self.blocksize == other.blocksize
and self.quant_type == other.quant_type
and (
self.offset == other.offset
if self.offset is not None and other.offset is not None
else self.offset is other.offset
)
and (
self.state2 == other.state2
if self.state2 is not None and other.state2 is not None
else self.state2 is other.state2
)
)
def quantize_blockwise(
A: torch.Tensor,
code: Optional[torch.Tensor] = None,
absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
blocksize=4096,
nested=False,
) -> tuple[torch.Tensor, QuantState]:
"""Quantize a tensor in blocks of values.
The input tensor is quantized by dividing it into blocks of `blocksize` values.
The the absolute maximum value within these blocks is calculated for scaling
the non-linear quantization.
Args:
A (`torch.Tensor`): The input tensor. Supports `float16`, `bfloat16`, or `float32` datatypes.
code (`torch.Tensor`, *optional*):
A mapping describing the low-bit data type. Defaults to a signed 8-bit dynamic type.
For more details, see (8-Bit Approximations for Parallelism in Deep Learning)[https://arxiv.org/abs/1511.04561].
absmax (`torch.Tensor`, *optional*): A tensor to use to store the absmax values.
out (`torch.Tensor`, *optional*): A tensor to use to store the result.
blocksize (`int`, *optional*):
The size of the blocks. Defaults to 4096.
Valid values are 64, 128, 256, 512, 1024, 2048, and 4096.
nested (`bool`, *optional*): Whether to additionally quantize the absmax values. Defaults to False.
Raises:
ValueError: Raised when the input data type is not supported.
Returns:
`Tuple[torch.Tensor, QuantState]`: A tuple containing the quantization results.
- `torch.Tensor`: The quantized tensor.
- [`QuantState`]: The state object used to undo the quantization.
"""
if code is None:
if "dynamic" not in name2qmap:
name2qmap["dynamic"] = create_dynamic_map().to(A.device)
code = name2qmap["dynamic"]
_out, _absmax = torch.ops.bitsandbytes.quantize_blockwise.default(
A,
code.to(A.device),
blocksize,
)
if nested:
offset = _absmax.mean()
_absmax -= offset
qabsmax, state2 = quantize_blockwise(_absmax, blocksize=blocksize, nested=False)
quant_state = QuantState(
absmax=qabsmax,
code=code.to(A.device, copy=True),
blocksize=blocksize,
dtype=A.dtype,
offset=offset,
state2=state2,
)
else:
quant_state = QuantState(absmax=_absmax, code=code.to(A.device, copy=True), blocksize=blocksize, dtype=A.dtype)
# TODO(matthewdouglas): Deprecate out kwarg
out = out.copy_(_out) if out is not None else _out
# TODO(matthewdouglas): Deprecate absmax kwarg
if absmax is not None:
quant_state.absmax = absmax.copy_(quant_state.absmax)
return out, quant_state
def dequantize_blockwise(
A: torch.Tensor,
quant_state: Optional[QuantState] = None,
absmax: Optional[torch.Tensor] = None,
code: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
blocksize: int = 4096,
nested=False,
) -> torch.Tensor:
"""Dequantize a tensor in blocks of values.
The input tensor is dequantized by dividing it into blocks of `blocksize` values.
The the absolute maximum value within these blocks is used for scaling
the non-linear dequantization.
Args:
A (`torch.Tensor`): The quantized input tensor.
quant_state ([`QuantState`], *optional*):
The quantization state as returned by [`quantize_blockwise`].
Required if `absmax` is not provided.
absmax (`torch.Tensor`, *optional*):
A tensor containing the scaling values.
Required if `quant_state` is not provided and ignored otherwise.
code (`torch.Tensor`, *optional*):
A mapping describing the low-bit data type. Defaults to a signed 8-bit dynamic type.
For more details, see (8-Bit Approximations for Parallelism in Deep Learning)[https://arxiv.org/abs/1511.04561].
Ignored when `quant_state` is provided.
out (`torch.Tensor`, *optional*): A tensor to use to store the result.
blocksize (`int`, *optional*):
The size of the blocks. Defaults to 4096.
Valid values are 64, 128, 256, 512, 1024, 2048, and 4096.
Ignored when `quant_state` is provided.
Raises:
ValueError: Raised when the input data type is not supported.
Returns:
`torch.Tensor`:
The dequantized tensor. The datatype is indicated by `quant_state.dtype` and defaults to `torch.float32`.
"""
assert quant_state is not None or absmax is not None
if code is None and quant_state is None:
if "dynamic" not in name2qmap:
name2qmap["dynamic"] = create_dynamic_map().to(A.device)
code = name2qmap["dynamic"]
if quant_state is None:
quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=torch.float32)
absmax = quant_state.absmax
if quant_state.nested:
absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2)
absmax += quant_state.offset
if absmax.dtype != torch.float32:
absmax = absmax.float()
if out is not None:
torch.ops.bitsandbytes.dequantize_blockwise.out(
A,
absmax,
quant_state.code.to(A.device),
quant_state.blocksize,
quant_state.dtype,
out=out,
)
return out
return torch.ops.bitsandbytes.dequantize_blockwise.default(
A,
absmax,
quant_state.code.to(A.device),
quant_state.blocksize,
quant_state.dtype,
)
def get_4bit_type(typename, device=None, blocksize=64):
if device is None:
device = "cuda"
data = None
if typename == "nf4":
# NF4 (NormalFloat4) quantization type.
#
# These 16 values are a lookup table derived from quantiles of the standard normal
# distribution N(0, 1), where each bin has equal probability mass. The 4-bit index
# is just a position in this table — NF4 is NOT a floating-point encoding (no
# sign/exponent/mantissa decomposition). This is fundamentally different from FP4.
#
# Generated by: create_normal_map(offset=0.9677083, use_extra_value=True)
# Values are hardcoded to avoid a scipy dependency at runtime.
#
# For details see: QLoRA (https://arxiv.org/abs/2305.14314)
data = [
-1.0,
-0.6961928009986877,
-0.5250730514526367,
-0.39491748809814453,
-0.28444138169288635,
-0.18477343022823334,
-0.09105003625154495,
0.0,
0.07958029955625534,
0.16093020141124725,
0.24611230194568634,
0.33791524171829224,
0.44070982933044434,
0.5626170039176941,
0.7229568362236023,
1.0,
]
elif typename == "fp4":
# FP4 (4-bit floating point) quantization type.
#
# Unlike NF4, FP4 is an actual floating-point encoding with 1 sign bit, 2 exponent
# bits, and 1 mantissa bit. Values below are listed in bit-pattern order (not value
# order), where only the 3 non-sign bits are shown:
#
# 0b000 = 0 (subnormal: zero)
# 0b001 = 0.0625 (subnormal: 0.5 * 2^-2)
# 0b010 = 8 0b011 = 12 0b100 = 4
# 0b101 = 6 0b110 = 2 0b111 = 3
#
# The exponent bias is 2^(e-1) = 2, which differs from IEEE 754's convention.
# These can be regenerated with:
# create_fp8_map(signed=True, exponent_bits=2, precision_bits=1, total_bits=4)
#
# All values are normalized to [-1, 1] after construction (see end of function).
data = [0, 0.0625, 8.0, 12.0, 4.0, 6.0, 2.0, 3.0, -0, -0.0625, -8.0, -12.0, -4.0, -6.0, -2.0, -3.0]
elif typename == "int4":
data = [7, 6, 5, 4, 3, 2, 1, 0, -0, -1, -2, -3, -4, -5, -6, -7]
elif typename == "af4":
# Taken from: NF4 Isn't Information Theoretically Optimal (and that's Good)
# https://arxiv.org/abs/2306.06965
if blocksize == 64:
data = [
-1.0,
-0.69441008,
-0.51243739,
-0.3736951,
-0.25607552,
-0.14982478,
-0.04934812,
0.0,
0.04273164,
0.12934483,
0.21961274,
0.31675666,
0.42563882,
0.55496234,
0.72424863,
1.0,
][::-1]
else:
raise NotImplementedError("4-bit AbnormalFloats currently only support blocksize 64.")
if data is None:
raise NotImplementedError(f"Typename {typename} not supported")
data = torch.tensor(data, device=device)
data.div_(data.abs().max())
assert data.numel() == 16
return data
def quantize_fp4(
A: torch.Tensor,
absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
blocksize=None,
compress_statistics=False,
quant_storage=torch.uint8,
):
return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "fp4", quant_storage)
def quantize_nf4(
A: torch.Tensor,
absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
blocksize=None,
compress_statistics=False,
quant_storage=torch.uint8,
):
return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "nf4", quant_storage)
def quantize_4bit(
A: torch.Tensor,
absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
blocksize=None,
compress_statistics=False,
quant_type="fp4",
quant_storage=torch.uint8,
) -> tuple[torch.Tensor, QuantState]:
"""Quantize tensor A in blocks of 4-bit values.
Quantizes tensor A by dividing it into blocks which are independently quantized.
Args:
A (`torch.Tensor`): The input tensor. Supports `float16`, `bfloat16`, or `float32` datatypes.
absmax (`torch.Tensor`, *optional*): A tensor to use to store the absmax values.
out (`torch.Tensor`, *optional*): A tensor to use to store the result.
blocksize (`int`, *optional*):
The size of the blocks. Defaults to 64.
Valid values are 32, 64, 128, 256, 512, 1024, 2048, and 4096.
compress_statistics (`bool`, *optional*): Whether to additionally quantize the absmax values. Defaults to False.
quant_type (`str`, *optional*): The data type to use: `nf4` or `fp4`. Defaults to `fp4`.
quant_storage (`torch.dtype`, *optional*): The dtype of the tensor used to store the result. Defaults to `torch.uint8`.
Raises:
ValueError: Raised when the input data type is not supported.
Returns:
Tuple[`torch.Tensor`, `QuantState`]: A tuple containing the quantization results.
- `torch.Tensor`: The quantized tensor with packed 4-bit values.
- [`QuantState`]: The state object used to undo the quantization.
"""
if blocksize is None:
blocksize = 64
input_shape = A.shape
_out, _absmax = torch.ops.bitsandbytes.quantize_4bit.default(
A,
blocksize,
quant_type,
quant_storage,
)
code = get_4bit_type(quant_type, device=A.device)
if compress_statistics:
offset = _absmax.mean()
qabsmax, state2 = quantize_blockwise(_absmax - offset, blocksize=256)
del _absmax
state = QuantState(
absmax=qabsmax,
shape=input_shape,
dtype=A.dtype,
blocksize=blocksize,
code=code,
quant_type=quant_type,
offset=offset,
state2=state2,
)
else:
state = QuantState(
absmax=_absmax,
shape=input_shape,
dtype=A.dtype,
blocksize=blocksize,
code=code,
quant_type=quant_type,
)
# TODO(matthewdouglas): Deprecate out kwarg
out = out.copy_(_out) if out is not None else _out
# TODO(matthewdouglas): Deprecate absmax kwarg
if absmax is not None:
state.absmax = absmax.copy_(state.absmax)
return out, state
def dequantize_fp4(
A: torch.Tensor,
quant_state: Optional[QuantState] = None,
absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
blocksize: Optional[int] = None,
) -> torch.Tensor:
return dequantize_4bit(A, quant_state, absmax, out, blocksize, "fp4")
def dequantize_nf4(
A: torch.Tensor,
quant_state: Optional[QuantState] = None,
absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
blocksize: Optional[int] = None,
) -> torch.Tensor:
return dequantize_4bit(A, quant_state, absmax, out, blocksize, "nf4")
def dequantize_4bit(
A: torch.Tensor,
quant_state: Optional[QuantState] = None,
absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
blocksize: Optional[int] = None,
quant_type="fp4",
) -> torch.Tensor:
"""Dequantizes a packed 4-bit quantized tensor.
The input tensor is dequantized by dividing it into blocks of `blocksize` values.
The the absolute maximum value within these blocks is used for scaling
the non-linear dequantization.
Args:
A (`torch.Tensor`): The quantized input tensor.
quant_state ([`QuantState`], *optional*):
The quantization state as returned by [`quantize_4bit`].
Required if `absmax` is not provided.
absmax (`torch.Tensor`, *optional*):
A tensor containing the scaling values.
Required if `quant_state` is not provided and ignored otherwise.
out (`torch.Tensor`, *optional*): A tensor to use to store the result.
blocksize (`int`, *optional*):
The size of the blocks. Defaults to 64.
Valid values are 32, 64, 128, 256, 512, 1024, 2048, and 4096.
quant_type (`str`, *optional*): The data type to use: `nf4` or `fp4`. Defaults to `fp4`.
Raises:
ValueError: Raised when the input data type or blocksize is not supported.
Returns:
`torch.Tensor`: The dequantized tensor.
"""
if blocksize is None:
blocksize = 64
if quant_state is None:
assert absmax is not None and out is not None
quant_state = QuantState(
absmax=absmax,
shape=out.shape,
dtype=out.dtype,
blocksize=blocksize,
quant_type=quant_type,
)
else:
absmax = quant_state.absmax
if quant_state.nested:
absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2)
absmax += quant_state.offset
if absmax.dtype != torch.float32:
absmax = absmax.float()
if out is not None:
torch.ops.bitsandbytes.dequantize_4bit.out(
A, absmax, quant_state.blocksize, quant_state.quant_type, quant_state.shape, quant_state.dtype, out=out
)
else:
out = torch.ops.bitsandbytes.dequantize_4bit.default(
A,
absmax,
quant_state.blocksize,
quant_state.quant_type,
quant_state.shape,
quant_state.dtype,
)
if A.shape[0] == 1: # is transposed, transpose back
return out.t()
return out
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def quantize(
A: Tensor,
code: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
) -> tuple[Tensor, tuple[Tensor, Tensor]]:
if code is None:
if "dynamic" not in name2qmap:
name2qmap["dynamic"] = create_dynamic_map().to(A.device)
code = name2qmap["dynamic"]
code = code.to(A.device)
absmax = torch.abs(A).max()
if absmax.dtype != torch.float32:
absmax = absmax.float()
inp = A / absmax
out = quantize_no_absmax(inp, code, out)
return out, (absmax, code)
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def dequantize(
A: Tensor,
state: Optional[tuple[Tensor, Tensor]] = None,
absmax: Optional[torch.Tensor] = None,
code: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
) -> Tensor:
assert state is not None or absmax is not None
if code is None and state is None:
if "dynamic" not in name2qmap:
name2qmap["dynamic"] = create_dynamic_map().to(A.device)
code = name2qmap["dynamic"]
code = code.to(A.device)
if state is None:
state = (absmax, code)
out = dequantize_no_absmax(A, state[1], out)
return out * state[0]
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def quantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = None) -> Tensor:
"""
Quantizes input tensor to 8-bit.
Quantizes the 32-bit input tensor `A` to the 8-bit output tensor
`out` using the quantization map `code`.
Parameters
----------
A : torch.Tensor
The input tensor.
code : torch.Tensor
The quantization map.
out : torch.Tensor, optional
The output tensor. Needs to be of type byte.
Returns
-------
torch.Tensor:
Quantized 8-bit tensor.
"""
with _cuda_device_of(A):
if out is None:
out = torch.zeros_like(A, dtype=torch.uint8)
is_on_gpu([A, out])
lib.cquantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()))
return out
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def dequantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = None) -> Tensor:
"""
Dequantizes the 8-bit tensor to 32-bit.
Dequantizes the 8-bit tensor `A` to the 32-bit tensor `out` via
the quantization map `code`.
Parameters
----------
A : torch.Tensor
The 8-bit input tensor.
code : torch.Tensor
The quantization map.
out : torch.Tensor
The 32-bit output tensor.
Returns
-------
torch.Tensor:
32-bit output tensor.
"""
with _cuda_device_of(A):
if out is None:
out = torch.zeros_like(A, dtype=torch.float32)
is_on_gpu([code, A, out])
stream = _get_tensor_stream(A)
lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()), stream)
return out
def optimizer_update_32bit(
optimizer_name: str,
g: Tensor,
p: Tensor,
state1: Tensor,
beta1: float,
eps: float,
step: int,
lr: float,
state2: Optional[torch.Tensor] = None,
beta2: float = 0.0,
beta3: float = 0.0,
alpha: float = 0.0,
weight_decay: float = 0.0,
gnorm_scale: float = 1.0,
unorm_vec: Optional[torch.Tensor] = None,
max_unorm: float = 0.0,
skip_zeros=False,
) -> None:
"""
Performs an inplace optimizer update with one or two optimizer states.
Universal optimizer update for 32-bit state and 32/16-bit gradients/weights.
Parameters
----------
optimizer_name : str
The name of the optimizer: {adam}.
g : torch.Tensor
Gradient tensor.
p : torch.Tensor
Parameter tensor.
state1 : torch.Tensor
Optimizer state 1.
beta1 : float
Optimizer beta1.
eps : float
Optimizer epsilon.
weight_decay : float
Weight decay.
step : int
Current optimizer step.
lr : float
The learning rate.
state2 : torch.Tensor
Optimizer state 2.
beta2 : float
Optimizer beta2.
beta3 : float
Optimizer beta3.
alpha : float
Optimizer alpha.
gnorm_scale : float
The factor to rescale the gradient to the max clip value.
unorm_vec : torch.Tensor
The tensor for the update norm.
max_unorm : float
The maximum update norm relative to the weight norm.
skip_zeros : bool
Whether to skip zero-valued gradients or not (default: False).
"""
param_norm = 0.0
if max_unorm > 0.0:
param_norm = torch.norm(p.data.float())
is_on_gpu([g, p, state1, state2, unorm_vec])
torch.ops.bitsandbytes.optimizer_update_32bit(
optimizer_name,
g,
p,
state1,
state2,
unorm_vec,
max_unorm,
param_norm,
beta1,
beta2,
beta3,
alpha,
eps,
weight_decay,
step,
lr,
gnorm_scale,
skip_zeros,
)
@deprecated(
"This function is deprecated and will be removed in a future release. "
"Please use optimizer_update_8bit_blockwise instead. ",
category=FutureWarning,
)
def optimizer_update_8bit(
optimizer_name: str,
g: Tensor,
p: Tensor,
state1: Tensor,
state2: Optional[torch.Tensor],
beta1: float,
beta2: float,
eps: float,
step: int,
lr: float,
qmap1: Tensor,
qmap2: Optional[torch.Tensor],
max1: Tensor,
max2: Optional[torch.Tensor],
new_max1: Tensor,
new_max2: Optional[torch.Tensor],
weight_decay: float = 0.0,
gnorm_scale: float = 1.0,
unorm_vec: Optional[torch.Tensor] = None,
max_unorm: float = 0.0,
) -> None:
"""
Performs an inplace Adam update.
Universal Adam update for 32/8-bit state and 32/16-bit gradients/weights.
Uses AdamW formulation if weight decay > 0.0.
Parameters
----------
optimizer_name : str
The name of the optimizer. Choices {adam, momentum}
g : torch.Tensor
Gradient tensor.
p : torch.Tensor
Parameter tensor.
state1 : torch.Tensor
Adam state 1.
state2 : torch.Tensor
Adam state 2.
beta1 : float
Adam beta1.
beta2 : float
Adam beta2.
eps : float
Adam epsilon.
weight_decay : float
Weight decay.
step : int
Current optimizer step.
lr : float
The learning rate.
qmap1 : torch.Tensor
Quantization map for first Adam state.
qmap2 : torch.Tensor
Quantization map for second Adam state.
max1 : torch.Tensor
Max value for first Adam state update.
max2 : torch.Tensor
Max value for second Adam state update.
new_max1 : torch.Tensor
Max value for the next Adam update of the first state.
new_max2 : torch.Tensor
Max value for the next Adam update of the second state.
gnorm_scale : float
The factor to rescale the gradient to the max clip value.
unorm_vec : torch.Tensor
The tensor for the update norm.
max_unorm : float
The maximum update norm relative to the weight norm.
"""
param_norm = 0.0
if max_unorm > 0.0:
param_norm = torch.norm(p.data.float())
with _cuda_device_of(g):
is_on_gpu([g, p, state1, state2, unorm_vec, qmap1, qmap2, max1, max2, new_max1, new_max2])
if g.dtype == torch.float32 and state1.dtype == torch.uint8:
str2optimizer8bit[optimizer_name][0](
get_ptr(p),
get_ptr(g),
get_ptr(state1),
get_ptr(state2),
get_ptr(unorm_vec),
ct.c_float(max_unorm),
ct.c_float(param_norm),
ct.c_float(beta1),
ct.c_float(beta2),
ct.c_float(eps),
ct.c_int32(step),
ct.c_float(lr),
get_ptr(qmap1),
get_ptr(qmap2),
get_ptr(max1),
get_ptr(max2),
get_ptr(new_max1),
get_ptr(new_max2),
ct.c_float(weight_decay),
ct.c_float(gnorm_scale),
ct.c_int32(g.numel()),
)
elif g.dtype == torch.float16 and state1.dtype == torch.uint8:
str2optimizer8bit[optimizer_name][1](
get_ptr(p),
get_ptr(g),
get_ptr(state1),
get_ptr(state2),
get_ptr(unorm_vec),
ct.c_float(max_unorm),
ct.c_float(param_norm),
ct.c_float(beta1),
ct.c_float(beta2),
ct.c_float(eps),
ct.c_int32(step),
ct.c_float(lr),
get_ptr(qmap1),
get_ptr(qmap2),
get_ptr(max1),
get_ptr(max2),
get_ptr(new_max1),
get_ptr(new_max2),
ct.c_float(weight_decay),
ct.c_float(gnorm_scale),
ct.c_int32(g.numel()),
)
else:
raise ValueError(
f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}",
)
def optimizer_update_8bit_blockwise(
optimizer_name: str,
g: Tensor,
p: Tensor,
state1: Tensor,
state2: Optional[torch.Tensor],
beta1: float,
beta2: float,
beta3: float,
alpha: float,
eps: float,
step: int,
lr: float,
qmap1: Tensor,
qmap2: Optional[torch.Tensor],
absmax1: Tensor,
absmax2: Optional[torch.Tensor],
weight_decay: float = 0.0,
gnorm_scale: float = 1.0,
skip_zeros=False,
) -> None:
is_on_gpu([p, g, state1, state2, qmap1, qmap2, absmax1, absmax2])
torch.ops.bitsandbytes.optimizer_update_8bit_blockwise(
optimizer_name,
g,
p,
state1,
state2,
beta1,
beta2,
beta3,
alpha,
eps,
step,
lr,
qmap1,
qmap2,
absmax1,
absmax2,
weight_decay,
gnorm_scale,
skip_zeros,
)
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile: int = 5):
"""Applies percentile clipping
grad: torch.Tensor
The gradient tensor.
gnorm_vec: torch.Tensor
Vector of gradient norms. 100 elements expected.
step: int
The current optimization steps (number of past gradient norms).
"""
with _cuda_device_of(grad):
is_on_gpu([grad, gnorm_vec])
if grad.dtype == torch.float32:
lib.cpercentile_clipping_g32(
get_ptr(grad),
get_ptr(gnorm_vec),
ct.c_int32(step),
ct.c_int32(grad.numel()),
)
elif grad.dtype == torch.float16:
lib.cpercentile_clipping_g16(
get_ptr(grad),
get_ptr(gnorm_vec),
ct.c_int32(step),
ct.c_int32(grad.numel()),
)
else:
raise ValueError(f"Gradient type {grad.dtype} not supported!")
current_gnorm = torch.sqrt(gnorm_vec[step % 100])
vals, _ = torch.sort(gnorm_vec)
clip_value = torch.sqrt(vals[percentile])
gnorm_scale = 1.0
if current_gnorm > clip_value:
gnorm_scale = clip_value / current_gnorm
return current_gnorm, clip_value, gnorm_scale
def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8):
if not torch.cuda.is_initialized():
torch.cuda.init()
if A.dtype != expected_type or B.dtype != expected_type:
raise TypeError(f"Expected torch.int8 input tensors A and B, but got {A.dtype} and {B.dtype}")
sA = A.shape
sB = B.shape
tA = transposed_A
tB = transposed_B
correct = True
if len(sA) == 2 and len(sB) == 2:
if not tA and not tB and A.shape[1] != B.shape[0]:
correct = False
elif tA and not tB and A.shape[0] != B.shape[0]:
correct = False
elif tA and tB and A.shape[0] != B.shape[1]:
correct = False
elif not tA and tB and A.shape[1] != B.shape[1]:
correct = False
elif len(sA) == 3 and len(sB) == 2:
if not tA and not tB and A.shape[2] != B.shape[0]:
correct = False
elif tA and not tB and A.shape[1] != B.shape[0]:
correct = False
elif tA and tB and A.shape[1] != B.shape[1]:
correct = False
elif not tA and tB and A.shape[2] != B.shape[1]:
correct = False
elif len(sA) == 3 and len(sB) == 3:
if not tA and not tB and A.shape[2] != B.shape[1]:
correct = False
elif tA and not tB and A.shape[1] != B.shape[1]:
correct = False
elif tA and tB and A.shape[1] != B.shape[2]:
correct = False
elif not tA and tB and A.shape[2] != B.shape[2]:
correct = False
if out is not None:
sout = out.shape
# special case common in backprop
if not correct and len(sA) == 3 and len(sB) == 3:
if sout[0] == sA[2] and sout[1] == sB[2] and sA[0] == sB[0] and sA[1] == sB[1]:
correct = True
else:
if len(sA) == 2 and len(sB) == 2:
if not tA and not tB:
sout = (sA[0], sB[1])
elif tA and tB:
sout = (sA[1], sB[0])
elif tA and not tB:
sout = (sA[1], sB[1])
elif not tA and tB:
sout = (sA[0], sB[0])
elif len(sA) == 3 and len(sB) == 2:
if not tA and not tB:
sout = (sA[0], sA[1], sB[1])
elif tA and tB:
sout = (sA[0], sA[2], sB[0])
elif tA and not tB:
sout = (sA[0], sA[2], sB[1])
elif not tA and tB:
sout = (sA[0], sA[1], sB[0])
elif len(sA) == 3 and len(sB) == 3:
if not tA and not tB:
sout = (sA[0], sA[1], sB[2])
elif tA and tB:
sout = (sA[0], sA[2], sB[1])
elif tA and not tB:
sout = (sA[0], sA[2], sB[2])
elif not tA and tB:
sout = (sA[0], sA[1], sB[1])
if not correct:
raise ValueError(
f"Tensor dimensions incorrect for matrix mulitiplication: A x B: {sA} x {sB} with transpose for A x B: {tA} x {tB}.",
)
return sout
def gemv_4bit(
A: Tensor,
B: Tensor,
out: Optional[torch.Tensor] = None,
transposed_A=False,
transposed_B=False,
state=None,
):
if state is None:
raise ValueError("state cannot be None. gemv_4bit() requires the state from quantize_4bit()")
absmax = state.absmax
if state.nested:
absmax = dequantize_blockwise(absmax, state.state2) + state.offset
if out is not None:
torch.ops.bitsandbytes.gemv_4bit.out(
A,
B,
state.shape,
absmax,
state.code,
state.blocksize,
out=out,
)
return out
return torch.ops.bitsandbytes.gemv_4bit.default(
A,
B,
state.shape,
absmax,
state.code,
state.blocksize,
)
def igemm(
A: Tensor,
B: Tensor,
out: Optional[torch.Tensor] = None,
transposed_A=False,
transposed_B=False,
):
sout = check_matmul(A, B, out, transposed_A, transposed_B)
if out is None:
out = torch.zeros(size=sout, dtype=torch.int32, device=A.device)
if len(A.shape) == 3 and len(B.shape) == 3:
if A.shape[0] == B.shape[0] and A.shape[2] == B.shape[1]:
return batched_igemm(A, B, out)
sA = A.shape
sB = B.shape
if transposed_A and len(sA) == 2:
sA = (sA[1], sA[0])
elif transposed_A and len(sA) == 3:
sA = (sA[0], sA[2], sA[0])
if transposed_B and len(sB) == 2:
sB = (sB[1], sB[0])
elif transposed_B and len(sB) == 3:
sB = (sB[0], sB[2], sB[0])
# this is a mess: cuBLAS expect column major, but PyTorch is row major.
# So to perform the matrix multiplication, we have to treat A, B, and C matrices
# (transpose of row major is column major)
# This means we compute B^T A^T = C^T and we explicitly switch the dimensions of each of these
# matrices in the input arguments for cuBLAS
# column major: A @ B = C: [m, k] @ [k, n] = [m, n]
# row major: B^T @ A^T = C^T: [m, k] @ [k, n] = [m, n]
# column major with row major layout: B^T @ A^T = C^T: [k, m] @ [n, k] = [n, m]
if len(sB) == 2:
if B.stride()[0] == B.shape[1]:
transposed_B = False
elif B.stride()[1] == B.shape[0]:
transposed_B = True
if len(A.shape) == 2:
if A.stride()[0] == A.shape[1]:
transposed_A = False
elif A.stride()[1] == A.shape[0]:
transposed_A = True
else:
if A.stride()[1] == A.shape[2]:
transposed_A = False
elif A.stride()[2] == A.shape[1]:
transposed_A = True
if len(sA) == 2:
n = sA[0]
ldb = A.stride()[1 if transposed_A else 0]
elif len(sA) == 3 and len(sB) == 2:
n = sA[0] * sA[1]
ldb = sA[2]
m = sB[1]
k = sB[0]
lda = B.stride()[(1 if transposed_B else 0)]
ldc = sB[1]
elif len(sB) == 3:
# special case
assert len(sA) == 3
if not (sA[0] == sB[0] and sA[1] == sB[1]):
raise ValueError(
f"Only bsi,bso->io supported for tensor contractions, but dims for A x B were: {sA} x {sB}",
)
transposed_A = True
transposed_B = False
m = sB[2]
n = sA[2]
k = sB[0] * sB[1]
lda = m
ldb = sA[2]
ldc = m
ptr = CUBLAS_Context.get_instance().get_context(A.device)
# B^T @ A^T = C^T
# [km, nk -> mn]
is_on_gpu([B, A, out])
lib.cigemm(
ptr,
ct.c_bool(transposed_B),
ct.c_bool(transposed_A),
ct.c_int32(m),
ct.c_int32(n),
ct.c_int32(k),
get_ptr(B),
get_ptr(A),
get_ptr(out),
ct.c_int32(lda),
ct.c_int32(ldb),
ct.c_int32(ldc),
)
return out
def batched_igemm(
A: Tensor,
B: Tensor,
out: Optional[torch.Tensor] = None,
transposed_A=False,
transposed_B=False,
):
if not len(A.shape) == 3 or not len(B.shape) == 3:
raise ValueError(f"Expected 3-dimensional tensors for bmm, but got shapes A and B: {A.shape} and {B.shape}")
sout = check_matmul(A, B, out, transposed_A, transposed_B)
if out is None:
out = torch.zeros(size=sout, dtype=torch.int32, device=A.device)
if B.is_contiguous():
lda = B.stride()[1]
transposed_A = False
else:
s = B.stride()
if s[0] != B.shape[0]:
B = B.contiguous()
lda = B.stride()[1]
elif s[2] == B.shape[1]:
transposed_A = True
lda = B.stride()[2]
else:
if s[2] == 1:
B = B.contiguous()
lda = B.stride()[1]
elif s[1] == 1:
B = B.contiguous()
lda = B.stride()[1]
else:
B = B.contiguous()
lda = B.stride()[1]
if A.is_contiguous():
ldb = A.stride()[1]
transposed_B = False
else:
s = A.stride()
if s[0] != A.shape[0]:
A = A.contiguous()
ldb = A.stride()[1]
transposed_B = False
elif s[2] == A.shape[1]:
ldb = A.stride()[2]
transposed_B = True
else:
A = A.contiguous()
ldb = A.stride()[1]
transposed_B = False
# this is a mess: cuBLAS expect column major, but PyTorch is row major.
# So to perform the matrix multiplication, we have to treat A, B, and C matrices
# (transpose of row major is column major)
# This means we compute B^T A^T = C^T and we explicitly switch the dimensions of each of these
# matrices in the input arguments for cuBLAS
# column major: A @ B = C: [batch, m, k] @ [batch, k, n] = [batch, m, n]
# row major: B^T @ A^T = C^T: [batch, m, k] @ [batch, k, n] = [batch, m, n]
# column major with row major layout: B^T @ A^T = C^T: [batch, k, m] @ [batch, n, k] = [batch, n, m]
num_batch = A.shape[0]
n = A.shape[1]
m = B.shape[2]
k = B.shape[1]
ldc = m
strideA = B.shape[1] * B.shape[2]
strideB = A.shape[1] * A.shape[2]
strideC = A.shape[1] * B.shape[2]
ptr = CUBLAS_Context.get_instance().get_context(A.device)
is_on_gpu([B, A, out])
lib.cbatched_igemm(
ptr,
ct.c_bool(transposed_B),
ct.c_bool(transposed_A),
ct.c_int32(m),
ct.c_int32(n),
ct.c_int32(k),
get_ptr(B),
get_ptr(A),
get_ptr(out),
ct.c_int32(lda),
ct.c_int32(ldb),
ct.c_int32(ldc),
ct.c_long(strideA),
ct.c_long(strideB),
ct.c_long(strideC),
ct.c_uint32(num_batch),
)
return out
def int8_linear_matmul(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Tensor] = None, dtype=torch.int32):
"""Performs an 8-bit integer matrix multiplication.
A linear transformation is applied such that `out = A @ B.T`. When possible, integer tensor core hardware is
utilized to accelerate the operation.
Args:
A (`torch.Tensor`): The first matrix operand with the data type `torch.int8`.
B (`torch.Tensor`): The second matrix operand with the data type `torch.int8`.
out (`torch.Tensor`, *optional*): A pre-allocated tensor used to store the result.
dtype (`torch.dtype`, *optional*): The expected data type of the output. Defaults to `torch.int32`.
Raises:
`NotImplementedError`: The operation is not supported in the current environment.
`RuntimeError`: Raised when the cannot be completed for any other reason.
Returns:
`torch.Tensor`: The result of the operation.
"""
if out is not None:
torch.ops.bitsandbytes.int8_linear_matmul.out(A, B, out)
return out
return torch.ops.bitsandbytes.int8_linear_matmul.default(A, B)
def int8_mm_dequant(
A: torch.Tensor,
row_stats: torch.Tensor,
col_stats: torch.Tensor,
out: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
):
"""Performs dequantization on the result of a quantized int8 matrix multiplication.
Args:
A (`torch.Tensor` with dtype `torch.int32`): The result of a quantized int8 matrix multiplication.
row_stats (`torch.Tensor`): The row-wise quantization statistics for the lhs operand of the matrix multiplication.
col_stats (`torch.Tensor`): The column-wise quantization statistics for the rhs operand of the matrix multiplication.
out (`torch.Tensor`, *optional*): A pre-allocated tensor to store the output of the operation.
bias (`torch.Tensor`, *optional*): An optional bias vector to add to the result.
Returns:
`torch.Tensor`: The dequantized result with an optional bias, with dtype `torch.float16`.
"""
result = torch.ops.bitsandbytes.int8_mm_dequant.default(A, row_stats, col_stats, dtype=torch.float16, bias=bias)
# TODO(matthewdouglas): Deprecate out kwarg
if out is not None:
return out.copy_(result)
return result
class COOSparseTensor:
def __init__(
self, rows: int, cols: int, nnz: int, rowidx: torch.Tensor, colidx: torch.Tensor, values: torch.Tensor
):
assert rowidx.dtype == torch.int32
assert colidx.dtype == torch.int32
assert values.dtype == torch.float16
assert values.numel() == nnz
assert rowidx.numel() == nnz
assert colidx.numel() == nnz
self.rows = rows
self.cols = cols
self.nnz = nnz
self.rowidx = rowidx
self.colidx = colidx
self.values = values
class CSRSparseTensor:
def __init__(self, rows, cols, nnz, rowptr, colidx, values):
assert rowptr.dtype == torch.int32
assert colidx.dtype == torch.int32
assert values.dtype == torch.float16
assert values.numel() == nnz
assert colidx.numel() == nnz
assert rowptr.numel() == rows + 1
self.rows = rows
self.cols = cols
self.nnz = nnz
self.rowptr = rowptr
self.colidx = colidx
self.values = values
class CSCSparseTensor:
def __init__(self, rows, cols, nnz, colptr, rowidx, values):
assert colptr.dtype == torch.int32
assert rowidx.dtype == torch.int32
assert values.dtype == torch.float16
assert values.numel() == nnz
assert rowidx.numel() == nnz
assert colptr.numel() == cols + 1
self.rows = rows
self.cols = cols
self.nnz = nnz
self.colptr = colptr
self.rowidx = rowidx
self.values = values
def coo2csr(cooA):
values, counts = torch.unique(cooA.rowidx, return_counts=True)
values.add_(1)
rowptr = torch.zeros((cooA.rows + 1,), dtype=torch.int32, device=cooA.rowidx.device)
rowptr.scatter_(index=values.long(), src=counts.int(), dim=0)
rowptr.cumsum_(0)
return CSRSparseTensor(cooA.rows, cooA.cols, cooA.nnz, rowptr, cooA.colidx, cooA.values)
def coo2csc(cooA):
val, col2rowidx = torch.sort(cooA.colidx)
rowidx = cooA.rowidx[col2rowidx]
values = cooA.values[col2rowidx]
colvalues, counts = torch.unique(val, return_counts=True)
colvalues.add_(1)
colptr = torch.zeros((cooA.cols + 1,), dtype=torch.int32, device=cooA.colidx.device)
colptr.scatter_(index=colvalues.long(), src=counts.int(), dim=0)
colptr.cumsum_(0)
return CSCSparseTensor(cooA.rows, cooA.cols, cooA.nnz, colptr, rowidx, values)
def coo_zeros(rows, cols, nnz, device, dtype=torch.half):
rowidx = torch.zeros((nnz,), dtype=torch.int32, device=device)
colidx = torch.zeros((nnz,), dtype=torch.int32, device=device)
values = torch.zeros((nnz,), dtype=dtype, device=device)
return COOSparseTensor(rows, cols, nnz, rowidx, colidx, values)
def int8_double_quant(
A: torch.Tensor,
col_stats: Optional[torch.Tensor] = None,
row_stats: Optional[torch.Tensor] = None,
out_col: Optional[torch.Tensor] = None,
out_row: Optional[torch.Tensor] = None,
threshold=0.0,
):
"""Determine the quantization statistics for input matrix `A` in accordance to the `LLM.int8()` algorithm.
The statistics are determined both row-wise and column-wise (transposed).
For more information, see the [LLM.int8() paper](https://arxiv.org/abs/2208.07339).
<Tip>
This function is useful for training, but for inference it is advised to use [`int8_vectorwise_quant`] instead.
This implementation performs additional column-wise transposed calculations which are not optimized.
</Tip>
Args:
A (`torch.Tensor` with dtype `torch.float16`): The input matrix.
col_stats (`torch.Tensor`, *optional*): A pre-allocated tensor to hold the column-wise quantization scales.
row_stats (`torch.Tensor`, *optional*): A pre-allocated tensor to hold the row-wise quantization scales.
out_col (`torch.Tensor`, *optional*): A pre-allocated tensor to hold the column-wise quantized data.
out_row (`torch.Tensor`, *optional*): A pre-allocated tensor to hold the row-wise quantized data.
threshold (`float`, *optional*):
An optional threshold for sparse decomposition of outlier features.
No outliers are held back when 0.0. Defaults to 0.0.
Returns:
`Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]`: A tuple containing the quantized tensor and relevant statistics.
- `torch.Tensor` with dtype `torch.int8`: The row-wise quantized data.
- `torch.Tensor` with dtype `torch.int8`: The column-wise quantized data.
- `torch.Tensor` with dtype `torch.float32`: The row-wise quantization scales.
- `torch.Tensor` with dtype `torch.float32`: The column-wise quantization scales.
- `torch.Tensor` with dtype `torch.int32`, *optional*: A list of column indices which contain outlier features.
"""
if row_stats is not None:
raise ValueError("row_stats must be None. int8_double_quant() does not support pre-allocated row_stats.")
if col_stats is not None:
raise ValueError("col_stats must be None. int8_double_quant() does not support pre-allocated col_stats.")
if out_col is not None:
raise ValueError("out_col must be None. int8_double_quant() does not support pre-allocated out_col.")
if out_row is not None:
raise ValueError("out_row must be None. int8_double_quant() does not support pre-allocated out_row.")
return torch.ops.bitsandbytes.int8_double_quant.default(A, threshold=threshold)
def int8_vectorwise_dequant(A: torch.Tensor, stats: torch.Tensor):
"""Dequantizes a tensor with dtype `torch.int8` to `torch.float32`.
Args:
A (`torch.Tensor` with dtype `torch.int8`): The quantized int8 tensor.
stats (`torch.Tensor` with dtype `torch.float32`): The row-wise quantization statistics.
Returns:
`torch.Tensor` with dtype `torch.float32`: The dequantized tensor.
"""
# To dequantize we divide by 127, or multiply by the reciprocal.
return torch.ops.bitsandbytes.int8_vectorwise_dequant.default(A, stats)
def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0):
"""Quantizes a tensor with dtype `torch.float16` to `torch.int8` in accordance to the `LLM.int8()` algorithm.
For more information, see the [LLM.int8() paper](https://arxiv.org/abs/2208.07339).
Args:
A (`torch.Tensor` with dtype `torch.float16`): The input tensor.
threshold (`float`, *optional*):
An optional threshold for sparse decomposition of outlier features.
No outliers are held back when 0.0. Defaults to 0.0.
Returns:
`Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]`: A tuple containing the quantized tensor and relevant statistics.
- `torch.Tensor` with dtype `torch.int8`: The quantized data.
- `torch.Tensor` with dtype `torch.float32`: The quantization scales.
- `torch.Tensor` with dtype `torch.int32`, *optional*: A list of column indices which contain outlier features.
"""
return torch.ops.bitsandbytes.int8_vectorwise_quant.default(A, threshold)
def spmm_coo(
cooA: COOSparseTensor | torch.Tensor,
B: torch.Tensor,
out: Optional[torch.Tensor] = None,
):
if not isinstance(cooA, COOSparseTensor):
assert cooA.is_sparse and cooA.layout == torch.sparse_coo, (
"Tensor must be `COOSparseTensor or a PyTorch COO tensor."
)
# Convert to custom COOSparseTensor
cooA = COOSparseTensor(
rows=cooA.shape[0],
cols=cooA.shape[1],
nnz=cooA._nnz(),
rowidx=cooA.indices()[0].int(),
colidx=cooA.indices()[1].int(),
values=cooA.values(),
)
if out is None:
out = torch.empty((cooA.rows, B.shape[1]), device=B.device, dtype=B.dtype)
nnz = cooA.nnz
assert cooA.rowidx.numel() == nnz
assert cooA.colidx.numel() == nnz
assert cooA.values.numel() == nnz
assert cooA.cols == B.shape[0]
transposed_B = not B.is_contiguous()
ldb = B.stride()[(1 if transposed_B else 0)]
ldc = B.shape[1]
ptr = Cusparse_Context.get_instance().context
ptrRowidx = get_ptr(cooA.rowidx)
ptrColidx = get_ptr(cooA.colidx)
ptrValues = get_ptr(cooA.values)
ptrB = get_ptr(B)
ptrC = get_ptr(out)
cnnz = ct.c_int32(cooA.nnz)
crowsA = ct.c_int32(cooA.rows)
ccolsA = ct.c_int32(cooA.cols)
ccolsB = ct.c_int32(B.shape[1])
cldb = ct.c_int32(ldb)
cldc = ct.c_int32(ldc)
is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out])
lib.cspmm_coo(
ptr,
ptrRowidx,
ptrColidx,
ptrValues,
cnnz,
crowsA,
ccolsA,
ccolsB,
cldb,
ptrB,
cldc,
ptrC,
ct.c_bool(transposed_B),
)
return out
def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
if out is None:
out = torch.zeros((cooA.rows, B.shape[1]), device=B.device, dtype=cooA.values.dtype)
nnz = cooA.nnz
assert cooA.rowidx.numel() == nnz
assert cooA.colidx.numel() == nnz
assert cooA.values.numel() == nnz
assert cooA.cols == B.shape[0], f"{cooA.cols} vs {B.shape}"
_, counts = torch.unique(cooA.rowidx, return_counts=True)
offset = counts.cumsum(0).int()
max_count, max_idx = torch.sort(counts, descending=True)
max_idx = max_idx.int()
max_count = max_count.int()
assert max_count[0] <= 32, f"Current max count per row is 8 but found {max_count[0]}."
assert B.dtype in [torch.float16, torch.int8]
ptrOffset = get_ptr(offset)
ptrMaxCount = get_ptr(max_count)
ptrMaxIdx = get_ptr(max_idx)
ptrRowidx = get_ptr(cooA.rowidx)
ptrColidx = get_ptr(cooA.colidx)
ptrValues = get_ptr(cooA.values)
ptrB = get_ptr(B)
ptrC = get_ptr(out)
ptrDequantStats = get_ptr(dequant_stats)
cnnz_rows = ct.c_int32(counts.numel())
cnnz = ct.c_int32(cooA.nnz)
crowsA = ct.c_int32(cooA.rows)
crowsB = ct.c_int32(B.shape[1])
ccolsB = ct.c_int32(B.shape[1])
with _cuda_device_of(B):
is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out, dequant_stats])
if B.dtype == torch.float16:
lib.cspmm_coo_very_sparse_naive_fp16(
ptrMaxCount,
ptrMaxIdx,
ptrOffset,
ptrRowidx,
ptrColidx,
ptrValues,
ptrB,
ptrC,
ptrDequantStats,
cnnz_rows,
cnnz,
crowsA,
crowsB,
ccolsB,
)
elif B.dtype == torch.int8:
lib.cspmm_coo_very_sparse_naive_int8(
ptrMaxCount,
ptrMaxIdx,
ptrOffset,
ptrRowidx,
ptrColidx,
ptrValues,
ptrB,
ptrC,
ptrDequantStats,
cnnz_rows,
cnnz,
crowsA,
crowsB,
ccolsB,
)
# else: assertion error
return out
def _convert_weight_packed_for_cpu(qweight: torch.Tensor, quant_state: QuantState, block_n: int = 32):
"""
qweight: (K * N / 2) uint8
return: packed_weight
"""
if qweight.dtype != torch.uint8:
quant_state.original_storage_type = qweight.dtype
qweight = qweight.view(torch.uint8)
quant_state.original_dtype = quant_state.dtype
quant_state.original_nested = quant_state.nested
quant_state.original_qshape = qweight.shape
qweight = qweight.reshape(-1)
unpacked_w = torch.empty(qweight.shape[0] * 2, dtype=torch.int32, device=qweight.device)
unpacked_w[1::2] = qweight & 0xF
unpacked_w[::2] = qweight >> 4
qweight_final = unpacked_w.reshape(quant_state.shape).to(torch.uint8) # (*, N, K)
# pack weight: [*, N, K] -> [*, N, K/2] combine low and high bit
assert len(qweight_final.shape) == 2
N, K = qweight_final.shape[0], qweight_final.shape[1]
assert N % block_n == 0, "N must be divisible by block_n"
assert K % 2 == 0, "K must be even"
BLOCK_N = block_n
BIT_COUNT = 32 # (=32 low +32 high)
new_shape = [N // BLOCK_N, BLOCK_N, K // 2, 2]
out_shape = [N, K // 2]
qw = qweight_final.reshape(new_shape) # (..., N/B, B, K/2, 2)
qw = qw.transpose(-3, -2).contiguous() # (..., N/B, K/2, B, 2)
qw = qw.reshape(-1, BIT_COUNT * 2) # [-1, 64]
high = qw[:, BIT_COUNT:] # high 32
low = qw[:, :BIT_COUNT] # low 32
packed = ((high << 4) | low).to(torch.uint8) # combine
final_qweight = packed.reshape(out_shape)
if quant_state.nested:
absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2)
absmax += quant_state.offset
if absmax.dtype != torch.float32:
absmax = absmax.float()
quant_state.absmax = absmax
quant_state.nested = False
delattr(quant_state, "state2")
quant_state.absmax = (
quant_state.absmax.reshape(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize)
.T.to(torch.bfloat16)
.contiguous()
)
quant_state.dtype = torch.bfloat16
quant_state.packing_format_for_cpu = True
return final_qweight, quant_state
def _convert_weight_packed_for_cpu_inverse(
packed_weight: torch.Tensor,
quant_state: QuantState,
block_n: int = 32,
) -> tuple[torch.Tensor, QuantState]:
"""
packed_weight: [N, K/2] uint8, output of `_convert_weight_packed_for_cpu` (final_qweight)
quant_state: QuantState that was modified by `_convert_weight_packed_for_cpu`
Returns:
qweight: [*, N, K] uint8, original qweight shape (quant_state.shape)
recovered_state: QuantState with partially restored fields (best-effort inverse)
"""
assert quant_state.packing_format_for_cpu, "only for packing format"
assert packed_weight.dtype == torch.uint8
assert len(packed_weight.shape) == 2, "packed_weight should be [N, K/2]"
N, K_half = packed_weight.shape
K = K_half * 2
# 1) packed [N, K/2] -> [N//BLOCK_N, BLOCK_N, K/2, 2]
BLOCK_N = block_n
BIT_COUNT = 32 # (=32 low + 32 high)
assert N % BLOCK_N == 0, "N must be divisible by block_n"
assert K % 2 == 0, "K must be even"
# [N, K/2] -> [-1, 64] (32 low + 32 high)
packed = packed_weight.reshape(-1, BIT_COUNT) # [-1, 64]
# split high/low nibbles
high = (packed >> 4) & 0xF
low = packed & 0xF
# concatenate to [..., 64], first 32 are low, last 32 are high
qw = torch.cat([low, high], dim=-1).to(torch.uint8) # [..., 64]
# -> [N/BLOCK_N, K/2, BLOCK_N, 2] -> [N, K]
qw = qw.reshape(N // BLOCK_N, K_half, BLOCK_N, 2) # [N/B, K/2, B, 2]
qw = qw.transpose(-3, -2).contiguous() # [N/B, B, K/2, 2]
qw = qw.reshape(N, K) # [N, K]
qweight = qw # [N, K]
unpacked_w = qweight.reshape(-1).to(torch.int32) # [K*N]
high4 = (unpacked_w[::2] & 0xF).to(torch.uint8)
low4 = (unpacked_w[1::2] & 0xF).to(torch.uint8)
qweight = (high4 << 4) | low4 # [K*N/2]
# 2) Best-effort restore of quant_state fields (absmax / dtype / nested flags, etc.)
recovered_state = quant_state
qweight = qweight.to(torch.uint8).reshape(recovered_state.original_qshape)
# quantize absmax
if recovered_state.original_nested:
absmax = recovered_state.absmax.T.reshape(-1).to(recovered_state.original_dtype)
offset = absmax.mean()
qabsmax, state2 = quantize_blockwise(absmax - offset, blocksize=256)
recovered_state.absmax = qabsmax
recovered_state.offset = offset
recovered_state.state2 = state2
recovered_state.nested = True
recovered_state.dtype = recovered_state.original_dtype
recovered_state.packing_format_for_cpu = False
if getattr(recovered_state, "original_storage_type", None):
qweight = qweight.view(recovered_state.original_storage_type)
return qweight, recovered_state
def has_avx512bf16():
"""
Try calling native lib.has_avx512bf16_cpu().
Return False explicitly if symbol missing or call fails.
"""
try:
support_avx_bf16 = lib.has_avx512bf16_cpu()
except (AttributeError, RuntimeError, OSError):
support_avx_bf16 = False
return support_avx_bf16
C = 127.0
|