File size: 99,452 Bytes
5000658 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783 1784 1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1824 1825 1826 1827 1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 1838 1839 1840 1841 1842 1843 1844 1845 1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857 1858 1859 1860 1861 1862 1863 1864 1865 1866 1867 1868 1869 1870 1871 1872 1873 1874 1875 1876 1877 1878 1879 1880 1881 1882 1883 1884 1885 1886 1887 1888 1889 1890 1891 1892 1893 1894 1895 1896 1897 1898 1899 1900 1901 1902 1903 1904 1905 1906 1907 1908 1909 1910 1911 1912 1913 1914 1915 1916 1917 1918 1919 1920 1921 1922 1923 1924 1925 1926 1927 1928 1929 1930 1931 1932 1933 1934 1935 1936 1937 1938 1939 1940 1941 1942 1943 1944 1945 1946 1947 1948 1949 1950 1951 1952 1953 1954 1955 1956 1957 1958 1959 1960 1961 1962 1963 1964 1965 1966 1967 1968 1969 1970 1971 1972 1973 1974 1975 1976 1977 1978 1979 1980 1981 1982 1983 1984 1985 1986 1987 1988 1989 1990 1991 1992 1993 1994 1995 1996 1997 1998 1999 2000 2001 2002 2003 2004 2005 2006 2007 2008 2009 2010 2011 2012 2013 2014 2015 2016 2017 2018 2019 2020 2021 2022 2023 2024 2025 2026 2027 2028 2029 2030 2031 2032 2033 2034 2035 2036 2037 2038 2039 2040 2041 2042 2043 2044 2045 2046 2047 2048 2049 2050 2051 2052 2053 2054 2055 2056 2057 2058 2059 2060 2061 2062 2063 2064 2065 2066 2067 2068 2069 2070 2071 2072 2073 2074 2075 2076 2077 2078 2079 2080 2081 2082 2083 2084 2085 2086 2087 2088 2089 2090 2091 2092 2093 2094 2095 2096 2097 2098 2099 2100 2101 2102 2103 2104 2105 2106 2107 2108 2109 2110 2111 2112 2113 2114 2115 2116 2117 2118 2119 2120 2121 2122 2123 2124 2125 2126 2127 2128 2129 2130 2131 2132 2133 2134 2135 2136 2137 2138 2139 2140 2141 2142 2143 2144 2145 2146 2147 2148 2149 2150 2151 2152 2153 2154 2155 2156 2157 2158 2159 2160 2161 2162 2163 2164 2165 2166 2167 2168 2169 2170 2171 2172 2173 2174 2175 2176 2177 2178 2179 2180 2181 2182 2183 2184 2185 2186 2187 2188 2189 2190 2191 2192 2193 2194 2195 2196 2197 2198 2199 2200 2201 2202 2203 2204 2205 2206 2207 2208 2209 2210 2211 2212 2213 2214 2215 2216 2217 2218 2219 2220 2221 2222 2223 2224 2225 2226 2227 2228 2229 2230 2231 2232 2233 2234 2235 2236 |
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import functools
import json
import os
import sys
import time
from collections import defaultdict
from pathlib import Path
from typing import List, Optional
import numpy as np
import safetensors
import torch
import torch.nn as nn
from tqdm import tqdm
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
from transformers.pytorch_utils import Conv1D
from ..._utils import pad_vocab_size, release_gc, str_dtype_to_torch
from ...logger import logger
from ...quantization import QuantAlgo
from ..convert_utils import (iterate_shard_files, load_calib_dataset,
load_state_dict, retrieved_layer_index_from_name)
from ..modeling_utils import PretrainedConfig
from .config import LLaMAConfig
def generate_int8(weights, act_range, is_qkv=False, multi_query_mode=False):
"""
This function has two purposes:
- compute quantized weights, scaled either per-tensor or per-column
- compute scaling factors
Depending on the GEMM API (CUTLASS/CUBLAS) the required scaling factors differ.
CUTLASS uses two sets of scaling factors. One for the activation X, one for the weight W.
CUBLAS only has one (we can't do per-row scaling). So we must provide pre-multiplied scaling factor.
Here is the list of what we need (T means per-tensor, C per-column):
- scale_x_orig_quant puts fp activation into the quantized range (i.e. [-128, 127], for int8). Used before the GEMM. (T)
- scale_y_quant_orig puts quantized activation into the fp range. Used if the GEMM outputs int8. (T)
- scale_w_quant_orig puts weights from quant range to fp range (used with CUTLASS) (T, C)
- scale_y_accum_quant puts the GEMM result (XW) from accumulation range (int32)
to quant range (int8) (used for CUBLAS) (T, C)
Note that we don't do anything special about row-parallel GEMM. Theoretically, we could have per-GPU scaling factors too,
but then the model would change depending on the number of GPUs used.
For QKV projection, the behavior is special. Even if we have a single matrix to perform QKV projection, we consider it
as three different matrices: Q, K, and V. So per-tensor actually means one scaling factor for each Q, K and V.
For our GEMM implementation to respect this behavior, we use per-column mode and replicate values along columns.
"""
# compute weight scaling factors for fp->int8 and int8->fp
if is_qkv and not multi_query_mode:
scale_w_orig_quant_t = 127. / act_range["w"].reshape(3, -1).max(
dim=-1, keepdims=True)[0]
scale_w_orig_quant_c = 127. / act_range["w"].reshape(3, -1)
elif is_qkv and multi_query_mode:
hidden_dim = weights.shape[0]
local_dim = act_range["w"].shape[0]
kv_dim = (local_dim - hidden_dim) // 2
scale_w_q = act_range["w"][0:hidden_dim]
scale_w_k = act_range["w"][hidden_dim:hidden_dim + kv_dim]
scale_w_v = act_range["w"][-kv_dim:]
scale_w_qkv_t = torch.concat([
scale_w_q.max(dim=0, keepdim=True)[0],
scale_w_k.max(dim=0, keepdim=True)[0],
scale_w_v.max(dim=0, keepdim=True)[0]
])
scale_w_orig_quant_t = 127. / scale_w_qkv_t
scale_w_orig_quant_c = 127. / act_range["w"]
else:
scale_w_orig_quant_t = 127. / act_range["w"].max()
scale_w_orig_quant_c = 127. / act_range["w"]
scale_w_quant_orig_t = 1.0 / scale_w_orig_quant_t
scale_w_quant_orig_c = 1.0 / scale_w_orig_quant_c
scale_w_orig_quant_c = scale_w_orig_quant_c.to(torch.float32)
scale_w_orig_quant_t = scale_w_orig_quant_t.to(torch.float32)
# compute the rest of needed scaling factors
scale_x_orig_quant_t = 127. / act_range["x"].max()
scale_y_orig_quant_t = 127. / act_range["y"].max()
scale_y_quant_orig_t = act_range["y"].max() / 127.
scale_y_accum_quant_t = scale_y_orig_quant_t / (scale_x_orig_quant_t *
scale_w_orig_quant_t)
scale_y_accum_quant_c = scale_y_orig_quant_t / (scale_x_orig_quant_t *
scale_w_orig_quant_c)
if is_qkv and not multi_query_mode:
scale_y_accum_quant_t = torch.broadcast_to(scale_y_accum_quant_t,
scale_w_orig_quant_c.shape)
scale_w_quant_orig_t = torch.broadcast_to(scale_w_quant_orig_t,
scale_w_orig_quant_c.shape)
if is_qkv and multi_query_mode:
scale_q_y_accum_t = torch.broadcast_to(scale_y_accum_quant_t[0],
scale_w_q.shape)
scale_k_y_accum_t = torch.broadcast_to(scale_y_accum_quant_t[1],
scale_w_k.shape)
scale_v_y_accum_t = torch.broadcast_to(scale_y_accum_quant_t[2],
scale_w_v.shape)
scale_y_accum_quant_t = torch.concat(
[scale_q_y_accum_t, scale_k_y_accum_t, scale_v_y_accum_t])
scale_w_quant_orig_t = torch.concat([
torch.broadcast_to(scale_w_quant_orig_t[0], scale_w_q.shape),
torch.broadcast_to(scale_w_quant_orig_t[1], scale_w_k.shape),
torch.broadcast_to(scale_w_quant_orig_t[2], scale_w_v.shape)
])
to_i8 = lambda x: x.round().clip(-127, 127).to(torch.int8)
if is_qkv and multi_query_mode:
weight_int8 = to_i8(weights / scale_w_quant_orig_t)
else:
weight_int8 = to_i8(weights * scale_w_orig_quant_t)
return {
"weight.int8": weight_int8,
"weight.int8.col": to_i8(weights * scale_w_orig_quant_c),
"scale_x_orig_quant": scale_x_orig_quant_t.to(torch.float32),
"scale_w_quant_orig": scale_w_quant_orig_t.to(torch.float32),
"scale_w_quant_orig.col": scale_w_quant_orig_c.to(torch.float32),
"scale_y_accum_quant": scale_y_accum_quant_t.to(torch.float32),
"scale_y_accum_quant.col": scale_y_accum_quant_c.to(torch.float32),
"scale_y_quant_orig": scale_y_quant_orig_t.to(torch.float32),
}
@torch.no_grad()
def apply_smoothing(scales,
gemm_weights,
layernorm_weights=None,
layernorm_bias=None,
dtype=torch.float32,
layernorm_1p=False):
if not isinstance(gemm_weights, list):
gemm_weights = [gemm_weights]
if layernorm_weights is not None:
assert layernorm_weights.numel() == scales.numel()
layernorm_weights.div_(scales).to(dtype)
if layernorm_bias is not None:
assert layernorm_bias.numel() == scales.numel()
layernorm_bias.div_(scales).to(dtype)
if layernorm_1p:
layernorm_weights += (1 / scales) - 1
for gemm in gemm_weights:
gemm.mul_(scales.view(1, -1)).to(dtype)
@torch.no_grad()
def smooth_gemm(gemm_weights,
act_scales,
layernorm_weights=None,
layernorm_bias=None,
alpha=0.5,
weight_scales=None):
if not isinstance(gemm_weights, list):
gemm_weights = [gemm_weights]
orig_dtype = gemm_weights[0].dtype
for gemm in gemm_weights:
# gemm_weights are expected to be transposed
assert gemm.shape[1] == act_scales.numel()
if weight_scales is None:
weight_scales = torch.cat(
[gemm.abs().max(dim=0, keepdim=True)[0] for gemm in gemm_weights],
dim=0)
weight_scales = weight_scales.max(dim=0)[0]
weight_scales.to(float).clamp(min=1e-5)
scales = (act_scales.to(gemm_weights[0].device).to(float).pow(alpha) /
weight_scales.pow(1 - alpha)).clamp(min=1e-5)
apply_smoothing(scales, gemm_weights, layernorm_weights, layernorm_bias,
orig_dtype)
return scales
@torch.no_grad()
def smooth_gemm_fc1_gate(fc1_weights,
gate_weights,
act_scales,
layernorm_weights=None,
layernorm_bias=None,
alpha=0.5,
weight_scales=None):
gemm_weights = []
if not isinstance(fc1_weights, list):
fc1_weights = [fc1_weights]
if not isinstance(gate_weights, list):
gate_weights = [gate_weights]
for i in range(len(fc1_weights)):
gemm_weight = torch.cat([fc1_weights[i], gate_weights[i]], dim=0)
gemm_weights.append(gemm_weight)
orig_dtype = gemm_weights[0].dtype
for gemm in gemm_weights:
# gemm_weights are expected to be transposed
assert gemm.shape[1] == act_scales.numel()
if weight_scales is None:
weight_scales = torch.cat(
[gemm.abs().max(dim=0, keepdim=True)[0] for gemm in gemm_weights],
dim=0)
weight_scales = weight_scales.max(dim=0)[0]
weight_scales.to(float).clamp(min=1e-5)
scales = (act_scales.to(gemm_weights[0].device).to(float).pow(alpha) /
weight_scales.pow(1 - alpha)).clamp(min=1e-5)
apply_smoothing(scales, fc1_weights + gate_weights, layernorm_weights,
layernorm_bias, orig_dtype)
return scales
@torch.no_grad()
def smooth_llama_model(model, scales, alpha, llama_qkv_para, llama_smoother):
# Smooth the activation and weights with smoother = $\diag{s}$
for name, module in model.named_modules():
if not isinstance(
module,
LlamaDecoderLayer) and not module.__class__.__name__ in [
"InternLMDecoderLayer", "MistralDecoderLayer"
]:
continue
# qkv_proj
layer_name_q = name + ".self_attn.q_proj"
layer_name_k = name + ".self_attn.k_proj"
layer_name_v = name + ".self_attn.v_proj"
layer_name_qkv = name + ".self_attn.qkv_proj"
weight = torch.cat([
module.self_attn.q_proj.weight, module.self_attn.k_proj.weight,
module.self_attn.v_proj.weight
],
dim=0)
smoother = smooth_gemm(weight, scales[layer_name_q]["x"],
module.input_layernorm.weight, None, alpha)
scales[layer_name_qkv]["x"] = scales[layer_name_q]["x"] / smoother
scales[layer_name_qkv]["w"] = weight.abs().max(dim=1)[0]
scales[layer_name_qkv]["y"] = torch.cat([
scales[layer_name_q]["y"], scales[layer_name_k]["y"],
scales[layer_name_v]["y"]
],
dim=0)
# see transpose_weights function
llama_qkv_para[layer_name_qkv] = weight.transpose(0, 1)
# =================================================================
layer_name = name + ".self_attn.o_proj"
smoother = smooth_gemm(module.self_attn.o_proj.weight,
scales[layer_name]["x"], None, None, alpha)
llama_smoother[layer_name] = smoother.float()
scales[layer_name]["x"] = scales[layer_name]["x"] / smoother
scales[layer_name]["w"] = module.self_attn.o_proj.weight.abs().max(
dim=1)[0]
# ==================================================================
fc1_layer_name = name + ".mlp.gate_proj"
gate_layer_name = name + ".mlp.up_proj"
smoother = smooth_gemm_fc1_gate(module.mlp.gate_proj.weight,
module.mlp.up_proj.weight,
scales[fc1_layer_name]["x"],
module.post_attention_layernorm.weight,
None, alpha)
scales[fc1_layer_name]["x"] = scales[fc1_layer_name]["x"] / smoother
scales[fc1_layer_name]["w"] = module.mlp.gate_proj.weight.abs().max(
dim=1)[0]
scales[gate_layer_name]["x"] = scales[gate_layer_name]["x"] / smoother
scales[gate_layer_name]["w"] = module.mlp.up_proj.weight.abs().max(
dim=1)[0]
# ==================================================================
layer_name = name + ".mlp.down_proj"
smoother = smooth_gemm(module.mlp.down_proj.weight,
scales[layer_name]["x"], None, None, alpha)
llama_smoother[layer_name] = smoother.float()
scales[layer_name]["x"] = scales[layer_name]["x"] / smoother
scales[layer_name]["w"] = module.mlp.down_proj.weight.abs().max(
dim=1)[0]
# ==================================================================
if hasattr(module, 'residual_mlp'):
fc1_layer_name = name + ".residual_mlp.w1"
gate_layer_name = name + ".residual_mlp.w3"
smoother = smooth_gemm_fc1_gate(module.residual_mlp.w1.weight,
module.residual_mlp.w3.weight,
scales[fc1_layer_name]["x"],
module.residual_layernorm.weight,
None, alpha)
scales[fc1_layer_name]["x"] = scales[fc1_layer_name]["x"] / smoother
scales[fc1_layer_name]["w"] = module.residual_mlp.w1.weight.abs(
).max(dim=1)[0]
scales[gate_layer_name][
"x"] = scales[gate_layer_name]["x"] / smoother
scales[gate_layer_name]["w"] = module.residual_mlp.w3.weight.abs(
).max(dim=1)[0]
# ==================================================================
layer_name = name + ".residual_mlp.w2"
smoother = smooth_gemm(module.residual_mlp.w2.weight,
scales[layer_name]["x"], None, None, alpha)
llama_smoother[layer_name] = smoother.float()
scales[layer_name]["x"] = scales[layer_name]["x"] / smoother
scales[layer_name]["w"] = module.residual_mlp.w2.weight.abs().max(
dim=1)[0]
@torch.no_grad()
def capture_activation_range(model,
tokenizer,
dataset,
num_samples=512,
seq_len=512):
model.eval()
device = next(model.parameters()).device
act_scales = defaultdict(lambda: {"x": None, "y": None, "w": None})
tokenizer.pad_token = tokenizer.eos_token
def stat_tensor(name, tensor, act_scales, key):
hidden_dim = tensor.shape[-1]
tensor = tensor.view(-1, hidden_dim).abs().detach()
comming_max = torch.max(tensor, dim=0)[0].float()
if act_scales[name][key] is None:
act_scales[name][key] = comming_max
else:
act_scales[name][key] = torch.max(act_scales[name][key],
comming_max)
def stat_input_hook(m, x, y, name):
if isinstance(x, tuple):
x = x[0]
stat_tensor(name, x, act_scales, "x")
stat_tensor(name, y, act_scales, "y")
if act_scales[name]["w"] is None:
act_scales[name]["w"] = m.weight.abs().clip(1e-8,
None).max(dim=1)[0]
hooks = []
for name, m in model.named_modules():
if isinstance(m, nn.Linear) or isinstance(m, Conv1D):
hooks.append(
m.register_forward_hook(
functools.partial(stat_input_hook, name=name)))
for i in tqdm(range(num_samples), desc="calibrating model"):
datapoint = dataset[i:i + 1]
line = copy.copy(datapoint)
line[0] = line[0] + ' TL;DR: '
line[0] = line[0].strip()
line[0] = line[0].replace(" n't", "n't")
input_ids = tokenizer(line,
return_tensors="pt",
max_length=seq_len,
padding=True,
truncation=True).input_ids.to(device)
model(input_ids)
for h in hooks:
h.remove()
return act_scales
def split(v, tp_size, idx, dim=0):
if tp_size == 1:
return v
if len(v.shape) == 1:
return torch.chunk(v, tp_size)[idx].contiguous()
else:
return torch.chunk(v, tp_size, dim=dim)[idx].contiguous()
def split_qkv_tp(v, n_head, n_hidden, tensor_parallel, rank):
"""
Splits the QKV matrix according to tensor parallelism
"""
v = v.reshape(3, n_hidden, n_hidden)
split_v = split(v, tensor_parallel, rank, dim=1)
split_v = split_v.reshape(3 * (n_hidden // tensor_parallel), n_hidden)
return split_v
def split_qkv_bias_tp(v, n_head, n_hidden, tensor_parallel, rank):
"""
Splits the QKV bias according to tensor parallelism
"""
v = v.reshape(3, n_hidden)
split_v = split(v, tensor_parallel, rank, dim=1)
split_v = split_v.reshape(3 * (n_hidden // tensor_parallel))
return split_v
def split_matrix_tp(v, tensor_parallel, rank, dim):
return split(v, tensor_parallel, rank, dim=dim)
def get_weight(config, prefix, dtype):
if config[prefix + '.weight'].dtype != dtype:
config[prefix + '.weight'].data = config[prefix + '.weight'].to(dtype)
return config[prefix + '.weight'].detach()
def get_bias(config, prefix, dtype):
if config[prefix + '.bias'].dtype != dtype:
config[prefix + '.bias'].data = config[prefix + '.bias'].to(dtype)
return config[prefix + '.bias'].detach()
def get_weight_and_bias(config, prefix, dtype):
return get_weight(config, prefix, dtype), get_bias(config, prefix, dtype)
def fp8_per_channel_quant_weight_gpu(weight, rank=0):
weight = weight.to("cuda:" + str(rank))
# activation range bound.
x = weight.to(torch.float32).clamp(-1200.0, 1200.0)
xmax = x.abs().max(-1, keepdim=True).values
# minimum scaling factor.
torch_weight_scales = (xmax / 448.0).clamp(min=1.0 / (448.0 * 512.0))
torch_weight_scales = torch_weight_scales.reshape(-1)
out = x * 448.0 / xmax
out = torch.clamp(out, -448, 448)
processed_torch_weights = out.to(torch.float8_e4m3fn)
processed_torch_weights = processed_torch_weights.to(
torch.float8_e4m3fn).cpu()
torch_weight_scales = torch_weight_scales.cpu()
return processed_torch_weights, torch_weight_scales
def get_tllm_linear_weight(weight,
prefix,
bias=None,
use_weight_only=False,
plugin_weight_only_quant_type=torch.int8,
dtype='float32',
use_gemm_woq_plugin=True,
use_fp8_rowwise=False,
tp_rank=0,
postfix='weight',
quant_scale_name=None):
results = {}
if use_weight_only:
if weight.dim() > 2:
v = weight.transpose(1, 2).contiguous()
else:
v = weight.t().contiguous()
processed_torch_weights, torch_weight_scales = \
torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix(
v.cpu(), plugin_weight_only_quant_type)
if not use_gemm_woq_plugin:
results[prefix + postfix] = v.to(dtype)
else:
results[prefix + postfix] = processed_torch_weights
if quant_scale_name is not None:
results[quant_scale_name] = torch_weight_scales
else:
results[prefix + 'per_channel_scale'] = torch_weight_scales
elif use_fp8_rowwise:
processed_torch_weights, torch_weight_scales = fp8_per_channel_quant_weight_gpu(
weight)
results[prefix + postfix] = processed_torch_weights
torch_weight_scales = torch_weight_scales.to(torch.float32)
if quant_scale_name is not None:
results[quant_scale_name] = torch_weight_scales
else:
results[prefix + 'per_channel_scale'] = torch_weight_scales
else:
results[prefix + postfix] = weight
if bias is not None:
results[prefix + 'bias'] = bias
return results
def dup_kv_weight(v, num_head, tp_size):
assert tp_size % num_head == 0
reps = tp_size // num_head
head_size = v.shape[0] // num_head
v = v.reshape(num_head, head_size,
-1)[:, None, :, :].expand(num_head, reps, head_size,
v.shape[1])
return v.reshape(num_head * reps * head_size, -1).clone().detach()
def get_tllm_linear_sq_weight(vals,
prefix,
shape,
tensor_parallel,
is_qkv=False,
per_token=False,
per_channel=False,
last_prefix=None,
bias=None,
smoother_value=None,
smoother_shape=None,
rank=0,
cat_dim=0,
multi_query_mode=False):
results = {}
def multi_query_split(data, local_dim, head_size, tp_size, cur_rank):
q, k, v = torch.split(data, [local_dim, head_size, head_size], dim=-1)
q_split = torch.chunk(q, tp_size, dim=-1)
k_split = torch.chunk(k, tp_size, dim=-1)
v_split = torch.chunk(v, tp_size, dim=-1)
return [
torch.concat((q_split[ii], k_split[ii], v_split[ii]), axis=-1)
for ii in range(tp_size)
][cur_rank]
col_shape = shape if (is_qkv or per_channel) else [1, 1]
if per_token:
if per_channel:
original_weights = torch.Tensor(vals["weight.int8.col"]).cuda()
else:
original_weights = torch.Tensor(vals["weight.int8"]).cuda()
local_dim = original_weights.shape[0]
head_size = (original_weights.shape[1] - local_dim) // 2
if multi_query_mode:
cur_weights = multi_query_split(original_weights, local_dim,
head_size, tensor_parallel, rank)
else:
cur_weights = torch.chunk(original_weights,
tensor_parallel,
dim=cat_dim)[rank]
if is_qkv:
hidden_dim = cur_weights.shape[0]
cur_weights = cur_weights.reshape(hidden_dim, -1)
results[prefix + 'weight'] = cur_weights.t().contiguous()
if smoother_value is None:
results[last_prefix] = torch.Tensor([1.0]).to(torch.float32).cuda()
if per_channel:
cur_per_channel_value = vals["scale_w_quant_orig.col"]
if smoother_value is None:
if multi_query_mode:
cur_per_channel_value = multi_query_split(
vals["scale_w_quant_orig.col"], local_dim, head_size,
tensor_parallel, rank)
else:
cur_per_channel_value = torch.chunk(
vals["scale_w_quant_orig.col"],
tensor_parallel,
dim=cat_dim)[rank]
else:
cur_per_channel_value = vals["scale_w_quant_orig"]
if is_qkv:
if multi_query_mode:
cur_per_channel_value = multi_query_split(
vals["scale_w_quant_orig"], local_dim, head_size,
tensor_parallel, rank)
else:
cur_per_channel_value = torch.chunk(
vals["scale_w_quant_orig"],
tensor_parallel,
dim=cat_dim)[rank]
results[prefix + 'per_channel_scale'] = cur_per_channel_value.reshape(
col_shape).contiguous()
else:
if per_channel:
original_weights = torch.Tensor(vals["weight.int8.col"]).cuda()
else:
original_weights = torch.Tensor(vals["weight.int8"]).cuda()
local_dim = original_weights.shape[0]
head_size = (original_weights.shape[1] - local_dim) // 2
if multi_query_mode:
cur_weights = multi_query_split(original_weights, local_dim,
head_size, tensor_parallel, rank)
else:
cur_weights = torch.chunk(original_weights,
tensor_parallel,
dim=cat_dim)[rank]
if is_qkv:
hidden_dim = cur_weights.shape[0]
cur_weights = cur_weights.reshape(hidden_dim, -1)
results[prefix + 'weight'] = cur_weights.t().contiguous()
if per_channel:
cur_per_channel_value = vals["scale_y_accum_quant.col"]
if smoother_value is None:
if multi_query_mode:
cur_per_channel_value = multi_query_split(
vals["scale_y_accum_quant.col"], local_dim, head_size,
tensor_parallel, rank)
else:
cur_per_channel_value = torch.chunk(
vals["scale_y_accum_quant.col"],
tensor_parallel,
dim=cat_dim)[rank]
else:
cur_per_channel_value = vals["scale_y_accum_quant"]
# QKV is always per_channel
if is_qkv:
if multi_query_mode:
cur_per_channel_value = multi_query_split(
vals["scale_y_accum_quant"], local_dim, head_size,
tensor_parallel, rank)
else:
cur_per_channel_value = torch.chunk(
vals["scale_y_accum_quant"],
tensor_parallel,
dim=cat_dim)[rank]
results[prefix +
'per_channel_scale'] = torch.Tensor(cur_per_channel_value).to(
torch.float32).reshape(col_shape).contiguous().cuda()
results[prefix + 'act_scale'] = torch.Tensor([[
vals['scale_y_quant_orig']
]]).to(torch.float32).contiguous().cuda()
results[last_prefix] = torch.Tensor([vals['scale_x_orig_quant']]).to(
torch.float32).contiguous().cuda()
if smoother_value is not None:
cur_smoother_value = torch.chunk(smoother_value,
tensor_parallel,
dim=cat_dim)[rank]
results[prefix + 'smoother'] = cur_smoother_value.reshape(
smoother_shape).contiguous().to(torch.float32)
if bias is not None:
results[prefix + 'bias'] = bias
return results
def load_hf_llama(model_dir: str, load_model_on_cpu: bool = False):
if "vila" in model_dir:
sys.path.append(model_dir + "/../VILA")
from llava.model import LlavaLlamaConfig, LlavaLlamaModel # noqa
from transformers import AutoModel
model = AutoModel.from_pretrained(
model_dir,
device_map='auto',
trust_remote_code=True,
)
return model.llm
hf_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
model_cls = AutoModelForCausalLM
if hf_config.model_type == "llava":
from transformers import LlavaForConditionalGeneration
model_cls = LlavaForConditionalGeneration
if hf_config.model_type == "llava_next":
from transformers import LlavaNextForConditionalGeneration
model_cls = LlavaNextForConditionalGeneration
model = model_cls.from_pretrained(
model_dir,
device_map='auto' if not load_model_on_cpu else 'cpu',
torch_dtype='auto',
trust_remote_code=True,
)
if hf_config.model_type in ["llava", "llava_next"]:
model = model.language_model
return model
def load_weights_from_hf_model(hf_model,
config: LLaMAConfig,
act_range: Optional[dict] = None,
qkv_para: Optional[dict] = None,
smoother: Optional[dict] = None):
quant_algo = config.quantization.quant_algo
use_weight_only = quant_algo in [QuantAlgo.W8A16, QuantAlgo.W4A16]
if quant_algo == QuantAlgo.W8A16:
plugin_weight_only_quant_type = torch.int8
elif quant_algo == QuantAlgo.W4A16:
plugin_weight_only_quant_type = torch.quint4x2
else:
plugin_weight_only_quant_type = None
use_gemm_woq_plugin = (not config.disable_weight_only_quant_plugin)
use_fp8_rowwise = quant_algo in [QuantAlgo.FP8_PER_CHANNEL_PER_TOKEN]
use_smooth_quant = config.quantization.use_plugin_sq
per_channel = use_smooth_quant and 'PER_CHANNEL' in quant_algo
per_token = use_smooth_quant and 'PER_TOKEN' in quant_algo
int8_kv_cache = config.quantization.kv_cache_quant_algo == QuantAlgo.INT8
fp8_kv_cache = config.quantization.kv_cache_quant_algo == QuantAlgo.FP8
if use_smooth_quant or int8_kv_cache:
assert act_range is not None
assert qkv_para is not None
assert smoother is not None
weights = {}
tik = time.time()
model_params = dict(hf_model.named_parameters())
dtype = getattr(torch, config.dtype)
mapping = config.mapping
moe_config = config.moe
mha_mode = (config.num_key_value_heads == config.num_attention_heads)
layers_range = config.mapping.pp_layers(config.num_hidden_layers)
def convert_layer(l):
prefix = f'model.layers.{l}.'
tllm_prex = f'transformer.layers.{l - layers_range[0]}.'
q_weight = get_weight(model_params, prefix + 'self_attn.q_proj', dtype)
k_weight = get_weight(model_params, prefix + 'self_attn.k_proj', dtype)
v_weight = get_weight(model_params, prefix + 'self_attn.v_proj', dtype)
if not mha_mode:
if config.num_key_value_heads < mapping.tp_size:
# duplicate the KV heads up to tensor_parallel
k_weight = dup_kv_weight(k_weight, config.num_key_value_heads,
mapping.tp_size)
v_weight = dup_kv_weight(v_weight, config.num_key_value_heads,
mapping.tp_size)
assert (k_weight.shape[0] %
(mapping.tp_size * config.head_size)) == 0
assert (v_weight.shape[0] %
(mapping.tp_size * config.head_size)) == 0
wq = split(q_weight, mapping.tp_size, mapping.tp_rank)
wk = split(k_weight, mapping.tp_size, mapping.tp_rank)
wv = split(v_weight, mapping.tp_size, mapping.tp_rank)
split_v = torch.concat((wq, wk, wv))
else:
qkv_weight = torch.cat([q_weight, k_weight, v_weight], dim=0)
split_v = split_qkv_tp(qkv_weight, config.num_attention_heads,
config.hidden_size, mapping.tp_size,
mapping.tp_rank)
if prefix + 'self_attn.q_proj.bias' in model_params:
# only used in Internlm 7B models
q_bias = get_bias(model_params, prefix + 'self_attn.q_proj', dtype)
k_bias = get_bias(model_params, prefix + 'self_attn.k_proj', dtype)
v_bias = get_bias(model_params, prefix + 'self_attn.v_proj', dtype)
qkv_bias = torch.cat((q_bias, k_bias, v_bias))
split_bias_v = split_qkv_bias_tp(qkv_bias,
config.num_attention_heads,
config.hidden_size,
mapping.tp_size, mapping.tp_rank)
else:
split_bias_v = None
if use_smooth_quant:
qkv_weight = qkv_para[prefix + 'self_attn.qkv_proj']
qkv_out_dim = qkv_weight.shape[1]
if not mha_mode:
local_dim = qkv_weight.shape[0]
kv_hidden_size = (qkv_weight.shape[-1] - local_dim) // 2
qkv_weight = qkv_weight.reshape(local_dim,
local_dim + 2 * kv_hidden_size)
else:
qkv_weight = qkv_weight.reshape(config.hidden_size, 3,
config.hidden_size)
int8_weights = generate_int8(qkv_weight,
act_range.get(prefix +
'self_attn.qkv_proj'),
is_qkv=True,
multi_query_mode=bool(not mha_mode))
weights.update(
get_tllm_linear_sq_weight(int8_weights,
tllm_prex + 'attention.qkv.',
[1, qkv_out_dim // mapping.tp_size],
mapping.tp_size,
is_qkv=True,
bias=split_bias_v,
per_token=per_token,
per_channel=per_channel,
last_prefix=tllm_prex +
'input_layernorm.scale_to_int',
smoother_value=None,
smoother_shape=None,
rank=mapping.tp_rank,
cat_dim=-1,
multi_query_mode=bool(not mha_mode)))
else:
weights.update(
get_tllm_linear_weight(split_v,
tllm_prex + 'attention.qkv.',
split_bias_v,
use_weight_only,
plugin_weight_only_quant_type,
dtype,
use_gemm_woq_plugin,
use_fp8_rowwise=False))
if int8_kv_cache:
qkv_y = torch.cat([
act_range.get(prefix + 'self_attn.q_proj')["y"],
act_range.get(prefix + 'self_attn.k_proj')["y"],
act_range.get(prefix + 'self_attn.v_proj')["y"]
],
dim=0)
int8_kv_scales = qkv_y.max() / 127.
kv_cache_weights = {}
kv_cache_weights[
tllm_prex +
'attention.kv_cache_scaling_factor'] = int8_kv_scales.reshape(
[1])
weights.update(kv_cache_weights)
elif fp8_kv_cache:
# FIXME: set it to 1.0f for fp8 kv cache.
weights[tllm_prex +
'attention.kv_cache_scaling_factor'] = torch.tensor(
[1.0], dtype=torch.float32)
attn_dense_weight = get_weight(model_params,
prefix + 'self_attn.o_proj', dtype)
split_v = split_matrix_tp(attn_dense_weight,
mapping.tp_size,
mapping.tp_rank,
dim=1)
if prefix + 'self_attn.o_proj.bias' in model_params:
attn_dense_bias = get_bias(model_params,
prefix + 'self_attn.o_proj', dtype)
else:
attn_dense_bias = None
if use_smooth_quant:
attn_dense_weight = attn_dense_weight.t()
int8_weights = generate_int8(
attn_dense_weight, act_range.get(prefix + 'self_attn.o_proj'))
weights.update(
get_tllm_linear_sq_weight(
int8_weights,
tllm_prex + 'attention.dense.', [1, config.hidden_size],
mapping.tp_size,
is_qkv=False,
bias=attn_dense_bias,
per_token=per_token,
per_channel=per_channel,
last_prefix=tllm_prex +
'attention.quantization_scaling_factor',
smoother_value=smoother[(prefix + 'self_attn.o_proj')],
smoother_shape=[1, config.hidden_size // mapping.tp_size],
rank=mapping.tp_rank,
cat_dim=0))
else:
weights.update(
get_tllm_linear_weight(split_v,
tllm_prex + 'attention.dense.',
attn_dense_bias,
use_weight_only,
plugin_weight_only_quant_type,
dtype,
use_gemm_woq_plugin,
use_fp8_rowwise=False))
if moe_config.has_moe():
rank_experts = list(range(moe_config.num_experts))
if mapping.has_moe_ep():
rank_experts = mapping.ep_experts(moe_config.num_experts)
for suffix in ["w1", "w2", "w3"]:
model_params[f'model.layers.{l}.block_sparse_moe.experts.{suffix}.weight'] = \
torch.stack([model_params[f'model.layers.{l}.block_sparse_moe.experts.{expert}.{suffix}.weight'].detach()
for expert in rank_experts])
w3 = model_params[
f'model.layers.{l}.block_sparse_moe.experts.w3.weight']
w2 = model_params[
f'model.layers.{l}.block_sparse_moe.experts.w2.weight']
w1 = model_params[
f'model.layers.{l}.block_sparse_moe.experts.w1.weight']
if mapping.has_moe_tp():
w3 = split(w3, mapping.moe_tp_size, mapping.moe_tp_rank, dim=1)
w2 = split(w2, mapping.moe_tp_size, mapping.moe_tp_rank, dim=2)
w1 = split(w1, mapping.moe_tp_size, mapping.moe_tp_rank, dim=1)
model_params[
f'model.layers.{l}.block_sparse_moe.experts.w3w1.weight'] = torch.concat(
[w3, w1], dim=-2)
model_params[
f'model.layers.{l}.block_sparse_moe.experts.w2.weight'] = w2
## block_sparse_moe.experts.w2.weight
moe_experts_w2_weights = get_weight(
model_params, prefix + 'block_sparse_moe.experts.w2', dtype)
weights.update(
get_tllm_linear_weight(moe_experts_w2_weights,
tllm_prex + 'mlp.proj.', None,
use_weight_only,
plugin_weight_only_quant_type, dtype,
use_gemm_woq_plugin))
##block_sparse_moe.experts.w3w1.weight
moe_experts_w3w1_weights = get_weight(
model_params, prefix + 'block_sparse_moe.experts.w3w1', dtype)
weights.update(
get_tllm_linear_weight(moe_experts_w3w1_weights,
tllm_prex + 'mlp.fc.', None,
use_weight_only,
plugin_weight_only_quant_type, dtype,
use_gemm_woq_plugin))
if config.residual_mlp:
residual_mlp_gate_weights = get_weight(
model_params, prefix + 'residual_mlp.w3', dtype)
if use_smooth_quant:
residual_mlp_gate_weights = residual_mlp_gate_weights.t()
int8_weights = generate_int8(
residual_mlp_gate_weights,
act_range.get(prefix + 'residual_mlp.w3'))
weights.update(
get_tllm_linear_sq_weight(
int8_weights,
tllm_prex + 'residual_mlp.gate.',
[1, config.hidden_size // mapping.tp_size],
mapping.tp_size,
is_qkv=False,
per_token=per_token,
per_channel=per_channel,
last_prefix=tllm_prex +
'post_layernorm.scale_to_int',
smoother_value=None,
smoother_shape=None,
rank=mapping.tp_rank,
cat_dim=-1))
else:
split_v = split_matrix_tp(residual_mlp_gate_weights,
mapping.tp_size,
mapping.tp_rank,
dim=0)
weights.update(
get_tllm_linear_weight(split_v,
tllm_prex + 'residual_mlp.gate.',
None, use_weight_only,
plugin_weight_only_quant_type,
dtype, use_gemm_woq_plugin))
residual_mlp_fc_weight = get_weight(model_params,
prefix + 'residual_mlp.w1',
dtype)
if use_smooth_quant:
residual_mlp_fc_weight = residual_mlp_fc_weight.t(
) #verified
int8_weights = generate_int8(
residual_mlp_fc_weight,
act_range.get(prefix + 'residual_mlp.w1'))
weights.update(
get_tllm_linear_sq_weight(
int8_weights,
tllm_prex + 'residual_mlp.fc.',
[1, config.hidden_size // mapping.tp_size],
mapping.tp_size,
is_qkv=False,
per_token=per_token,
per_channel=per_channel,
last_prefix=tllm_prex +
'post_layernorm.scale_to_int',
smoother_value=None,
smoother_shape=None,
rank=mapping.tp_rank,
cat_dim=-1))
else:
split_v = split_matrix_tp(residual_mlp_fc_weight,
mapping.tp_size,
mapping.tp_rank,
dim=0)
weights.update(
get_tllm_linear_weight(split_v,
tllm_prex + 'residual_mlp.fc.',
None, use_weight_only,
plugin_weight_only_quant_type,
dtype, use_gemm_woq_plugin))
residual_mlp_proj_weight = get_weight(
model_params, prefix + 'residual_mlp.w2', dtype)
if use_smooth_quant:
residual_mlp_proj_weight = residual_mlp_proj_weight.t()
int8_weights = generate_int8(
residual_mlp_proj_weight,
act_range.get(prefix + 'residual_mlp.w2'))
weights.update(
get_tllm_linear_sq_weight(
int8_weights,
tllm_prex + 'residual_mlp.proj.',
[1, config.hidden_size],
mapping.tp_size,
is_qkv=False,
per_token=per_token,
per_channel=per_channel,
last_prefix=tllm_prex +
'residual_mlp.quantization_scaling_factor',
smoother_value=smoother[prefix + 'residual_mlp.w2'],
smoother_shape=[
1, config.hidden_size // mapping.tp_size
],
rank=mapping.tp_rank,
cat_dim=0))
else:
split_v = split_matrix_tp(residual_mlp_proj_weight,
mapping.tp_size,
mapping.tp_rank,
dim=1)
weights.update(
get_tllm_linear_weight(split_v,
tllm_prex + 'residual_mlp.proj.',
None, use_weight_only,
plugin_weight_only_quant_type,
dtype, use_gemm_woq_plugin))
moe_experts_gate_weights = get_weight(
model_params, prefix + 'block_sparse_moe.gate', torch.float32)
weights.update(
get_tllm_linear_weight(
moe_experts_gate_weights,
tllm_prex + 'mlp.router.',
None,
False, # Router should never be quantized
plugin_weight_only_quant_type,
dtype,
use_gemm_woq_plugin))
else:
mlp_gate_weight = get_weight(model_params, prefix + 'mlp.up_proj',
dtype)
split_v = split_matrix_tp(mlp_gate_weight,
mapping.tp_size,
mapping.tp_rank,
dim=0)
if use_smooth_quant:
mlp_gate_weight = mlp_gate_weight.t()
int8_weights = generate_int8(
mlp_gate_weight, act_range.get(prefix + 'mlp.up_proj'))
weights.update(
get_tllm_linear_sq_weight(
int8_weights,
tllm_prex + 'mlp.gate.',
[1, config.intermediate_size // mapping.tp_size],
mapping.tp_size,
is_qkv=False,
per_token=per_token,
per_channel=per_channel,
last_prefix=tllm_prex + 'post_layernorm.scale_to_int',
smoother_value=None,
smoother_shape=None,
rank=mapping.tp_rank,
cat_dim=-1))
else:
weights.update(
get_tllm_linear_weight(split_v, tllm_prex + 'mlp.gate.',
None, use_weight_only,
plugin_weight_only_quant_type, dtype,
use_gemm_woq_plugin,
use_fp8_rowwise))
mlp_fc_weight = get_weight(model_params, prefix + 'mlp.gate_proj',
dtype)
split_v = split_matrix_tp(mlp_fc_weight,
mapping.tp_size,
mapping.tp_rank,
dim=0)
if use_smooth_quant:
mlp_fc_weight = mlp_fc_weight.t() #verified
int8_weights = generate_int8(
mlp_fc_weight, act_range.get(prefix + 'mlp.gate_proj'))
weights.update(
get_tllm_linear_sq_weight(
int8_weights,
tllm_prex + 'mlp.fc.',
[1, config.intermediate_size // mapping.tp_size],
mapping.tp_size,
is_qkv=False,
per_token=per_token,
per_channel=per_channel,
last_prefix=tllm_prex + 'post_layernorm.scale_to_int',
smoother_value=None,
smoother_shape=None,
rank=mapping.tp_rank,
cat_dim=-1))
else:
weights.update(
get_tllm_linear_weight(split_v, tllm_prex + 'mlp.fc.', None,
use_weight_only,
plugin_weight_only_quant_type, dtype,
use_gemm_woq_plugin,
use_fp8_rowwise))
mlp_proj_weight = get_weight(model_params, prefix + 'mlp.down_proj',
dtype)
split_v = split_matrix_tp(mlp_proj_weight,
mapping.tp_size,
mapping.tp_rank,
dim=1)
if use_smooth_quant:
mlp_proj_weight = mlp_proj_weight.t()
int8_weights = generate_int8(
mlp_proj_weight, act_range.get(prefix + 'mlp.down_proj'))
weights.update(
get_tllm_linear_sq_weight(
int8_weights,
tllm_prex + 'mlp.proj.', [1, config.hidden_size],
mapping.tp_size,
is_qkv=False,
per_token=per_token,
per_channel=per_channel,
last_prefix=tllm_prex +
'mlp.quantization_scaling_factor',
smoother_value=smoother[prefix + 'mlp.down_proj'],
smoother_shape=[
1, config.intermediate_size // mapping.tp_size
],
rank=mapping.tp_rank,
cat_dim=0))
else:
weights.update(
get_tllm_linear_weight(split_v, tllm_prex + 'mlp.proj.',
None, use_weight_only,
plugin_weight_only_quant_type, dtype,
use_gemm_woq_plugin,
use_fp8_rowwise))
# Layer norms do not use tensor parallelism
input_ln_weight = get_weight(model_params, prefix + 'input_layernorm',
dtype)
weights[tllm_prex + 'input_layernorm.weight'] = input_ln_weight
post_ln_weight = get_weight(model_params,
prefix + 'post_attention_layernorm', dtype)
weights[tllm_prex + 'post_layernorm.weight'] = post_ln_weight
if config.residual_mlp:
residual_ln_weight = get_weight(model_params,
prefix + 'residual_layernorm',
dtype)
weights[tllm_prex +
'residual_layernorm.weight'] = residual_ln_weight
cur_block_weights = [
weight_name for weight_name in model_params
if weight_name.find(prefix) != -1
]
for weight_name in cur_block_weights:
model_params[weight_name] = None
for l in layers_range:
convert_layer(l)
release_gc()
v = get_weight(model_params, 'model.embed_tokens', dtype)
if hf_model.config.tie_word_embeddings:
# lm_head.weight has the same weights as embedding
if mapping.is_last_pp_rank():
if config.vocab_size % mapping.tp_size != 0:
# padding
vocab_size_padded = pad_vocab_size(config.vocab_size,
mapping.tp_size)
pad_width = vocab_size_padded - config.vocab_size
v = torch.nn.functional.pad(v, (0, 0, 0, pad_width), 'constant',
0)
weights['lm_head.weight'] = split(v, mapping.tp_size,
mapping.tp_rank)
if config.use_parallel_embedding:
v = split_matrix_tp(v,
mapping.tp_size,
mapping.tp_rank,
dim=config.embedding_sharding_dim)
if mapping.is_first_pp_rank():
weights['transformer.vocab_embedding.weight'] = v
lm_head_weights = get_weight(model_params, 'lm_head', dtype)
if mapping.is_last_pp_rank():
if config.vocab_size % mapping.tp_size != 0:
# padding
vocab_size_padded = pad_vocab_size(config.vocab_size,
mapping.tp_size)
pad_width = vocab_size_padded - config.vocab_size
lm_head_weights = torch.nn.functional.pad(lm_head_weights,
(0, 0, 0, pad_width),
'constant',
value=0)
weights['lm_head.weight'] = split_matrix_tp(lm_head_weights,
mapping.tp_size,
mapping.tp_rank,
dim=0)
ln_f_w = get_weight(model_params, 'model.norm', dtype)
weights['transformer.ln_f.weight'] = ln_f_w
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
print(f'Weights loaded. Total time: {t}')
return weights
def smooth_quant(model,
tokenizer,
dataset,
smoothquant: Optional[float] = None):
assert model is not None
act_range = {}
llama_qkv_para = {}
# smoother for inputs of self_attn.o_proj and mlp.down_proj
llama_smoother = {}
act_range = capture_activation_range(model, tokenizer, dataset)
if smoothquant is not None:
smooth_llama_model(model, act_range, smoothquant, llama_qkv_para,
llama_smoother)
return act_range, llama_qkv_para, llama_smoother
def quantize(hf_model_dir: str,
output_dir: str,
config: LLaMAConfig,
device: str = 'cuda',
calib_dataset: str = 'cnn_dailymail'):
'''
Quantize the save the model as TRT-LLM checkpoint to output_dir
'''
#TODO: currently only smooth quant and kv cache quantization are supported, needs to support mode quant algorithm calling modelopt
with open(os.path.join(output_dir, 'config.json'), 'w') as f:
json.dump(config.to_dict(), f, indent=4)
mapping = config.mapping
assert mapping.rank == -1, "You shall call quantize only once in one rank, assert rank==-1 for precaution"
quant_config = config.quantization
use_smooth_quant = quant_config.use_plugin_sq
int8_kv_cache = quant_config.kv_cache_quant_algo == QuantAlgo.INT8
assert use_smooth_quant or int8_kv_cache, "Call from_hugging_face when there is no quantization"
if use_smooth_quant:
assert quant_config.smoothquant_val is not None, "A smooth value must be specified when using smooth quant"
assert hf_model_dir is not None
## only load and call smooth quant routine once for all ranks
hf_config = AutoConfig.from_pretrained(hf_model_dir, trust_remote_code=True)
assert "llava" not in hf_config.model_type, "Smooth quant llava/vila/llava_next is not supported yet"
hf_model = AutoModelForCausalLM.from_pretrained(
hf_model_dir,
device_map='auto' if device != 'cpu' else 'cpu',
torch_dtype='auto' if not use_smooth_quant else torch.float16,
trust_remote_code=True)
os.environ["TOKENIZERS_PARALLELISM"] = os.environ.get(
"TOKENIZERS_PARALLELISM", "false")
tokenizer = AutoTokenizer.from_pretrained(hf_model_dir,
trust_remote_code=True,
use_fast=False,
padding_side='left')
dataset = load_calib_dataset(calib_dataset)
act_range, qkv_para, smoother = smooth_quant(hf_model, tokenizer, dataset,
quant_config.smoothquant_val)
for rank in range(mapping.world_size):
# To avoid changing the mapping arg in-place, also the given mapping from caller is rank agnostic, since quantize is called from only one rank
config = copy.deepcopy(config)
config.set_rank(rank)
weights = load_weights_from_hf_model(
hf_model,
config=config,
act_range=act_range,
qkv_para=qkv_para,
smoother=smoother,
)
safetensors.torch.save_file(
weights, os.path.join(output_dir, f'rank{rank}.safetensors'))
del weights
class QkvWeightHelper:
""" A helper utility for loading QKV weights from sharded files. """
def __init__(self, config: PretrainedConfig):
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.num_kv_heads = config.num_key_value_heads
self.tp_size = config.mapping.tp_size
self.tp_rank = config.mapping.tp_rank
self.is_mha = self.num_heads == self.num_kv_heads
self.head_size = None if not hasattr(config,
"head_size") else config.head_size
self._qkv_weights = {}
@staticmethod
def is_qkv_weight(name):
for k in ['q_proj', 'k_proj', 'v_proj']:
if 'self_attn' in name and k in name:
return True
return False
def add_weight(self, i: int, name: str, weight: torch.Tensor):
if 'q_proj' in name:
tag = 'q'
elif 'k_proj' in name:
tag = 'k'
elif 'v_proj' in name:
tag = 'v'
else:
raise ValueError(f'Got an unexpected parameter of name {name}')
if i not in self._qkv_weights:
self._qkv_weights[i] = {}
self._qkv_weights[i][tag] = weight
def is_qkv_prepared(self, layer_idx):
if layer_idx not in self._qkv_weights:
return False
weights = self._qkv_weights[layer_idx]
return 'q' in weights and 'k' in weights and 'v' in weights
def split_qkv_weights(self, layer_idx):
if not self.is_qkv_prepared(layer_idx):
return None
weights = self._qkv_weights.pop(layer_idx) # to prevent memory leak.
q, k, v = (torch.tensor(weights[t]) for t in ['q', 'k', 'v'])
if not self.is_mha:
head_size = self.hidden_size // self.num_heads if self.head_size is None else self.head_size
if self.num_kv_heads < self.tp_size:
# duplicate the KV heads up to tensor_parallel
k = dup_kv_weight(k, self.num_kv_heads, self.tp_size)
v = dup_kv_weight(v, self.num_kv_heads, self.tp_size)
assert k.shape[0] % (self.tp_size * head_size) == 0
assert v.shape[0] % (self.tp_size * head_size) == 0
wq = split(q, self.tp_size, self.tp_rank)
wk = split(k, self.tp_size, self.tp_rank)
wv = split(v, self.tp_size, self.tp_rank)
fused_qkv = torch.cat((wq, wk, wv), dim=0)
else:
qkv = torch.cat([q, k, v], dim=0)
qkv = qkv.reshape(3, q.shape[0], q.shape[1])
fused_qkv = split(qkv, self.tp_size, self.tp_rank, dim=1)
fused_qkv = fused_qkv.reshape(3 * (q.shape[0] // self.tp_size),
q.shape[1])
return fused_qkv
def load_weights_from_hf_by_shard(model_dir: str, config: LLaMAConfig):
'''Weights-only quantization is the only supported quantization recipe here.'''
logger.info('Loading weights from HF LLaMA...')
quant_algo = config.quantization.quant_algo
use_weight_only = quant_algo in [QuantAlgo.W8A16, QuantAlgo.W4A16]
if quant_algo == QuantAlgo.W8A16:
plugin_weight_only_quant_type = torch.int8
elif quant_algo == QuantAlgo.W4A16:
plugin_weight_only_quant_type = torch.quint4x2
else:
plugin_weight_only_quant_type = None
weights = {}
tik = time.time()
dtype = getattr(torch, config.dtype)
mapping = config.mapping
moe_config = config.moe
assert not moe_config.has_moe(), "MoE does not support sharded load"
from transformers import AutoConfig
hf_config = AutoConfig.from_pretrained(model_dir)
quant_mode = config.quant_mode
if quant_mode.is_int8_weight_only():
plugin_weight_only_quant_type = torch.int8
elif quant_mode.is_int4_weight_only():
plugin_weight_only_quant_type = torch.quint4x2
elif config.quant_mode.has_fp8_rowwise():
plugin_weight_only_quant_type = torch.float8_e4m3fn
else:
plugin_weight_only_quant_type = None
use_weight_only = quant_mode.is_weight_only()
use_fp8_rowwise = quant_mode.has_fp8_rowwise()
layers_range = mapping.pp_layers(config.num_hidden_layers)
qkv_weight_helper = QkvWeightHelper(config)
for model_file in iterate_shard_files(model_dir,
rank=mapping.tp_rank,
progress_bar=False):
logger.debug(f'Loading file {str(model_file)}...')
model_params = load_state_dict(model_file, dtype=dtype)
for name, param in model_params.items():
logger.debug(f'Converting weight {name}...')
layer_idx = retrieved_layer_index_from_name(name)
tllm_prex = f'transformer.layers.{layer_idx}.'
if layer_idx is None:
layer = None
else:
if layer_idx not in layers_range:
continue
else:
tllm_prex = f'transformer.layers.{layer_idx - layers_range[0]}.'
if 'model.embed_tokens.weight' in name:
if hf_config.tie_word_embeddings:
# lm_head.weight has the same weights as embedding
if mapping.is_last_pp_rank():
if config.vocab_size % mapping.tp_size != 0:
# padding
vocab_size_padded = pad_vocab_size(
config.vocab_size, mapping.tp_size)
pad_width = vocab_size_padded - config.vocab_size
param = torch.from_numpy(
np.pad(param.detach().cpu().numpy(),
((0, pad_width), (0, 0)),
'constant',
constant_values=0))
weights['lm_head.weight'] = split(
param, mapping.tp_size, mapping.tp_rank)
if config.use_parallel_embedding:
param = split(param, mapping.tp_size, mapping.tp_rank,
config.embedding_sharding_dim)
if mapping.is_first_pp_rank():
weights['transformer.vocab_embedding.weight'] = param
elif 'model.norm.weight' in name:
if mapping.is_last_pp_rank():
weights['transformer.ln_f.weight'] = param
elif 'lm_head.weight' in name:
if mapping.is_last_pp_rank():
if config.vocab_size % mapping.tp_size != 0:
# padding
vocab_size_padded = pad_vocab_size(
config.vocab_size, mapping.tp_size)
pad_width = vocab_size_padded - config.vocab_size
param = torch.from_numpy(
np.pad(param.detach().cpu().numpy(),
((0, pad_width), (0, 0)),
'constant',
constant_values=0))
weights['lm_head.weight'] = split(param, mapping.tp_size,
mapping.tp_rank)
elif 'input_layernorm.weight' in name:
weights[tllm_prex + 'input_layernorm.weight'] = param
elif 'post_attention_layernorm.weight' in name:
weights[tllm_prex + 'post_layernorm.weight'] = param
elif qkv_weight_helper.is_qkv_weight(name):
qkv_weight_helper.add_weight(layer_idx, name, param)
if not qkv_weight_helper.is_qkv_prepared(layer_idx):
continue
split_v = qkv_weight_helper.split_qkv_weights(layer_idx)
if use_weight_only:
param = split_v.transpose()
processed_torch_weights, torch_weight_scales = \
torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix(
param, plugin_weight_only_quant_type)
weights[tllm_prex +
'attention.qkv.weight'] = processed_torch_weights
weights[
tllm_prex +
'attention.qkv.per_channel_scale'] = torch_weight_scales
else:
weights[tllm_prex + 'attention.qkv.weight'] = split_v
elif 'self_attn.o_proj.weight' in name:
split_v = split(param, mapping.tp_size, mapping.tp_rank, dim=1)
if use_weight_only:
processed_torch_weights, torch_weight_scales = \
torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix(
split_v.transpose(), plugin_weight_only_quant_type)
weights[tllm_prex +
'attention.dense.weight'] = processed_torch_weights
weights[
tllm_prex +
'attention.dense.per_channel_scale'] = torch_weight_scales
else:
weights[tllm_prex + 'attention.dense.weight'] = split_v
elif 'mlp.up_proj.weight' in name:
split_v = split(param, mapping.tp_size, mapping.tp_rank, dim=0)
if use_weight_only:
processed_torch_weights, torch_weight_scales = \
torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix(
split_v.transpose(), plugin_weight_only_quant_type)
weights[tllm_prex +
'mlp.gate.weight'] = processed_torch_weights
weights[tllm_prex +
'mlp.gate.per_channel_scale'] = torch_weight_scales
elif use_fp8_rowwise:
processed_torch_weights, torch_weight_scales = fp8_per_channel_quant_weight_gpu(
split_v)
weights[tllm_prex +
'mlp.gate.weight'] = processed_torch_weights.view(
plugin_weight_only_quant_type)
weights[
tllm_prex +
'mlp.gate.per_channel_scale'] = torch_weight_scales.to(
torch.float32)
else:
weights[tllm_prex + 'mlp.gate.weight'] = split_v
elif 'mlp.down_proj.weight' in name:
split_v = split(param, mapping.tp_size, mapping.tp_rank, dim=1)
if use_weight_only:
processed_torch_weights, torch_weight_scales = \
torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix(
split_v.transpose(), plugin_weight_only_quant_type)
weights[tllm_prex +
'mlp.proj.weight'] = processed_torch_weights
weights[tllm_prex +
'mlp.proj.per_channel_scale'] = torch_weight_scales
elif use_fp8_rowwise:
processed_torch_weights, torch_weight_scales = fp8_per_channel_quant_weight_gpu(
split_v)
weights[tllm_prex +
'mlp.proj.weight'] = processed_torch_weights.view(
plugin_weight_only_quant_type)
weights[
tllm_prex +
'mlp.proj.per_channel_scale'] = torch_weight_scales.to(
torch.float32)
else:
weights[tllm_prex + 'mlp.proj.weight'] = split_v
elif 'mlp.gate_proj.weight' in name:
split_v = split(param, mapping.tp_size, mapping.tp_rank, dim=0)
if use_weight_only:
processed_torch_weights, torch_weight_scales = \
torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix(
split_v.transpose(), plugin_weight_only_quant_type)
layer.mlp.fc.weight.value = processed_torch_weights
layer.mlp.fc.per_channel_scale.value = torch_weight_scales
weights[tllm_prex +
'mlp.fc.weight'] = processed_torch_weights
weights[tllm_prex +
'mlp.fc.per_channel_scale'] = torch_weight_scales
elif use_fp8_rowwise:
processed_torch_weights, torch_weight_scales = fp8_per_channel_quant_weight_gpu(
split_v)
weights[tllm_prex +
'mlp.fc.weight'] = processed_torch_weights.view(
plugin_weight_only_quant_type)
weights[
tllm_prex +
'mlp.fc.per_channel_scale'] = torch_weight_scales.to(
torch.float32)
else:
weights[tllm_prex + 'mlp.fc.weight'] = split_v
del model_params
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
logger.info(f'Weights loaded. Total time: {t}')
return weights
def load_weights_from_hf_safetensors(model_dir: str, config: LLaMAConfig):
logger.info('Loading weights from Huggingface LLaMA safetensors...')
tik = time.time()
import json
import os
import safetensors
weights = {}
model_dir = model_dir if model_dir.endswith("/") else model_dir + "/"
safetensors_map = {}
has_safetensor_index_json = True
try:
with open(model_dir + "model.safetensors.index.json", 'r') as fr:
sharding_map = json.load(fr)
for k, v in sharding_map['weight_map'].items():
safetensors_map[k] = int(v[6:11]) - 1
except FileNotFoundError:
has_safetensor_index_json = False
shard_files = []
for name in os.listdir(model_dir):
if name.endswith(".safetensors"):
if has_safetensor_index_json and name not in sharding_map[
'weight_map'].values():
continue
shard_files.append(name)
shard_files.sort()
safetensors_ptrs = [
safetensors.safe_open(model_dir + shard_file,
framework="pt",
device="cpu") for shard_file in shard_files
]
mapping = config.mapping
num_hidden_layers = config.num_hidden_layers
vocab_size = config.vocab_size
pad_vocab = vocab_size % mapping.tp_size != 0
vocab_size_padded = pad_vocab_size(config.vocab_size, mapping.tp_size)
dtype = config.dtype
moe_config = config.moe
kv_tp_size = None
kv_tp_rank = None
if config.num_key_value_heads < mapping.tp_size:
kv_tp_size = config.num_key_value_heads
kv_tp_rank = mapping.tp_rank * kv_tp_size // mapping.tp_size
model_prefix = "model."
key_list = [
"embed_tokens.weight", # vocab_embedding
"lm_head.weight", # lm_head
"norm.weight", # ln_f
"self_attn.", # attention.qkv
"_proj.weight", # qkv suffix
"self_attn.o_proj.weight", # attention.dense
"mlp.up_proj.weight", # mlp.gate
"mlp.down_proj.weight", # mlp.proj
"mlp.gate_proj.weight", # mlp.fc
"input_layernorm.weight", # input_layernorm
"post_attention_layernorm.weight", # post_layernorm
]
torch_dtype = str_dtype_to_torch(dtype)
def load(key,
tp_dim=-1,
no_prefix=0,
is_expert_weights=False,
tp_size=None,
tp_rank=None):
if not no_prefix:
key = model_prefix + key
ptr_idx = safetensors_map[key] if key in safetensors_map else 0
if key not in safetensors_ptrs[ptr_idx].keys():
return None
tensor_slice = safetensors_ptrs[ptr_idx].get_slice(key)
tensor_shape = tensor_slice.get_shape()
if tp_dim == -1:
res = tensor_slice[:]
elif tp_dim >= 0 and tp_dim < len(tensor_shape):
if is_expert_weights:
tp_size = tp_size or mapping.moe_tp_size
tp_rank = tp_rank or mapping.moe_tp_rank
else:
tp_size = tp_size or mapping.tp_size
tp_rank = tp_rank or mapping.tp_rank
dim_size = tensor_shape[tp_dim]
if dim_size % tp_size != 0:
logger.error(
f"Current weight {key}'s shape {tensor_shape} is invalid at dimension {tp_dim} for TP size {tp_size}"
)
indices = [slice(None)] * len(tensor_shape)
indices[tp_dim] = slice(dim_size * tp_rank // tp_size,
dim_size * (tp_rank + 1) // tp_size)
res = tensor_slice[indices]
else:
raise ValueError(
f"Invalid TP dim {tp_dim} for weight {key}'s shape {tensor_shape}"
)
return res.to(torch_dtype).contiguous(
) if "block_sparse_moe.gate" not in key else res.to(torch.float32)
def load_and_set(target,
key,
tp_dim=-1,
no_prefix=0,
is_expert_weights=False):
res = load(key, tp_dim, no_prefix, is_expert_weights)
weights[target] = res
if "weight" in key:
bias = load(key.replace("weight", "bias"), -1, no_prefix,
is_expert_weights)
if bias is not None:
weights[target.replace("weight", "bias")] = bias
if mapping.is_first_pp_rank():
weights['transformer.vocab_embedding.weight'] = load(
key_list[0], config.embedding_sharding_dim
if config.use_parallel_embedding else -1) # vocab_embedding
if mapping.is_last_pp_rank():
v = load(key_list[1], -1, 1) if pad_vocab else load(key_list[1], 0,
1) # lm_head
if pad_vocab:
v = torch.nn.functional.pad(
v, (0, 0, 0, vocab_size_padded - vocab_size), 'constant', 0)
v = split(v, mapping.tp_size, mapping.tp_rank)
weights['lm_head.weight'] = v
weights['transformer.ln_f.weight'] = load(key_list[2]) # ln_f
layers_range = mapping.pp_layers(num_hidden_layers)
for l in layers_range:
layer_idx = l - layers_range[0]
prefix = f'layers.{l}.'
tllm_prex = f'transformer.layers.{layer_idx}'
# Attention
qkv_list = []
for comp in ["q", "k", "v"]:
tp_size = kv_tp_size if comp != "q" else None
tp_rank = kv_tp_rank if comp != "q" else None
weight_part = load(prefix + key_list[3] + comp + key_list[4],
0,
tp_size=tp_size,
tp_rank=tp_rank)
qkv_list.append(weight_part)
bias_part = load(
(prefix + key_list[3] + comp + key_list[4]).replace(
"weight", "bias"),
0,
tp_size=tp_size,
tp_rank=tp_rank)
if bias_part is not None:
qkv_list.append(bias_part)
if len(qkv_list) == 3:
# No bias
weights[f'{tllm_prex}.attention.qkv.weight'] = torch.cat(
qkv_list, 0)
else:
weights[f'{tllm_prex}.attention.qkv.weight'] = torch.cat(
qkv_list[::2], 0)
weights[f'{tllm_prex}.attention.qkv.bias'] = torch.cat(
qkv_list[1::2], 0)
load_and_set(f'{tllm_prex}.attention.dense.weight',
prefix + key_list[5], 1) # attention.dense
# MLP
if not moe_config.has_moe():
load_and_set(f'{tllm_prex}.mlp.gate.weight', prefix + key_list[6],
0) # mlp.gate
load_and_set(f'{tllm_prex}.mlp.proj.weight', prefix + key_list[7],
1) # mlp.proj
load_and_set(f'{tllm_prex}.mlp.fc.weight', prefix + key_list[8],
0) # mlp.fc
else:
weights[f'{tllm_prex}.mlp.router.weight'] = load(
prefix + 'block_sparse_moe.gate.weight')
rank_experts = list(range(moe_config.num_experts))
if mapping.has_moe_ep():
rank_experts = mapping.ep_experts(moe_config.num_experts)
expert_weight_list = []
for suffix in range(3):
tp_dim = -1
if mapping.has_moe_tp():
tp_dim = 1 if suffix == 1 else 0
expert_weight_list.append(
torch.stack(
list(
load(
prefix +
f'block_sparse_moe.experts.{expert}.w{suffix + 1}.weight',
tp_dim=tp_dim,
is_expert_weights=True)
for expert in rank_experts)))
w1 = expert_weight_list[0]
w2 = expert_weight_list[1]
w3 = expert_weight_list[2]
weights[f'{tllm_prex}.mlp.fc.weight'] = \
torch.concat([w3, w1], dim=-2).contiguous()
weights[f'{tllm_prex}.mlp.proj.weight'] = w2.contiguous()
load_and_set(f'{tllm_prex}.input_layernorm.weight',
prefix + key_list[9]) # input_layernorm
load_and_set(f'{tllm_prex}.post_layernorm.weight',
prefix + key_list[10]) # post_layernorm
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
logger.info(f'Weights loaded. Total time: {t}')
return weights
def load_weights_from_gptq(quant_ckpt_path: str, config: LLaMAConfig):
logger.info('Loading weights from groupwise GPTQ LLaMA safetensors...')
weights = {}
tik = time.time()
num_hidden_layers = config.num_hidden_layers
vocab_size = config.vocab_size
dtype = config.dtype
mapping = config.mapping
gptq_llama = safetensors.safe_open(quant_ckpt_path,
framework="pt",
device=0)
gptq_prefix = "model."
gptq_suffix_list = [".qweight", ".qzeros", ".scales"]
gptq_key_list = [
"embed_tokens.weight", # vocab_embedding
"lm_head.weight", # lm_head
"norm.weight", # ln_f
"self_attn.", # attention.qkv
"_proj", # qkv suffix
"self_attn.o_proj", # attention.dense
"mlp.up_proj", # mlp.gate
"mlp.down_proj", # mlp.proj
"mlp.gate_proj", # mlp.fc
"input_layernorm.weight", # input_layernorm
"post_attention_layernorm.weight", # post_layernorm
]
split_sym = "."
packer = torch.ops.trtllm.pack_int8_tensor_to_packed_int4
preprocessor = torch.ops.trtllm.preprocess_weights_for_mixed_gemm
torch_dtype = str_dtype_to_torch(dtype)
def load(key, no_prefix=0):
if no_prefix:
return gptq_llama.get_tensor(key)
else:
return gptq_llama.get_tensor(gptq_prefix + key)
def torch_split(v, dim):
if v.shape[dim] % mapping.tp_size != 0:
logger.error(
"Current weight shape is invalid for mapping.tp_size=" +
str(mapping.tp_size))
assert False, "Invalid TP size"
return v.split(v.shape[dim] // mapping.tp_size,
dim=dim)[mapping.tp_rank]
def unpack_int32_into_int8(w_packed):
# Unpack inputs packed in int32/float32 into uint4 and store them in int8 format
w_packed_int4x2 = w_packed.contiguous().view(torch.uint8)
w_unpacked = torch.zeros(w_packed_int4x2.shape[0],
w_packed_int4x2.shape[1] * 2,
dtype=torch.int8)
w_unpacked[:, ::2] = w_packed_int4x2 % 16
w_unpacked[:, 1::2] = w_packed_int4x2 // 16
return w_unpacked.contiguous()
def process_and_assign_weight(v: List[torch.Tensor],
tllm_prex: str,
tp_dim: int = -1):
if tp_dim == -1:
qweight_int32, qzeros_int32, scales_fp16 = [
item.cpu() for item in v
]
else:
qweight_int32, qzeros_int32, scales_fp16 = [
torch_split(item, tp_dim).cpu() for item in v
]
USE_UINT4_INPUT = 1 # Set to true if checkpoint store UINT4 weights
USE_GPTQ_FOR_LLAMA = 1 # GPTQ-for-LLaMA added 1 to zeros
qweight_unpacked_int8 = unpack_int32_into_int8(
qweight_int32.T).T.contiguous() - 8
qweight_interleaved = preprocessor(packer(qweight_unpacked_int8),
torch.quint4x2,
torch.float16).view(torch.float16)
# zeros = zeros * scales
qzeros_unpacked_int32 = unpack_int32_into_int8(qzeros_int32)
if not USE_UINT4_INPUT:
# Correcting UINT4 values back to INT4 order
mask_negative = qzeros_unpacked_int32[qzeros_unpacked_int32 < 0]
mask_positive = qzeros_unpacked_int32[qzeros_unpacked_int32 >= 0]
qzeros_unpacked_int32 = qzeros_unpacked_int32 + 16 * mask_negative - 16 * mask_positive
zeros_x_scales_fp16 = (-qzeros_unpacked_int32 + 8 * USE_UINT4_INPUT -
USE_GPTQ_FOR_LLAMA) * scales_fp16
zeros_x_scales_fp16 = zeros_x_scales_fp16.half()
results = {
f'{tllm_prex}.weight': qweight_interleaved,
f'{tllm_prex}.weights_scaling_factor': scales_fp16,
f'{tllm_prex}.zero': zeros_x_scales_fp16,
}
return results
# Load weights from GPTQ checkpoint into TRT-LLM module
# 1. vocab_embedding
v = load(gptq_key_list[0])
if mapping.is_first_pp_rank():
# tensorrt_llm_llama.vocab_embedding.weight.value = v.to(
# torch_dtype).cpu().numpy()
weights['transformer.vocab_embedding.weight'] = v.to(torch_dtype)
# 2. lm_head
v = load(gptq_key_list[1], "no_prefix")
if mapping.is_last_pp_rank():
# tensorrt_llm_llama.lm_head.weight.value = torch_split(
# v, 0).to(torch_dtype).cpu().numpy()
if vocab_size % mapping.tp_size != 0:
# padding
vocab_size_padded = pad_vocab_size(vocab_size, mapping.tp_size)
pad_width = vocab_size_padded - vocab_size
v = torch.from_numpy(
np.pad(v.detach().cpu().numpy(), ((0, pad_width), (0, 0)),
'constant',
constant_values=0))
weights['lm_head.weight'] = torch_split(v, 0).to(torch_dtype)
# 3. ln_f
v = load(gptq_key_list[2])
if mapping.is_last_pp_rank():
# tensorrt_llm_llama.ln_f.weight.value = v.to(torch_dtype).cpu().numpy()
weights['transformer.ln_f.weight'] = v.to(torch_dtype)
# 4. Weights inside each layer
layers_range = mapping.pp_layers(num_hidden_layers)
for l in layers_range:
layer_idx = l - layers_range[0]
prefix = "layers" + split_sym + str(layer_idx) + split_sym
logger.info(f'Process weights in layer: {layer_idx}')
# layer = tensorrt_llm_llama.layers[layer_idx]
tllm_prex = f'transformer.layers.{l-layers_range[0]}'
# 4.1 attention.qkv
qkv_weight_list = []
for suf in gptq_suffix_list:
qkv_list = []
for comp in ["q", "k", "v"]:
comp_part = load(prefix + gptq_key_list[3] + comp +
gptq_key_list[4] + suf)
comp_part = torch_split(comp_part, 1)
qkv_list.append(comp_part)
qkv_weight_list.append(torch.cat(qkv_list, dim=1))
# process_and_assign_weight(layer.attention.qkv, qkv_weight_list)
weights.update(
process_and_assign_weight(qkv_weight_list,
f'{tllm_prex}.attention.qkv'))
# 4.2 attention.dense
v = [load(prefix + gptq_key_list[5] + suf) for suf in gptq_suffix_list]
# process_and_assign_weight(layer.attention.dense, v, 0)
weights.update(
process_and_assign_weight(v,
f'{tllm_prex}.attention.dense',
tp_dim=0))
# 4.3 mlp.gate
v = [load(prefix + gptq_key_list[6] + suf) for suf in gptq_suffix_list]
# process_and_assign_weight(layer.mlp.gate, v, 1)
weights.update(
process_and_assign_weight(v, f'{tllm_prex}.mlp.gate', tp_dim=1))
# 4.4 mlp.proj
v = [load(prefix + gptq_key_list[7] + suf) for suf in gptq_suffix_list]
# process_and_assign_weight(layer.mlp.proj, v, 0)
weights.update(
process_and_assign_weight(v, f'{tllm_prex}.mlp.proj', tp_dim=0))
# 4.5 mlp.fc
v = [load(prefix + gptq_key_list[8] + suf) for suf in gptq_suffix_list]
# process_and_assign_weight(layer.mlp.fc, v, 1)
weights.update(
process_and_assign_weight(v, f'{tllm_prex}.mlp.fc', tp_dim=1))
# 4.6 input_layernorm
v = load(prefix + gptq_key_list[9])
# layer.input_layernorm.weight.value = v.to(torch_dtype).cpu().numpy()
weights[f'{tllm_prex}.input_layernorm.weight'] = v.to(torch_dtype)
# 4.7 post_layernorm
v = load(prefix + gptq_key_list[10])
# layer.post_layernorm.weight.value = v.to(torch_dtype).cpu().numpy()
weights[f'{tllm_prex}.post_layernorm.weight'] = v.to(torch_dtype)
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
logger.info(f'Weights loaded. Total time: {t}')
return weights
def load_weights_from_meta_ckpt(meta_ckpt_dir: str, config: LLaMAConfig):
torch_dtype = str_dtype_to_torch(config.dtype)
mapping = config.mapping
use_fp8_rowwise = config.quant_mode.has_fp8_rowwise()
if config.quant_mode.has_any_quant() and not use_fp8_rowwise:
logger.error(
"Meta ckpts only support fp8_rowwise quantization currently.")
print(f"fp8 rowwise {use_fp8_rowwise}")
weights = {}
def gather_ckpts(ckpts):
gathered = {}
for k in ckpts[0]:
d = 0
# TODO(bhsueh) not sure should we consider tok here.
if any([n in k for n in ["wo", "w2"]]):
d = 1
if "norm" in k or "rope" in k: # no TP
gathered[k] = ckpts[0][k].clone()
else:
gathered[k] = torch.cat([pt[k] for pt in ckpts], dim=d).clone()
return gathered
def split_ckpt(ckpt, ranks_per_ckpt, ckpt_rank):
split_ckpt = {}
for k, v in ckpt.items():
d = 0
if any(n in k for n in
["wo", "feed_forward.w2", "tok", "feed_forward.gate"]):
d = 1
if "norm" in k or "rope" in k: # no TP
split_ckpt[k] = v.clone()
elif config.num_key_value_heads < mapping.tp_size and any(
n in k for n in ["wk", "wv"]):
assert mapping.tp_size % config.num_key_value_heads == 0
# special case: we need to duplicate KV head
tmp = dup_kv_weight(v, config.num_key_value_heads,
mapping.tp_size)
split_ckpt[k] = torch.split(tmp,
tmp.shape[d] // ranks_per_ckpt,
dim=d)[ckpt_rank].clone()
else:
split_ckpt[k] = torch.split(v,
v.shape[d] // ranks_per_ckpt,
dim=d)[ckpt_rank].clone()
return split_ckpt
def get_current_weights(num_ckpts):
if num_ckpts > mapping.tp_size:
# combine ckpts
assert (num_ckpts % mapping.tp_size) == 0
nf = num_ckpts // mapping.tp_size
fs = nf * mapping.tp_rank
file_ids = list(range(fs, fs + nf))
ckpts = []
for f in file_ids:
ckpt = torch.load(Path(meta_ckpt_dir,
f"consolidated.{f:02d}.pth"),
map_location="cpu")
ckpts.append(ckpt)
return gather_ckpts(ckpts)
elif num_ckpts < mapping.tp_size:
# split ckpt
assert (mapping.tp_size % num_ckpts) == 0
ranks_per_ckpt = mapping.tp_size // num_ckpts
ckpt_fid = mapping.tp_rank // ranks_per_ckpt
ckpt_rank = mapping.tp_rank % ranks_per_ckpt
nH_per_ckpt = config.num_attention_heads // num_ckpts
assert (nH_per_ckpt % ranks_per_ckpt) == 0
ckpt = torch.load(Path(meta_ckpt_dir,
f"consolidated.{ckpt_fid:02d}.pth"),
map_location="cpu")
return split_ckpt(ckpt, ranks_per_ckpt, ckpt_rank)
# num_ckpts == tensor_parallel, 1:1 mapping from files to TP
return torch.load(Path(meta_ckpt_dir,
f"consolidated.{mapping.tp_rank:02d}.pth"),
map_location="cpu")
def permute(w, nH, d, dH):
# due to MQA's wk, nH*dH != d could be true
return w.view(nH, dH // 2, 2, d).transpose(1, 2).reshape(nH * dH, d)
def extract_layer_idx(name):
ss = name.split('.')
for s in ss:
if s.isdigit():
return s
return None
if not hasattr(load_weights_from_meta_ckpt, "saved_embed"):
load_weights_from_meta_ckpt.saved_embed = None
def combine_embeddings(embeds, num_ckpts):
if len(embeds) == 1:
return embeds[0]
assert [
embeds[i].shape == embeds[i + 1].shape
for i in range(len(embeds) - 1)
]
if embeds[0].shape[0] == config.vocab_size // num_ckpts:
merge_dim = 0
elif embeds[0].shape[1] == config.hidden_size // num_ckpts:
merge_dim = 1
else:
logger.error("Unable to infer embedding split dimension")
assert False, "Unable to infer embedding split dimension"
return torch.cat(embeds, dim=merge_dim)
def gather_embedding(cur_embed, name: str, num_ckpts):
if mapping.tp_size == 1:
# even if num_ckpts > 1, get_current_weights will already have it gathered
return cur_embed
if load_weights_from_meta_ckpt.saved_embed is None:
embeds = [None] * num_ckpts
for i in range(num_ckpts):
ckpt = torch.load(Path(meta_ckpt_dir,
f"consolidated.{i:02d}.pth"),
map_location="cpu")
embeds[i] = ckpt[name]
embed = combine_embeddings(embeds, num_ckpts).to(torch_dtype)
load_weights_from_meta_ckpt.saved_embed = embed
return load_weights_from_meta_ckpt.saved_embed
logger.info('Loading weights from Meta LLaMA checkpoints ...')
tik = time.time()
num_kv_heads = config.num_key_value_heads
mha_mode = (num_kv_heads == config.num_attention_heads)
ckpts = list(Path(meta_ckpt_dir).glob("consolidated.*.pth"))
num_ckpts = len(ckpts)
# llama/llama2 doesn't have MQA. So, simplifying loader logic by not worrying about it.
assert num_kv_heads > 1 or num_kv_heads >= num_ckpts, \
f"We don't know how the {num_kv_heads} KV heads are distributed among {num_ckpts} checkpoints."
head_size = config.hidden_size // config.num_attention_heads
ckpt = get_current_weights(num_ckpts)
layers_range = mapping.pp_layers(config.num_hidden_layers)
for l in layers_range:
prefix = f'layers.{l}.attention.'
q_weight = permute(ckpt[prefix + 'wq.weight'].clone(),
nH=(config.num_attention_heads // mapping.tp_size),
d=config.hidden_size,
dH=head_size)
if num_kv_heads < mapping.tp_size and num_ckpts >= mapping.tp_size:
assert mapping.tp_size % num_kv_heads == 0
assert False, "Not supported yet"
k_weight = permute(ckpt[prefix + 'wk.weight'].clone(),
nH=((num_kv_heads + mapping.tp_size - 1) //
mapping.tp_size),
d=config.hidden_size,
dH=head_size)
v_weight = ckpt[prefix + 'wv.weight'].clone()
qkv_weight = torch.cat([q_weight, k_weight, v_weight], dim=0)
ckpt[prefix + 'qkv.weight'] = qkv_weight
for k, v in tqdm(ckpt.items()):
dtype = torch_dtype if 'feed_forward.gate' not in k else torch.float32
v = v.to(dtype)
if "tok_embeddings" in k:
if not config.use_parallel_embedding:
v = gather_embedding(v, k, num_ckpts)
elif config.embedding_sharding_dim == 0:
# this needs a gather and then resplit along different dims
v = gather_embedding(v, k, num_ckpts)
v = split(v, mapping.tp_size, mapping.tp_rank, 0)
if mapping.is_first_pp_rank():
weights['transformer.vocab_embedding.weight'] = v
elif "output" in k:
if mapping.is_last_pp_rank():
if config.vocab_size % mapping.tp_size != 0:
# padding
vocab_size_padded = pad_vocab_size(config.vocab_size,
mapping.tp_size)
pad_width = vocab_size_padded - config.vocab_size
v = torch.from_numpy(
np.pad(v.detach().cpu().numpy(),
((0, pad_width), (0, 0)),
'constant',
constant_values=0))
weights['lm_head.weight'] = v
elif k == "norm.weight":
if mapping.is_last_pp_rank():
weights['transformer.ln_f.weight'] = v
else:
# layer specific weights
layer_idx = extract_layer_idx(k)
if layer_idx is None or int(layer_idx) not in layers_range:
continue
idx = int(layer_idx) - layers_range[0]
tllm_prex = f'transformer.layers.{idx}.'
if 'attention_norm.weight' in k:
weights[tllm_prex + 'input_layernorm.weight'] = v
elif 'ffn_norm.weight' in k:
weights[tllm_prex + 'post_layernorm.weight'] = v
elif 'feed_forward.w3.weight' in k:
if use_fp8_rowwise:
processed_torch_weights, torch_weight_scales = fp8_per_channel_quant_weight_gpu(
v)
weights[tllm_prex +
'mlp.gate.weight'] = processed_torch_weights
weights[tllm_prex +
'mlp.gate.per_channel_scale'] = torch_weight_scales
else:
weights[tllm_prex + 'mlp.gate.weight'] = v
elif 'feed_forward.w2.weight' in k:
if use_fp8_rowwise:
processed_torch_weights, torch_weight_scales = fp8_per_channel_quant_weight_gpu(
v)
weights[tllm_prex +
'mlp.proj.weight'] = processed_torch_weights
weights[tllm_prex +
'mlp.proj.per_channel_scale'] = torch_weight_scales
else:
weights[tllm_prex + 'mlp.proj.weight'] = v
elif 'feed_forward.w1.weight' in k:
if use_fp8_rowwise:
processed_torch_weights, torch_weight_scales = fp8_per_channel_quant_weight_gpu(
v)
weights[tllm_prex +
'mlp.fc.weight'] = processed_torch_weights
weights[tllm_prex +
'mlp.fc.per_channel_scale'] = torch_weight_scales
else:
weights[tllm_prex + 'mlp.fc.weight'] = v
elif 'attention.wo.weight' in k:
weights[tllm_prex + 'attention.dense.weight'] = v
elif 'attention.qkv.weight' in k:
weights[tllm_prex + 'attention.qkv.weight'] = v
elif 'feed_forward.gate' in k:
weights[tllm_prex + 'mlp.router.weight'] = v
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
logger.info(f'Weights loaded. Total time: {t}')
return weights
|