File size: 87,584 Bytes
5000658 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783 1784 1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1824 1825 1826 1827 1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 1838 1839 1840 1841 1842 1843 1844 1845 1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857 1858 1859 1860 1861 1862 1863 1864 1865 1866 1867 1868 1869 1870 1871 1872 1873 1874 1875 1876 1877 1878 1879 1880 1881 1882 1883 1884 1885 1886 1887 1888 1889 1890 1891 1892 1893 1894 1895 1896 1897 1898 1899 1900 1901 1902 1903 1904 1905 1906 1907 1908 1909 1910 1911 1912 1913 1914 1915 1916 1917 1918 1919 1920 1921 1922 1923 1924 1925 1926 1927 1928 1929 1930 1931 1932 1933 1934 1935 1936 1937 1938 1939 1940 1941 1942 1943 1944 1945 1946 1947 1948 1949 1950 1951 1952 1953 1954 1955 1956 1957 1958 1959 1960 |
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from collections import OrderedDict
from typing import List, Optional
import tensorrt as trt
import torch
from tensorrt_llm._common import default_net
from tensorrt_llm._utils import numpy_to_torch, str_dtype_to_torch
from tensorrt_llm.functional import (LayerNormPositionType, LayerNormType,
MLPType, PositionEmbeddingType, Tensor,
assertion, cast, gather_last_token_logits,
gelu, maximum, minimum, recv, send, shape,
slice, transpose)
from tensorrt_llm.layers import (MLP, Attention, AttentionMaskType,
AttentionParams, BertAttention, ColumnLinear,
Conv1d, Embedding, FusedGatedMLP, GatedMLP,
GroupNorm, KeyValueCacheParams, LayerNorm,
LoraParams, PromptTuningEmbedding, RmsNorm)
from tensorrt_llm.lora_manager import (LoraConfig,
get_default_trtllm_modules_to_hf_modules,
use_lora)
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.modeling_utils import PretrainedConfig, PretrainedModel
from tensorrt_llm.module import Module, ModuleList
from tensorrt_llm.parameter import Parameter
from tensorrt_llm.plugin.plugin import current_all_reduce_helper
layernorm_map = {
LayerNormType.LayerNorm: LayerNorm,
LayerNormType.RmsNorm: RmsNorm,
LayerNormType.GroupNorm: GroupNorm,
}
mlp_map = {
MLPType.MLP: MLP,
MLPType.GatedMLP: GatedMLP,
MLPType.FusedGatedMLP: FusedGatedMLP,
}
class EncDecEmbedding(Module):
def __init__(self,
vocab_size,
hidden_size,
max_position_embeddings=None,
has_position_embedding=False,
type_vocab_size=None,
has_embedding_layernorm=False,
has_embedding_scale=False,
layernorm_eps=1e-5,
layernorm_type=LayerNormType.LayerNorm,
dtype=None,
use_parallel_embedding=False,
embedding_sharding_dim=0,
mapping=Mapping()):
super().__init__()
self.layernorm_type = layernorm_type
ln_type = layernorm_map[layernorm_type]
self.vocab_embedding = Embedding(
vocab_size,
hidden_size,
dtype=dtype,
tp_size=mapping.tp_size if use_parallel_embedding else 1,
tp_group=mapping.tp_group if use_parallel_embedding else None,
sharding_dim=embedding_sharding_dim,
tp_rank=mapping.tp_rank)
self.position_embedding = None
self.max_position_embeddings = max_position_embeddings
if has_position_embedding:
self.position_embedding = Embedding(
max_position_embeddings,
hidden_size,
dtype=dtype,
tp_size=mapping.tp_size if use_parallel_embedding else 1,
tp_group=mapping.tp_group if use_parallel_embedding else None,
sharding_dim=embedding_sharding_dim,
tp_rank=mapping.tp_rank)
self.token_type_embedding = None
if type_vocab_size:
self.token_type_embedding = Embedding(
type_vocab_size,
hidden_size,
dtype=dtype,
tp_size=mapping.tp_size if use_parallel_embedding else 1,
tp_group=mapping.tp_group if use_parallel_embedding else None,
sharding_dim=embedding_sharding_dim,
tp_rank=mapping.tp_rank)
# e.g. BART true, T5 false
self.embedding_layernorm = None
if has_embedding_layernorm:
self.embedding_layernorm = ln_type(normalized_shape=hidden_size,
eps=layernorm_eps,
dtype=dtype)
# e.g. BART true, T5 false
self.embedding_scale = 1.0
if has_embedding_scale:
self.embedding_scale = math.sqrt(hidden_size)
# Note: embedding offset in BART is not considered as a standard. For the specific case,
# we just need to shrink its position embedding table by [offset:] during weight loading
def forward(self,
input_ids,
position_ids=None,
token_type_ids=None,
prompt_embedding_table=None,
prompt_tasks=None,
prompt_vocab_size=None):
# position_ids and token_type_ids are provided inputs
# and should not be formulated deterministically
args = [prompt_embedding_table, prompt_tasks, prompt_vocab_size
] if prompt_embedding_table is not None else []
x = self.vocab_embedding(input_ids, *args) * self.embedding_scale
self.register_network_output('word_embeddings', x)
if self.position_embedding:
pos_emb = self.position_embedding(position_ids)
self.register_network_output('position_embeddings', pos_emb)
x = x + pos_emb
if self.token_type_embedding:
x = x + self.token_type_embedding(token_type_ids)
if self.embedding_layernorm:
x = self.embedding_layernorm(x)
return x
class EncoderLayer(Module):
def __init__(self,
hidden_size,
ffn_hidden_size,
num_attention_heads,
num_kv_heads,
head_size,
max_position_embeddings=None,
q_scaling=1.0,
has_attention_qkvo_bias=False,
has_mlp_bias=False,
layernorm_position=LayerNormPositionType.pre_layernorm,
layernorm_type=LayerNormType.LayerNorm,
layernorm_eps=1e-5,
hidden_act="relu",
mlp_type=MLPType.MLP,
mapping=Mapping(),
dtype=None,
residual_scaling=1.0,
relative_attention=False,
max_distance=0,
num_buckets=0,
fp16_clamping=False):
super().__init__()
# e.g. BART regular, T5 RMS
self.layernorm_type = layernorm_type
ln_type = layernorm_map[layernorm_type]
# e.g. BART post, T5 pre
self.layernorm_position = layernorm_position
# e.g. BART q_scaling = 1.f, T5 q_scaling = 1.f/sqrt(head_size)
self.attention = BertAttention(
hidden_size,
num_attention_heads,
attention_head_size=head_size,
num_kv_heads=num_kv_heads,
max_position_embeddings=max_position_embeddings,
q_scaling=q_scaling,
bias=has_attention_qkvo_bias,
tp_group=mapping.tp_group,
tp_size=mapping.tp_size,
tp_rank=mapping.tp_rank,
dtype=dtype,
relative_attention=relative_attention,
max_distance=max_distance,
num_buckets=num_buckets)
self.attention_layernorm = ln_type(normalized_shape=hidden_size,
eps=layernorm_eps,
dtype=dtype)
# T5/BART MLP, Flan-T5 GatedMLP
self.mlp_type = mlp_type
mlp_f = mlp_map[mlp_type]
self.mlp = mlp_f(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
hidden_act=hidden_act,
bias=has_mlp_bias,
tp_group=mapping.tp_group,
tp_size=mapping.tp_size,
dtype=dtype,
)
self.mlp_layernorm = ln_type(normalized_shape=hidden_size,
eps=layernorm_eps,
dtype=dtype)
self.residual_scaling = residual_scaling
# T5-series model(e.g. t5-large, t5-3b, flan-t5-small) has accuracy issue due to fp16 overflow
# after residual add. We add workaround for clamping fp16 range [-64000, 64000] after every
# residual add to avoid accuracy drop.
self.fp16_clamping = fp16_clamping
def forward(self,
hidden_states: Tensor,
attention_mask=None,
input_lengths=None,
max_input_length=None,
lora_layer_params=None):
assert isinstance(hidden_states, Tensor)
# self attention
residual = hidden_states * self.residual_scaling
if self.layernorm_position == LayerNormPositionType.pre_layernorm:
hidden_states = self.attention_layernorm(hidden_states)
attention_output = self.attention(hidden_states,
attention_mask=attention_mask,
input_lengths=input_lengths,
max_input_length=max_input_length,
lora_layer_params=lora_layer_params)
self.register_network_output('attention_output', attention_output)
hidden_states = residual + attention_output
if self.fp16_clamping:
hidden_states = maximum(-64000.0, hidden_states)
hidden_states = minimum(64000.0, hidden_states)
if self.layernorm_position == LayerNormPositionType.post_layernorm:
hidden_states = self.attention_layernorm(hidden_states)
# MLP
residual = hidden_states * self.residual_scaling
if self.layernorm_position == LayerNormPositionType.pre_layernorm:
hidden_states = self.mlp_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states,
lora_layer_params=lora_layer_params)
self.register_network_output('mlp_output', hidden_states)
hidden_states = residual + hidden_states
if self.fp16_clamping:
hidden_states = maximum(-64000.0, hidden_states)
hidden_states = minimum(64000.0, hidden_states)
if self.layernorm_position == LayerNormPositionType.post_layernorm:
hidden_states = self.mlp_layernorm(hidden_states)
return hidden_states
class DecoderLayer(Module):
def __init__(self,
*,
local_layer_idx,
hidden_size,
ffn_hidden_size,
num_attention_heads,
num_kv_heads,
head_size,
max_position_embeddings=None,
q_scaling=1.0,
has_attention_qkvo_bias=False,
has_mlp_bias=False,
layernorm_position=LayerNormPositionType.pre_layernorm,
layernorm_type=LayerNormType.LayerNorm,
layernorm_eps=1e-5,
hidden_act="relu",
mlp_type=MLPType.MLP,
mapping=Mapping(),
dtype=None,
residual_scaling=1.0,
relative_attention=False,
max_distance=0,
num_buckets=0,
fp16_clamping=False,
skip_cross_qkv=False,
use_implicit_relative_attention=False):
super().__init__()
# e.g. BART regular, T5 RMS
self.layernorm_type = layernorm_type
ln_type = layernorm_map[layernorm_type]
# e.g. BART post, T5 pre
self.layernorm_position = layernorm_position
# e.g. BART q_scaling = 1.f, T5 q_scaling = 1.f/sqrt(head_size)
self.self_attention = Attention(
local_layer_idx=local_layer_idx,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
attention_head_size=head_size,
num_kv_heads=num_kv_heads,
max_position_embeddings=max_position_embeddings,
q_scaling=q_scaling,
bias=has_attention_qkvo_bias,
attention_mask_type=AttentionMaskType.causal,
tp_group=mapping.tp_group,
tp_size=mapping.tp_size,
tp_rank=mapping.tp_rank,
dtype=dtype,
cross_attention=False,
relative_attention=relative_attention,
max_distance=max_distance if use_implicit_relative_attention else 0,
num_buckets=num_buckets,
position_embedding_type=PositionEmbeddingType.relative
if relative_attention else PositionEmbeddingType.learned_absolute,
use_implicit_relative_attention=use_implicit_relative_attention)
self.self_attention_layernorm = ln_type(normalized_shape=hidden_size,
eps=layernorm_eps,
dtype=dtype)
# Note: self attn uses MMHA, mask is always causal triangular
# cross attn has two scenarios:
# - in context phase, all ones mask, same as padding type
# - in generation phase, same causal triangular mask as MMHA
# - context phase special handling is done in plugin by resetting mask type
#
# e.g. BART q_scaling = 1.f, T5 q_scaling = 1.f/sqrt(head_size)
self.cross_attention = Attention(
local_layer_idx=local_layer_idx,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
attention_head_size=head_size,
num_kv_heads=num_kv_heads,
max_position_embeddings=max_position_embeddings,
q_scaling=q_scaling,
bias=has_attention_qkvo_bias,
attention_mask_type=AttentionMaskType.causal,
tp_group=mapping.tp_group,
tp_size=mapping.tp_size,
tp_rank=mapping.tp_rank,
dtype=dtype,
cross_attention=True,
relative_attention=
False, # Cross attention has no relative attention bias
max_distance=max_distance,
num_buckets=num_buckets,
position_embedding_type=PositionEmbeddingType.learned_absolute,
skip_cross_qkv=skip_cross_qkv)
self.cross_attention_layernorm = ln_type(normalized_shape=hidden_size,
eps=layernorm_eps,
dtype=dtype)
# T5/BART MLP, Flan-T5 GatedMLP
self.mlp_type = mlp_type
mlp_f = mlp_map[mlp_type]
self.mlp = mlp_f(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
hidden_act=hidden_act,
bias=has_mlp_bias,
tp_group=mapping.tp_group,
tp_size=mapping.tp_size,
dtype=dtype,
)
self.mlp_layernorm = ln_type(normalized_shape=hidden_size,
eps=layernorm_eps,
dtype=dtype)
self.residual_scaling = residual_scaling
# T5-series model(e.g. t5-large, t5-3b, flan-t5-small) has accuracy issue due to fp16 overflow
# after residual add. We add workaround for clamping fp16 range [-64000, 64000] after every
# residual add to avoid accuracy drop.
self.fp16_clamping = fp16_clamping
def forward(self,
hidden_states: Tensor,
encoder_output: Optional[Tensor] = None,
attention_mask=None,
cross_attention_mask=None,
use_cache=False,
kv_cache_params=None,
attention_params=None,
lora_layer_params=None,
cross_kv_cache_gen: Optional[Tensor] = None,
cross_qkv_reuse: Optional[Tensor] = None):
assert isinstance(hidden_states, Tensor)
if encoder_output:
assert isinstance(encoder_output, Tensor)
# self-attention
residual = hidden_states * self.residual_scaling
if self.layernorm_position == LayerNormPositionType.pre_layernorm:
hidden_states = self.self_attention_layernorm(hidden_states)
attention_output = self.self_attention(
hidden_states=hidden_states,
attention_mask=attention_mask,
use_cache=use_cache,
kv_cache_params=kv_cache_params,
attention_params=attention_params,
lora_layer_params=lora_layer_params)
if use_cache:
attention_output, presents_self = attention_output
self.register_network_output('self_attention_output', attention_output)
hidden_states = residual + attention_output
if self.fp16_clamping:
hidden_states = maximum(-64000.0, hidden_states)
hidden_states = minimum(64000.0, hidden_states)
if self.layernorm_position == LayerNormPositionType.post_layernorm:
hidden_states = self.self_attention_layernorm(hidden_states)
# cross attention
residual = hidden_states * self.residual_scaling
if self.layernorm_position == LayerNormPositionType.pre_layernorm:
hidden_states = self.cross_attention_layernorm(hidden_states)
attention_output = self.cross_attention(
hidden_states=hidden_states,
attention_mask=cross_attention_mask,
encoder_output=encoder_output,
use_cache=use_cache,
kv_cache_params=kv_cache_params,
attention_params=attention_params,
lora_layer_params=lora_layer_params,
cross_kv_cache_gen=cross_kv_cache_gen,
cross_qkv_reuse=cross_qkv_reuse)
if use_cache:
attention_output, presents_cross = attention_output
self.register_network_output('cross_attention_output', attention_output)
hidden_states = residual + attention_output
if self.fp16_clamping:
hidden_states = maximum(-64000.0, hidden_states)
hidden_states = minimum(64000.0, hidden_states)
if self.layernorm_position == LayerNormPositionType.post_layernorm:
hidden_states = self.cross_attention_layernorm(hidden_states)
# MLP
residual = hidden_states * self.residual_scaling
if self.layernorm_position == LayerNormPositionType.pre_layernorm:
hidden_states = self.mlp_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states,
lora_layer_params=lora_layer_params)
self.register_network_output('mlp_output', hidden_states)
hidden_states = residual + hidden_states
if self.fp16_clamping:
hidden_states = maximum(-64000.0, hidden_states)
hidden_states = minimum(64000.0, hidden_states)
if self.layernorm_position == LayerNormPositionType.post_layernorm:
hidden_states = self.mlp_layernorm(hidden_states)
if use_cache:
return (hidden_states, presents_self, presents_cross)
return hidden_states
class EncoderModel(PretrainedModel):
def __init__(self, config: PretrainedConfig):
self.check_config(config)
super().__init__(config)
self.mapping = self.config.mapping
self.has_position_embedding = self.config.has_position_embedding
type_vocab_size = self.config.type_vocab_size
self.has_token_type_embedding = False if type_vocab_size is None else True
# e.g. BART regular, T5 RMS
self.layernorm_type = self.config.layernorm_type
ln_type = layernorm_map[self.layernorm_type]
# e.g. BART true, T5 false
self.has_attention_qkvo_bias = self.config.has_attention_qkvo_bias
self.has_mlp_bias = self.config.has_mlp_bias
# e.g. BART false, T5 true
self.has_model_final_layernorm = self.config.has_model_final_layernorm
self._dtype = self.config.dtype
self.total_num_layers = self.config.num_hidden_layers
self.num_layers = self.config.num_hidden_layers // self.mapping.pp_size
self.hidden_size = self.config.hidden_size
self.num_heads = self.config.num_attention_heads
num_kv_heads = self.num_heads
if num_kv_heads is None or num_kv_heads <= 0:
num_kv_heads = self.config.num_attention_heads
self.num_kv_heads = num_kv_heads
self.head_size = self.hidden_size // self.num_heads if self.config.head_size is None else self.config.head_size
self.fp16_clamping = (self.config.dtype
== 'float16') and (self.config.model_type == 't5')
self.mlp_type = MLPType.MLP if not hasattr(
self.config, "mlp_type") else self.config.mlp_type
if self.mapping.is_first_pp_rank():
self.embedding = EncDecEmbedding(
self.config.vocab_size,
self.config.hidden_size,
max_position_embeddings=self.config.max_position_embeddings,
has_position_embedding=self.has_position_embedding,
type_vocab_size=type_vocab_size,
has_embedding_layernorm=self.config.has_embedding_layernorm,
has_embedding_scale=self.config.has_embedding_scale,
layernorm_eps=self.config.norm_epsilon,
layernorm_type=self.layernorm_type,
dtype=self.config.dtype,
use_parallel_embedding=self.config.use_parallel_embedding,
embedding_sharding_dim=self.config.embedding_sharding_dim,
mapping=self.mapping)
self.encoder_layers = ModuleList([
EncoderLayer(
hidden_size=self.hidden_size,
ffn_hidden_size=self.config.intermediate_size,
num_attention_heads=self.num_heads,
num_kv_heads=num_kv_heads,
head_size=self.head_size,
max_position_embeddings=self.config.max_position_embeddings,
q_scaling=self.config.q_scaling,
has_attention_qkvo_bias=self.has_attention_qkvo_bias,
has_mlp_bias=self.has_mlp_bias,
layernorm_position=self.config.layernorm_position,
layernorm_eps=self.config.norm_epsilon,
layernorm_type=self.layernorm_type,
hidden_act=self.config.hidden_act,
mlp_type=self.mlp_type,
mapping=self.mapping,
dtype=self.config.dtype,
residual_scaling=1.0
if not hasattr(self.config, "residual_scaling") else
self.config.residual_scaling,
relative_attention=self.config.relative_attention,
max_distance=self.config.max_distance,
num_buckets=self.config.num_buckets,
fp16_clamping=self.fp16_clamping)
for _ in self.mapping.pp_layers(self.total_num_layers)
])
if self.mapping.is_last_pp_rank():
if self.has_model_final_layernorm:
self.final_layernorm = ln_type(
normalized_shape=self.config.hidden_size,
eps=self.config.norm_epsilon,
dtype=self.config.dtype)
def check_config(self, config: PretrainedConfig):
config.set_if_not_exist('has_position_embedding', False)
config.set_if_not_exist('type_vocab_size', None)
config.set_if_not_exist('rescale_before_lm_head', False)
config.set_if_not_exist('layernorm_type', LayerNormType.LayerNorm)
config.set_if_not_exist('layernorm_position',
LayerNormPositionType.pre_layernorm)
config.set_if_not_exist('has_attention_qkvo_bias', False)
config.set_if_not_exist('has_mlp_bias', False)
config.set_if_not_exist('has_model_final_layernorm', False)
config.set_if_not_exist('encoder_hidden_size', None)
config.set_if_not_exist('encoder_num_heads', None)
config.set_if_not_exist('encoder_num_kv_heads', None)
config.set_if_not_exist('encoder_head_size', None)
config.set_if_not_exist('model_type', 't5')
config.set_if_not_exist('skip_cross_qkv', False)
config.set_if_not_exist('mlp_type', MLPType.MLP)
config.set_if_not_exist('has_embedding_scale', False)
config.set_if_not_exist('residual_scaling', 1.0)
config.set_if_not_exist('has_lm_head_bias', False)
config.set_if_not_exist('num_buckets', None)
config.set_if_not_exist('max_distance', None)
config.set_if_not_exist('relative_attention', False)
config.set_if_not_exist('residual_scaling', 1.0)
def forward(self,
input_ids: Tensor,
input_lengths=None,
position_ids=None,
token_type_ids=None,
hidden_states=None,
max_input_length=None,
prompt_embedding_table=None,
prompt_tasks=None,
prompt_vocab_size=None,
attention_mask=None,
lora_params: LoraParams = None):
# In PP, layer 0 has ids as inputs, all other layers have hidden_states as inputs
if self.mapping.is_first_pp_rank():
ptuning_args = [
prompt_embedding_table, prompt_tasks, prompt_vocab_size
] if prompt_embedding_table is not None else []
hidden_states = self.embedding(input_ids, position_ids,
token_type_ids, *ptuning_args)
self.register_network_output('embedding_layer_output',
hidden_states)
else:
hidden_states = recv(hidden_states, self.mapping.prev_pp_rank())
for layer_idx, encoder_layer in enumerate(self.encoder_layers):
lora_layer_params = None
if lora_params is not None and lora_params.lora_ranks is not None:
lora_layer_params = lora_params.get_layer_params(layer_idx)
hidden_states = encoder_layer(hidden_states=hidden_states,
attention_mask=attention_mask,
input_lengths=input_lengths,
max_input_length=max_input_length,
lora_layer_params=lora_layer_params)
if self.mapping.is_last_pp_rank():
if self.has_model_final_layernorm:
hidden_states = self.final_layernorm(hidden_states)
hidden_states.mark_output('encoder_output', self._dtype)
else:
hidden_states = send(hidden_states, self.mapping.next_pp_rank())
hidden_states.mark_output('hidden_states_output', self._dtype)
return hidden_states
def prepare_inputs(self,
max_batch_size,
max_input_len,
prompt_embedding_table_size: int = 0,
lora_target_modules: List[str] = None,
*args,
**kwargs):
'''@brief: Prepare inputs Tensors for the model, the given sizes are used to determine the
ranges of the dimensions of when using TRT dynamic shapes.
@return: a list contains values which can be fed into the self.forward()
'''
hidden_size = self.hidden_size
bs_range = [1, (max_batch_size + 1) // 2, max_batch_size]
inlen_range = [1, (max_input_len + 1) // 2, max_input_len]
num_tokens_range = [
1,
(max_input_len * max_batch_size + 1) // 2,
max_input_len * max_batch_size,
]
input_ids, position_ids, token_type_ids, hidden_states = None, None, None, None
remove_input_padding = default_net().plugin_config.remove_input_padding
use_lora_plugin = default_net().plugin_config.lora_plugin
attention_mask = None
if remove_input_padding:
if self.mapping.is_first_pp_rank():
input_ids = Tensor(
name="input_ids",
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([("num_tokens", [num_tokens_range])]),
)
if self.has_position_embedding:
position_ids = Tensor(
name='position_ids',
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([('num_tokens',
[num_tokens_range])]),
)
if self.has_token_type_embedding:
token_type_ids = Tensor(
name='token_type_ids',
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([('num_tokens',
[num_tokens_range])]),
)
else:
hidden_states = Tensor(name='hidden_states_input',
dtype=self._dtype,
shape=[-1, hidden_size],
dim_range=OrderedDict([
('num_tokens', [num_tokens_range]),
('hidden_size', [hidden_size]),
]))
else:
if self.mapping.is_first_pp_rank():
input_ids = Tensor(
name="input_ids",
dtype=trt.int32,
shape=[-1, -1],
dim_range=OrderedDict([("batch_size", [bs_range]),
("input_len", [inlen_range])]),
)
if self.has_position_embedding:
position_ids = Tensor(
name='position_ids',
dtype=trt.int32,
shape=[-1, -1],
dim_range=OrderedDict([('batch_size', [bs_range]),
('input_len', [inlen_range])]),
)
if self.has_token_type_embedding:
token_type_ids = Tensor(
name='token_type_ids',
dtype=trt.int32,
shape=[-1, -1],
dim_range=OrderedDict([('batch_size', [bs_range]),
('input_len', [inlen_range])]),
)
else:
hidden_states = Tensor(name='hidden_states_input',
dtype=self._dtype,
shape=[-1, -1, hidden_size],
dim_range=OrderedDict([
('batch_size', [bs_range]),
('input_len', [inlen_range]),
('hidden_size', [hidden_size]),
]))
if not default_net().plugin_config.bert_attention_plugin:
attention_mask = Tensor(
name='attention_mask',
dtype=trt.int32,
shape=[-1, -1],
dim_range=OrderedDict([
('batch_size', [bs_range]),
('input_len', [inlen_range]),
]),
)
# if self.mapping.tp_size > 1:
# current_all_reduce_helper().set_workspace_tensor(self.mapping, 1)
# FIXME(TRTLLM-996): Support custom allreduce for encoder models on C++ runtime
input_lengths = Tensor(
name="input_lengths",
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([("batch_size", [bs_range])]),
)
max_input_length = Tensor(
name="max_input_length",
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([("max_input_length", [inlen_range])]),
)
prompt_embedding_table = None
tasks = None
prompt_vocab_size = None
if self.mapping.is_first_pp_rank() and prompt_embedding_table_size > 0:
p_embedding_range = [[
1, prompt_embedding_table_size // 2, prompt_embedding_table_size
]]
prompt_embedding_table = Tensor(name='prompt_embedding_table',
dtype=self._dtype,
shape=[-1, hidden_size],
dim_range=OrderedDict([
('prompt_embedding_table_size',
p_embedding_range),
('hidden_size', [hidden_size]),
]))
if remove_input_padding:
tasks = Tensor(name='tasks',
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([('input_len_task',
[num_tokens_range])]))
else:
tasks = Tensor(name='tasks',
dtype=trt.int32,
shape=[-1, 1],
dim_range=OrderedDict([
('batch_size', bs_range),
('broadcast_dim', [1]),
]))
prompt_vocab_size = Tensor(name='prompt_vocab_size',
dtype=trt.int32,
shape=[1],
dim_range=OrderedDict([('size', [1])]))
'''
LoRA plugin related inputs:
lora_target_modules for BART-encoder:
['attn_q', 'attn_v']
For BART-decoder, the lora_target_modules is different.
See comments in the DecoderModel.prepare_inputs() for more details.
'''
lora_weights_pointers = None
lora_ranks = None
lora_params = None
if use_lora_plugin:
lora_weights_pointers = []
lora_ranks = []
# In current design, q_lora_params, k_lora_params and v_lora_params should be all enabled or all disabled at the same time.
# However, BART lora modules only contain two of them, so we use zero tensor to fill the missing ones.
missing_qkv_modules = []
if any(x in lora_target_modules
for x in ["attn_q", "attn_k", "attn_v"]):
for lora_module in ["attn_q", "attn_k", "attn_v"]:
if lora_module not in lora_target_modules:
missing_qkv_modules.append(lora_module)
layers_range = self.mapping.pp_layers(self.total_num_layers)
for i in layers_range:
lora_weight_pointer_dict = {}
lora_rank_dict = {}
for lora_module in (lora_target_modules + missing_qkv_modules):
lora_weight_pointer = Tensor(
name=f'{lora_module}_lora_weights_pointers_{i}',
dtype=trt.int64,
shape=[-1, 2],
dim_range=OrderedDict([('batch_size', [bs_range]),
('in_out', [2])]))
lora_weight_pointer_dict.update({
f'{lora_module}_lora_weights_pointers':
lora_weight_pointer
})
lora_rank = Tensor(name=f'{lora_module}_lora_ranks_{i}',
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([('batch_size',
[bs_range])]))
lora_rank_dict.update(
{f'{lora_module}_lora_ranks': lora_rank})
lora_weights_pointers.append(lora_weight_pointer_dict)
lora_ranks.append(lora_rank_dict)
host_request_types = Tensor(name='host_request_types',
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([('batch_size',
[bs_range])]))
host_context_lengths = None
if remove_input_padding:
host_context_lengths = Tensor(name='host_context_lengths',
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([
('batch_size', [bs_range])
]))
lora_params = LoraParams(
lora_ranks=lora_ranks,
lora_weights_pointers=lora_weights_pointers,
max_context_length=max_input_len,
host_request_types=host_request_types,
host_context_lengths=host_context_lengths,
)
result = {
'input_ids': input_ids,
'input_lengths': input_lengths,
'position_ids': position_ids,
'token_type_ids': token_type_ids,
'hidden_states': hidden_states,
'max_input_length': max_input_length,
'prompt_embedding_table': prompt_embedding_table,
'prompt_tasks': tasks,
'prompt_vocab_size': prompt_vocab_size,
'attention_mask': attention_mask,
'lora_params': lora_params,
}
return result
def use_lora(self, lora_config: LoraConfig):
use_lora(self, lora_config)
def use_prompt_tuning(self):
embedding = self.embedding.vocab_embedding
self.embedding.vocab_embedding = PromptTuningEmbedding(
num_embeddings=embedding.num_embeddings,
embedding_dim=embedding.embedding_dim,
dtype=embedding.dtype,
tp_size=embedding.tp_size,
tp_group=embedding.tp_group,
sharding_dim=embedding.sharding_dim,
tp_rank=embedding.tp_rank)
self.embedding.vocab_embedding.weight.value = embedding.weight.raw_value
def precompute_relative_attention_bias(self, build_config):
pass
class DecoderModel(PretrainedModel):
def __init__(self, config: PretrainedConfig):
self.check_config(config)
super().__init__(config)
self.mapping = self.config.mapping
self.has_position_embedding = self.config.has_position_embedding
type_vocab_size = self.config.type_vocab_size
self.has_token_type_embedding = (type_vocab_size is not None)
self.rescale_before_lm_head = self.config.rescale_before_lm_head
# e.g. BART regular, T5 RMS
self.layernorm_type = self.config.layernorm_type
ln_type = layernorm_map[self.layernorm_type]
# e.g. BART true, T5 false
self.has_attention_qkvo_bias = self.config.has_attention_qkvo_bias
self.has_mlp_bias = self.config.has_mlp_bias
# e.g. BART false, T5 true
self.has_model_final_layernorm = self.config.has_model_final_layernorm
self._dtype = self.config.dtype
# no quantization considered for now
self._kv_dtype = self._dtype
self._logits_dtype = self.config.logits_dtype
self.total_num_layers = self.config.num_hidden_layers
self.num_layers = self.config.num_hidden_layers // self.mapping.pp_size
self.hidden_size = self.config.hidden_size
self.num_heads = self.config.num_attention_heads
num_kv_heads = self.num_heads
if num_kv_heads is None or num_kv_heads <= 0:
num_kv_heads = self.num_heads
self.num_kv_heads = num_kv_heads
self.head_size = self.hidden_size // self.num_heads if self.config.head_size is None else self.config.head_size
self.encoder_hidden_size = self.config.encoder_hidden_size
self.encoder_num_heads = self.config.encoder_num_heads
encoder_num_kv_heads = None if not hasattr(
self.config,
"encoder_num_kv_heads") else self.config.encoder_num_kv_heads
if encoder_num_kv_heads is None or encoder_num_kv_heads <= 0:
encoder_num_kv_heads = self.encoder_num_heads
self.encoder_num_kv_heads = encoder_num_kv_heads
self.encoder_head_size = self.encoder_hidden_size // self.num_heads if self.config.encoder_head_size is None else self.config.encoder_head_size
self.has_position_embedding = self.config.has_position_embedding
self.has_token_type_embedding = type_vocab_size is not None
self.fp16_clamping = (self.config.dtype
== 'float16') and (self.config.model_type
in ['t5', 'pix2struct'])
self.skip_cross_qkv = self.config.skip_cross_qkv
self.mlp_type = MLPType.MLP if not hasattr(
self.config, "mlp_type") else self.config.mlp_type
self.use_implicit_relative_attention = self.config.use_implicit_relative_attention if hasattr(
self.config, "use_implicit_relative_attention") else False
if self.mapping.is_first_pp_rank():
self.embedding = EncDecEmbedding(
self.config.vocab_size,
self.config.hidden_size,
max_position_embeddings=self.config.max_position_embeddings,
has_position_embedding=self.config.has_position_embedding,
type_vocab_size=type_vocab_size,
has_embedding_layernorm=self.config.has_embedding_layernorm,
has_embedding_scale=self.config.has_embedding_scale,
layernorm_eps=self.config.norm_epsilon,
layernorm_type=self.config.layernorm_type,
dtype=self._dtype,
use_parallel_embedding=self.config.use_parallel_embedding,
embedding_sharding_dim=self.config.embedding_sharding_dim,
mapping=self.mapping)
layers_range = self.mapping.pp_layers(self.total_num_layers)
self.decoder_layers = ModuleList([
DecoderLayer(
local_layer_idx=layer_idx - layers_range[0],
hidden_size=self.config.hidden_size,
ffn_hidden_size=self.config.intermediate_size,
num_attention_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
max_position_embeddings=self.config.max_position_embeddings,
q_scaling=self.config.q_scaling,
has_attention_qkvo_bias=self.config.has_attention_qkvo_bias,
has_mlp_bias=self.config.has_mlp_bias,
layernorm_position=self.config.layernorm_position,
layernorm_eps=self.config.norm_epsilon,
layernorm_type=self.config.layernorm_type,
hidden_act=self.config.hidden_act,
mlp_type=self.mlp_type,
mapping=self.mapping,
dtype=self._dtype,
residual_scaling=self.config.residual_scaling,
relative_attention=self.config.relative_attention,
max_distance=self.config.max_distance,
num_buckets=self.config.num_buckets,
fp16_clamping=self.fp16_clamping,
skip_cross_qkv=self.skip_cross_qkv,
use_implicit_relative_attention=self.
use_implicit_relative_attention) for layer_idx in layers_range
])
if self.mapping.is_last_pp_rank():
if self.has_model_final_layernorm:
self.final_layernorm = ln_type(
normalized_shape=self.config.hidden_size,
eps=self.config.norm_epsilon,
dtype=self.config.dtype)
self.lm_head = ColumnLinear(
self.config.hidden_size,
self.config.vocab_size,
bias=False if not hasattr(self.config, "has_lm_head_bias") else
self.config.has_lm_head_bias,
dtype=self.config.dtype,
tp_group=self.config.mapping.tp_group,
tp_size=self.config.mapping.tp_size,
gather_output=True,
)
self.trtllm_modules_to_hf_modules = {
**get_default_trtllm_modules_to_hf_modules(),
"attn_q": "self_attn.q_proj",
"attn_k": "self_attn.k_proj",
"attn_v": "self_attn.v_proj",
"attn_dense": "self_attn.o_proj",
"cross_attn_q": "encoder_attn.q_proj",
"cross_attn_k": "encoder_attn.k_proj",
"cross_attn_v": "encoder_attn.v_proj",
"cross_attn_dense": "encoder_attn.o_proj",
}
if self.config.relative_attention and not self.use_implicit_relative_attention:
self.rel_attn_table = Parameter(
shape=(self.config.num_attention_heads // self.mapping.tp_size,
self.config.num_buckets),
dtype=self._dtype)
def check_config(self, config: PretrainedConfig):
config.set_if_not_exist('has_position_embedding', False)
config.set_if_not_exist('type_vocab_size', None)
config.set_if_not_exist('rescale_before_lm_head', False)
config.set_if_not_exist('layernorm_type', LayerNormType.LayerNorm)
config.set_if_not_exist('layernorm_position',
LayerNormPositionType.pre_layernorm)
config.set_if_not_exist('has_attention_qkvo_bias', False)
config.set_if_not_exist('has_mlp_bias', False)
config.set_if_not_exist('has_model_final_layernorm', False)
config.set_if_not_exist('encoder_hidden_size', None)
config.set_if_not_exist('encoder_num_heads', None)
config.set_if_not_exist('encoder_num_kv_heads', None)
config.set_if_not_exist('encoder_head_size', None)
config.set_if_not_exist('model_type', 't5')
config.set_if_not_exist('skip_cross_qkv', False)
config.set_if_not_exist('mlp_type', MLPType.MLP)
config.set_if_not_exist('has_embedding_scale', False)
config.set_if_not_exist('residual_scaling', 1.0)
config.set_if_not_exist('has_lm_head_bias', False)
config.set_if_not_exist('num_buckets', None)
config.set_if_not_exist('max_distance', None)
config.set_if_not_exist('relative_attention', False)
config.set_if_not_exist('residual_scaling', 1.0)
def forward(self,
decoder_input_ids: Tensor,
encoder_output: Tensor,
position_ids=None,
token_type_ids=None,
use_cache=False,
attention_mask=None,
cross_attention_mask=None,
last_token_ids=None,
kv_cache_params=None,
attention_params=None,
hidden_states=None,
lora_params: LoraParams = None,
cross_kv_cache_gen: Optional[Tensor] = None,
cross_qkv_reuse: Optional[Tensor] = None):
if self.mapping.is_first_pp_rank():
assert isinstance(decoder_input_ids, Tensor)
else:
assert isinstance(hidden_states, Tensor)
# In PP, layer 0 has ids as inputs, all other layers have hidden_states as inputs
if self.mapping.is_first_pp_rank():
hidden_states = self.embedding(decoder_input_ids, position_ids,
token_type_ids)
self.register_network_output('embedding_layer_output',
hidden_states)
else:
hidden_states = recv(hidden_states, self.mapping.prev_pp_rank())
kv_cache_params.fill_none_tensor_list(len(self.decoder_layers))
if use_cache:
presents = []
for i, (decoder_layer, past) in enumerate(
zip(self.decoder_layers, kv_cache_params.past_key_value)):
lora_layer_params = None
if lora_params is not None and lora_params.lora_ranks is not None:
lora_layer_params = lora_params.get_layer_params(i)
hidden_states = decoder_layer(
hidden_states,
encoder_output=encoder_output,
attention_mask=attention_mask,
cross_attention_mask=cross_attention_mask,
use_cache=use_cache,
kv_cache_params=KeyValueCacheParams(
past_key_value=past,
host_past_key_value_lengths=kv_cache_params.
host_past_key_value_lengths,
host_max_attention_window_sizes=kv_cache_params.
host_max_attention_window_sizes,
host_sink_token_length=kv_cache_params.
host_sink_token_length,
cache_indirection=kv_cache_params.cache_indirection,
kv_cache_block_offsets=kv_cache_params.
kv_cache_block_offsets,
host_kv_cache_block_offsets=kv_cache_params.
host_cross_kv_cache_block_offsets,
host_kv_cache_pool_pointers=kv_cache_params.
host_kv_cache_pool_pointers,
cross_kv_cache_block_offsets=kv_cache_params.
cross_kv_cache_block_offsets,
host_cross_kv_cache_block_offsets=kv_cache_params.
host_cross_kv_cache_block_offsets,
host_cross_kv_cache_pool_pointers=kv_cache_params.
host_cross_kv_cache_pool_pointers),
attention_params=attention_params,
lora_layer_params=lora_layer_params,
cross_kv_cache_gen=cross_kv_cache_gen,
cross_qkv_reuse=cross_qkv_reuse)
if use_cache:
presents_self, presents_cross = hidden_states[1], hidden_states[
2]
presents.append((presents_self, presents_cross))
hidden_states = hidden_states[0]
self.register_network_output(f'decoder_layer_{i}_output',
hidden_states)
if self.mapping.is_last_pp_rank():
if self.has_model_final_layernorm:
hidden_states = self.final_layernorm(hidden_states)
# [bs, seq, hidden_size] or [num_tokens, hidden_size] -> [bs, hidden_size]
hidden_states = gather_last_token_logits(
hidden_states, last_token_ids,
default_net().plugin_config.remove_input_padding)
self.register_network_output('logits_before_lmhead', hidden_states)
# Rescale output before projecting on vocab (for T5)
# See https://github.com/huggingface/transformers/blob/0b192de1f353b0e04dad4813e02e2c672de077be/src/transformers/models/t5/modeling_t5.py#L1769-L1772
# Note: this is specific for T5, to make it more generic, one can pass in a config:
# self.config.tie_word_embeddings - default to be True for T5
# openai whisper model didn't use this rescale
if self.rescale_before_lm_head:
hidden_states = hidden_states * (self.hidden_size**-0.5)
# [bs, hidden_size] -> [bs, vocab_size]
lm_logits = self.lm_head(hidden_states)
lm_logits.mark_output('logits', self._logits_dtype)
else:
hidden_states = send(hidden_states, self.mapping.next_pp_rank())
hidden_states.mark_output('hidden_states_output', self._dtype)
if use_cache and default_net().plugin_config.paged_kv_cache == False:
for i, present in zip(self.mapping.pp_layers(self.total_num_layers),
presents):
present[0].mark_output(f'present_key_value_{i}', self._kv_dtype)
if default_net().plugin_config.gpt_attention_plugin:
present[1].mark_output(f'cross_present_key_value_{i}',
self._kv_dtype)
if self.mapping.is_last_pp_rank():
return (lm_logits, tuple(presents))
return (hidden_states, tuple(presents))
else:
if self.mapping.is_last_pp_rank():
return lm_logits
return hidden_states
def prepare_inputs(self,
max_batch_size,
max_beam_width,
max_decoder_input_len,
max_seq_len,
max_encoder_input_len,
gather_context_logits: bool = False,
gather_generation_logits: bool = False,
lora_target_modules: List[str] = None,
use_cache=True,
*args,
**kwargs):
'''@brief: Prepare inputs Tensors for the model, the given sizes are used to determine the
ranges of the dimensions of when using TRT dynamic shapes.
@return: a list contains values which can be fed into the self.forward()
'''
# Prepare inputs
max_output_len = max_decoder_input_len + max_seq_len
head_size = self.head_size
num_kv_heads = (self.num_kv_heads + self.mapping.tp_size -
1) // self.mapping.tp_size
encoder_head_size = self.encoder_head_size
encoder_num_kv_heads = (self.encoder_num_kv_heads + self.mapping.tp_size
- 1) // self.mapping.tp_size
bb_range = [
1, (max_batch_size * max_beam_width + 1) // 2,
max_batch_size * max_beam_width
]
bs_range = [1, (max_batch_size + 1) // 2, max_batch_size]
beam_width_range = [1, (max_beam_width + 1) // 2, max_beam_width]
inlen_range = [
1, 1, max_decoder_input_len
] # context phase >= 1 (if forced_input_ids), generation phase = 1
encoder_inlen_range = [
1, (max_encoder_input_len + 1) // 2, max_encoder_input_len
]
mask_len_range = [1, (max_output_len + 1) // 2 + 1, max_output_len + 1]
max_output_len_range = [0, (max_output_len + 1) // 2, max_output_len]
encoder_num_tokens_range = [
0, # 0 for generation phase, >0 for context phase
(max_encoder_input_len * max_batch_size + 1) // 2,
max_encoder_input_len * max_batch_size,
]
decoder_num_tokens_range = [
1,
max_batch_size * max_beam_width,
max(max_decoder_input_len * max_batch_size,
max_beam_width * max_batch_size),
]
# No enable_two_optimization_profiles support yet
encoder_input_len_range = [
0, # 0 for generation phase, >0 for context phase
(max_encoder_input_len + 1) // 2,
max_encoder_input_len
]
past_key_value = []
sequence_length = None
host_past_key_value_lengths = None
runtime_perf_knobs = None
attention_mask = None
cross_attention_mask = None
use_gpt_attention_plugin = default_net(
).plugin_config.gpt_attention_plugin
remove_input_padding = default_net().plugin_config.remove_input_padding
paged_kv_cache = default_net().plugin_config.paged_kv_cache
tokens_per_block = default_net().plugin_config.tokens_per_block
use_lora_plugin = default_net().plugin_config.lora_plugin
input_ids, position_ids, token_type_ids, hidden_states = None, None, None, None
if remove_input_padding:
if self.mapping.is_first_pp_rank():
input_ids = Tensor(name='input_ids',
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([
('decoder_num_tokens',
[decoder_num_tokens_range]),
]))
if self.has_position_embedding:
position_ids = Tensor(name='position_ids',
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([
('decoder_num_tokens',
[decoder_num_tokens_range]),
]))
if self.has_token_type_embedding:
token_type_ids = Tensor(
name='token_type_ids',
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([('decoder_num_tokens',
[decoder_num_tokens_range])]),
)
else:
hidden_states = Tensor(name='hidden_states_input',
dtype=self._dtype,
shape=[-1, self.hidden_size],
dim_range=OrderedDict([
('decoder_num_tokens',
[decoder_num_tokens_range]),
('hidden_size', [self.hidden_size]),
]))
else:
if self.mapping.is_first_pp_rank():
input_ids = Tensor(name='input_ids',
dtype=trt.int32,
shape=[-1, -1],
dim_range=OrderedDict([
('batch_size_beam_width', [bb_range]),
('input_len', [inlen_range]),
]))
if self.has_position_embedding:
position_ids = Tensor(name='position_ids',
dtype=trt.int32,
shape=[-1, -1],
dim_range=OrderedDict([
('batch_size_beam_width',
[bb_range]),
('input_len', [inlen_range]),
]))
if self.has_token_type_embedding:
token_type_ids = Tensor(
name='token_type_ids',
dtype=trt.int32,
shape=[-1, -1],
dim_range=OrderedDict([('batch_size_beam_width',
[bb_range]),
('input_len', [inlen_range])]),
)
else:
hidden_states = Tensor(name='hidden_states_input',
dtype=self._dtype,
shape=[-1, -1, self.hidden_size],
dim_range=OrderedDict([
('batch_size_beam_width', [bb_range
]),
('input_len', [inlen_range]),
('hidden_size', [self.hidden_size]),
]))
encoder_input_lengths = Tensor(
name="encoder_input_lengths",
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([("batch_size_beam_width", [bb_range])]),
)
encoder_max_input_length = Tensor(
name="encoder_max_input_length",
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([("encoder_max_input_length",
[encoder_inlen_range])]),
)
encoder_output = None
if remove_input_padding:
encoder_output = Tensor(
name="encoder_output",
dtype=self._dtype,
shape=[-1, self.encoder_hidden_size],
dim_range=OrderedDict([
("encoder_num_tokens", [encoder_num_tokens_range]),
("encoder_hidden_size", [self.encoder_hidden_size]),
]),
)
else:
encoder_output = Tensor(
name="encoder_output",
dtype=self._dtype,
shape=[-1, -1, self.encoder_hidden_size],
dim_range=OrderedDict([
("batch_size_beam_width_encoder", [bb_range]),
("encoder_input_len", [encoder_input_len_range]),
("encoder_hidden_size", [self.encoder_hidden_size]),
]),
)
if use_gpt_attention_plugin:
host_past_key_value_lengths = Tensor(
name='host_past_key_value_lengths',
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([('batch_size_beam_width', [bb_range])]),
)
context_lengths = None
host_context_lengths = None
host_request_types = None
if use_gpt_attention_plugin and remove_input_padding:
host_context_lengths = Tensor(name='host_context_lengths',
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([
('batch_size_beam_width',
[bb_range])
]))
if use_gpt_attention_plugin:
sequence_length = Tensor(
name='sequence_length',
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([('batch_size_beam_width', [bb_range])]),
)
context_lengths = Tensor(name='context_lengths',
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([
('batch_size_beam_width', [bb_range])
]))
host_request_types = Tensor(name='host_request_types',
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([
('batch_size_beam_width',
[bb_range])
]))
runtime_perf_knobs = Tensor(name='host_runtime_perf_knobs',
dtype=trt.int64,
shape=[16],
dim_range=OrderedDict([
('perf_knob_size', [16])
]))
last_token_ids = None
if self.mapping.is_last_pp_rank() and not gather_context_logits:
last_token_ids = Tensor(
name="last_token_ids",
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([("batch_size_last_token_ids", [bb_range])
]),
)
if not use_gpt_attention_plugin:
attention_mask = Tensor(
name='attention_mask',
dtype=trt.int32,
shape=[-1, -1],
dim_range=OrderedDict([
('batch_size_beam_width', [bb_range]),
('mask_len', [mask_len_range]),
]),
)
cross_attention_mask = Tensor(
name='cross_attention_mask',
dtype=trt.int32,
shape=[-1, -1, -1],
dim_range=OrderedDict([
('batch_size_beam_width', [bb_range]),
('query_len', [1]),
('encoder_input_len', [encoder_input_len_range]),
]),
)
cache_indirection = Tensor(
name='cache_indirection',
dtype=trt.int32,
shape=[-1, -1, -1],
dim_range=OrderedDict([
('batch_size_cache', [bs_range]),
('beam_width', [beam_width_range]),
('max_seq_len', [max_output_len_range]),
]),
)
if self.mapping.tp_size > 1:
current_all_reduce_helper().set_workspace_tensor(self.mapping, 1)
layers_range = self.mapping.pp_layers(self.total_num_layers)
num_pp_layers = len(layers_range)
host_max_attention_window_sizes = None
host_sink_token_length = None
if use_gpt_attention_plugin:
host_max_attention_window_sizes = Tensor(
name=f'host_max_attention_window_sizes',
dtype=trt.int32,
shape=[num_pp_layers],
dim_range=OrderedDict([('num_layers', [num_pp_layers])]))
host_sink_token_length = Tensor(name='host_sink_token_length',
dtype=trt.int32,
shape=[1],
dim_range=OrderedDict([('scalar',
[1])]))
'''
LoRA plugin related inputs:
lora_target_modules for BART-decoder:
['attn_q', 'cross_attn_q',
'attn_v', 'cross_attn_v']
This is NOT directly loaded from the adapter-config file
We make it this way because BART has LoRA weights for both self-attention and cross-attention in decoder
'''
lora_weights_pointers = None
lora_ranks = None
lora_params = None
if use_lora_plugin:
lora_weights_pointers = []
lora_ranks = []
# In current design, q_lora_params, k_lora_params and v_lora_params should be all enabled or all disabled at the same time.
# However, BART lora modules only contain two of them, so we use zero tensor to fill the missing ones.
missing_qkv_modules = []
if any(x in lora_target_modules
for x in ["attn_q", "attn_k", "attn_v"]):
for lora_module in [
"attn_q",
"attn_k",
"attn_v",
]:
if lora_module not in lora_target_modules:
missing_qkv_modules.append(lora_module)
if any(x in lora_target_modules
for x in ["cross_attn_q", "cross_attn_k", "cross_attn_v"]):
for lora_module in [
"cross_attn_q", "cross_attn_k", "cross_attn_v"
]:
if lora_module not in lora_target_modules:
missing_qkv_modules.append(lora_module)
for i in layers_range:
lora_weight_pointer_dict = {}
lora_rank_dict = {}
for lora_module in (lora_target_modules + missing_qkv_modules):
lora_weight_pointer = Tensor(
name=f'{lora_module}_lora_weights_pointers_{i}',
dtype=trt.int64,
shape=[-1, 2],
dim_range=OrderedDict([('batch_size_beam_width',
[bb_range]), ('in_out', [2])]))
lora_weight_pointer_dict.update({
f'{lora_module}_lora_weights_pointers':
lora_weight_pointer
})
lora_rank = Tensor(name=f'{lora_module}_lora_ranks_{i}',
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([
('batch_size_beam_width', [bb_range])
]))
lora_rank_dict.update(
{f'{lora_module}_lora_ranks': lora_rank})
lora_weights_pointers.append(lora_weight_pointer_dict)
lora_ranks.append(lora_rank_dict)
# For cross attention, we need to use encoder_input_lengths (in CPU) to pass
# as the host_context_lengths to the lora_plugin. But for self attention, we
# should keep using the original host_context_lengths. Therefore, we keep both
# of them in the lora_params.
host_encoder_input_lengths = None
if remove_input_padding:
host_encoder_input_lengths = Tensor(
name="host_encoder_input_lengths",
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([("batch_size_beam_width", [bb_range])
]),
)
lora_params = LoraParams(
lora_ranks=lora_ranks,
lora_weights_pointers=lora_weights_pointers,
host_context_lengths=host_context_lengths,
max_context_length=max_decoder_input_len,
max_encoder_context_length=max_encoder_input_len,
host_request_types=host_request_types,
host_encoder_input_lengths=host_encoder_input_lengths,
)
kv_cache_block_offsets = None
host_kv_cache_block_offsets = None
host_kv_cache_pool_pointers = None
cross_kv_cache_block_offsets = None
host_cross_kv_cache_block_offsets = None
host_cross_kv_cache_pool_pointers = None
if use_cache:
if not paged_kv_cache:
for i in layers_range:
kv_dim_range = OrderedDict([
('batch_size_beam_width', [bb_range]),
('kv', [2]),
('num_heads', [num_kv_heads]),
('past_key_len', [max_output_len_range]),
('head_size', [head_size]),
])
kv = Tensor(name=f'past_key_value_{i}',
dtype=self._kv_dtype,
shape=[-1, 2, num_kv_heads, -1, head_size],
dim_range=kv_dim_range)
if use_gpt_attention_plugin:
cross_kv_dim_range = OrderedDict([
('batch_size_beam_width', [bb_range]),
('kv', [2]),
('cross_num_heads', [encoder_num_kv_heads]),
('cross_past_key_len', [encoder_input_len_range]),
('cross_head_size', [encoder_head_size]),
])
cross_kv = Tensor(name=f'cross_past_key_value_{i}',
dtype=self._kv_dtype,
shape=[
-1, 2, encoder_num_kv_heads, -1,
encoder_head_size
],
dim_range=cross_kv_dim_range)
past_key_value.append((kv, cross_kv))
else:
# use encoder_output directly, no need to save cross_past_key_value
past_key_value.append((kv, ))
# TODO: Remove this when TRT fix the named dimension
if not remove_input_padding:
assertion(
shape(
input_ids if self.mapping.is_first_pp_rank() else
hidden_states, 0) == shape(kv, 0), 'batch size')
else: # paged_kv_cache == True
# PagedKV setup for KV cache of self-attention
max_blocks_per_seq_range = [[
math.ceil(max_output_len_range[0] / tokens_per_block),
math.ceil(max_output_len_range[1] / tokens_per_block),
math.ceil(max_output_len_range[2] / tokens_per_block)
]]
max_blocks_per_seq_range = [[
x for x in max_blocks_per_seq_range[0]
]]
# PagedKV setup for KV cache of cross-attention
max_cross_blocks_per_seq_range = [[
math.ceil(encoder_input_len_range[0] / tokens_per_block),
math.ceil(encoder_input_len_range[1] / tokens_per_block),
math.ceil(encoder_input_len_range[2] / tokens_per_block)
]]
max_cross_blocks_per_seq_range = [[
x for x in max_cross_blocks_per_seq_range[0]
]]
kv_cache_block_offsets = Tensor(name=f'kv_cache_block_offsets',
dtype=trt.int32,
shape=[-1, 2, -1],
dim_range=OrderedDict([
('batch_size_beam_width',
[bb_range]),
('kv', [2]),
('max_blocks_per_seq',
max_blocks_per_seq_range),
]))
host_kv_cache_block_offsets = Tensor(
name=f'host_kv_cache_block_offsets',
dtype=trt.int32,
shape=[-1, 2, -1],
dim_range=OrderedDict([
('batch_size_beam_width', [bb_range]),
('kv', [2]),
('max_blocks_per_seq', max_blocks_per_seq_range),
]))
host_kv_cache_pool_pointers = Tensor(
name=f'host_kv_cache_pool_pointers',
dtype=trt.int64,
shape=[2],
dim_range=OrderedDict([
('num_pools', [2]),
]))
# paged blocks for cross kv
cross_kv_cache_block_offsets = Tensor(
name=f'cross_kv_cache_block_offsets',
dtype=trt.int32,
shape=[-1, 2, -1],
dim_range=OrderedDict([
('batch_size_beam_width', [bb_range]),
('kv', [2]),
('max_cross_blocks_per_seq',
max_cross_blocks_per_seq_range),
]))
host_cross_kv_cache_block_offsets = Tensor(
name=f'host_cross_kv_cache_block_offsets',
dtype=trt.int32,
shape=[-1, 2, -1],
dim_range=OrderedDict([
('batch_size_beam_width', [bb_range]),
('kv', [2]),
('max_cross_blocks_per_seq',
max_cross_blocks_per_seq_range),
]))
host_cross_kv_cache_pool_pointers = Tensor(
name=f'host_cross_kv_cache_pool_pointers',
dtype=trt.int64,
shape=[2],
dim_range=OrderedDict([
('num_pools', [2]),
]))
for i in layers_range:
past_key_value.append(None)
kv_cache_params = KeyValueCacheParams(
past_key_value=past_key_value,
host_past_key_value_lengths=host_past_key_value_lengths,
host_max_attention_window_sizes=host_max_attention_window_sizes,
host_sink_token_length=host_sink_token_length,
cache_indirection=cache_indirection,
kv_cache_block_offsets=kv_cache_block_offsets,
host_kv_cache_block_offsets=host_kv_cache_block_offsets,
host_kv_cache_pool_pointers=host_kv_cache_pool_pointers,
cross_kv_cache_block_offsets=cross_kv_cache_block_offsets,
host_cross_kv_cache_block_offsets=
host_cross_kv_cache_block_offsets,
host_cross_kv_cache_pool_pointers=
host_cross_kv_cache_pool_pointers,
)
attention_params = AttentionParams(
sequence_length=sequence_length,
context_lengths=context_lengths,
host_context_lengths=host_context_lengths,
max_context_length=max_decoder_input_len,
host_request_types=host_request_types,
encoder_input_lengths=encoder_input_lengths,
encoder_max_input_length=encoder_max_input_length,
host_runtime_perf_knobs=runtime_perf_knobs)
cross_kv_cache_gen = Tensor(name='cross_kv_cache_gen',
dtype=trt.bool,
shape=[1],
dim_range=OrderedDict([
('boolean', [1]),
]))
cross_qkv_reuse = None
num_heads = (self.num_heads + self.mapping.tp_size -
1) // self.mapping.tp_size
cross_qkv_out_dim = num_heads * self.head_size + 2 * num_kv_heads * self.head_size
if self.skip_cross_qkv:
if remove_input_padding:
cross_qkv_reuse = Tensor(
name="cross_qkv_reuse",
dtype=self._dtype,
shape=[-1, cross_qkv_out_dim],
dim_range=OrderedDict([
("encoder_num_tokens", [encoder_num_tokens_range]),
("encoder_qkv_size", [cross_qkv_out_dim]),
]),
)
else:
cross_qkv_reuse = Tensor(
name="cross_qkv_reuse",
dtype=self._dtype,
shape=[-1, -1, cross_qkv_out_dim],
dim_range=OrderedDict([
("batch_size_beam_width_encoder", [bb_range]),
("encoder_input_len", [encoder_input_len_range]),
("encoder_qkv_size", [cross_qkv_out_dim]),
]),
)
result = {
'decoder_input_ids': input_ids,
'encoder_output': encoder_output,
'position_ids': position_ids,
'token_type_ids': token_type_ids,
'use_cache': True,
'attention_mask': attention_mask,
'cross_attention_mask': cross_attention_mask,
'last_token_ids': last_token_ids,
'kv_cache_params': kv_cache_params,
'attention_params': attention_params,
'hidden_states': hidden_states,
'lora_params': lora_params,
'cross_kv_cache_gen': cross_kv_cache_gen,
'cross_qkv_reuse': cross_qkv_reuse,
}
return result
def use_lora(self, lora_config: LoraConfig):
use_lora(self, lora_config, self.trtllm_modules_to_hf_modules)
def precompute_relative_attention_bias(self, build_config):
if self.config.relative_attention and not self.use_implicit_relative_attention:
relative_attention_bias_builder = torch.ops.tensorrt_llm.relative_attention_bias
rel_attn_precomputed = torch.zeros(
(self.config.num_attention_heads // self.mapping.tp_size,
build_config.max_seq_len + 1, build_config.max_seq_len + 1),
dtype=str_dtype_to_torch(self.config.dtype),
device='cuda')
rel_attn_table = numpy_to_torch(
self.rel_attn_table.raw_value).to('cuda')
relative_attention_bias_builder(
rel_attn_precomputed,
rel_attn_table,
self.config.num_attention_heads // self.mapping.tp_size,
build_config.max_seq_len,
self.config.num_buckets,
False,
self.config.max_distance,
)
for layer_idx in range(self.num_layers):
self.decoder_layers[
layer_idx].self_attention.set_rel_attn_table(
build_config.max_seq_len, rel_attn_precomputed)
class WhisperEncoder(PretrainedModel):
def __init__(self, config: PretrainedConfig):
super().__init__(config)
self._dtype = self.config.dtype
# Encoder conv needs to run in fp32 on Volta/Turing
major, minor = torch.cuda.get_device_capability()
if major >= 8:
self._conv_dtype = self._dtype
else:
self._conv_dtype = "float32"
self.conv1 = Conv1d(config.n_mels,
config.hidden_size,
kernel_size=3,
padding=1,
dtype=self._conv_dtype)
self.conv2 = Conv1d(config.hidden_size,
config.hidden_size,
kernel_size=3,
stride=2,
padding=1,
dtype=self._conv_dtype)
self.positional_embedding = Parameter(shape=(config.n_audio_ctx,
config.hidden_size),
dtype=self._dtype)
self.encoder_layers = ModuleList([
EncoderLayer(
hidden_size=config.hidden_size,
ffn_hidden_size=config.hidden_size * 4,
num_attention_heads=config.num_attention_heads,
num_kv_heads=config.num_attention_heads,
head_size=config.hidden_size // config.num_attention_heads,
max_position_embeddings=3000,
q_scaling=1.0,
has_attention_qkvo_bias=True,
has_mlp_bias=True,
hidden_act='gelu',
dtype=self._dtype) for _ in range(config.num_hidden_layers)
])
self.ln_post = LayerNorm(config.hidden_size, dtype=self._dtype)
self.max_audio_feature_seq_len = 3000
def forward(self, input_features: Tensor, input_lengths=None):
if default_net().plugin_config.remove_input_padding:
# BXT,D -> B,T,D -> B,D,T
input_features = input_features.view([
input_lengths.shape[0], self.max_audio_feature_seq_len,
self.config.n_mels
])
input_features = transpose(input_features, 1, 2)
# Encoder conv needs to run in fp32 on Volta/Turing
x_type = input_features.dtype
input_features = cast(input_features, self._conv_dtype)
x = self.conv1(input_features)
x = gelu(x)
x = self.conv2(x)
x = cast(x, x_type)
x = gelu(x)
x = transpose(x, 2, 1)
x = x + cast(
slice(input=self.positional_embedding.value,
starts=[0, 0],
sizes=[
self.max_audio_feature_seq_len // 2,
self.positional_embedding.shape[1]
],
strides=[1, 1]), x.dtype)
if default_net().plugin_config.remove_input_padding:
#B,T,D -> BxT,D
x = x.view([-1, self.config.hidden_size])
hidden_states = x
for encoder_layer in self.encoder_layers:
hidden_states = encoder_layer(hidden_states,
input_lengths=input_lengths)
x = hidden_states
x = self.ln_post(x)
x.mark_output('encoder_output', self._dtype)
return x
def prepare_inputs(self, max_batch_size=16):
bs_range = [1, (max_batch_size + 1) // 2, max_batch_size]
# You may change max_audio_feature_seq_len here for distill-whisper models.
max_audio_feature_seq_len = self.max_audio_feature_seq_len
if not default_net().plugin_config.remove_input_padding:
x = Tensor(
name="input_features",
dtype=self._dtype,
shape=[-1, self.config.n_mels, max_audio_feature_seq_len],
dim_range=OrderedDict([
("batch_size", [bs_range]),
("feature_dim", [self.config.n_mels]),
("feature_len_range", [max_audio_feature_seq_len]),
]))
else:
batch_seqlen_range = [
1,
(max_audio_feature_seq_len * max_batch_size + 1) // 2,
max_audio_feature_seq_len * max_batch_size,
]
x = Tensor(name="input_features",
dtype=self._dtype,
shape=[-1, self.config.n_mels],
dim_range=OrderedDict([
("batch_seqlen_range", [batch_seqlen_range]),
("feature_dim", [self.config.n_mels]),
]))
input_lengths = Tensor(
name="input_lengths",
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([("batch_size", [bs_range])]),
)
return {'input_features': x, 'input_lengths': input_lengths}
def precompute_relative_attention_bias(self, build_config):
pass
|