File size: 18,781 Bytes
5000658 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 |
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import numpy as np
from ..._common import default_net
from ...functional import (ACT2FN, bert_attention, cast, concat, constant,
cumsum, expand, expand_mask, index_select, matmul,
select, shape, slice, softmax, split, unsqueeze)
from ...layers import MLP, ColumnLinear, Embedding, LayerNorm, Linear, RowLinear
from ...mapping import Mapping
from ...module import Module, ModuleList
class BertEmbedding(Module):
def __init__(self,
vocab_size,
hidden_size,
max_position_embeddings,
type_vocab_size,
dtype=None):
super().__init__()
self.vocab_embedding = Embedding(vocab_size, hidden_size, dtype=dtype)
self.position_embedding = Embedding(max_position_embeddings,
hidden_size,
dtype=dtype)
self.token_embedding = Embedding(type_vocab_size,
hidden_size,
dtype=dtype)
self.max_position_embeddings = max_position_embeddings
self.embedding_ln = LayerNorm(normalized_shape=hidden_size, dtype=dtype)
def forward(self, input_ids, position_ids, token_type_ids):
x = self.vocab_embedding(input_ids)
x = x + self.position_embedding(position_ids)
x = x + self.token_embedding(token_type_ids)
x = self.embedding_ln(x)
return x
class BertAttention(Module):
def __init__(self,
hidden_size,
num_attention_heads,
max_position_embeddings,
dtype=None,
tp_group=None,
tp_size=1):
super().__init__()
self.attention_head_size = hidden_size // num_attention_heads
self.num_attention_heads = num_attention_heads // tp_size
self.hidden_size = hidden_size // tp_size
self.max_position_embeddings = max_position_embeddings
self.norm_factor = math.sqrt(self.attention_head_size)
self.qkv = ColumnLinear(hidden_size,
hidden_size * 3,
dtype=dtype,
tp_group=tp_group,
tp_size=tp_size,
gather_output=False)
self.dense = RowLinear(hidden_size,
hidden_size,
dtype=dtype,
tp_group=tp_group,
tp_size=tp_size)
def forward(self,
hidden_states,
attention_mask=None,
input_lengths=None,
max_input_length=None):
qkv = self.qkv(hidden_states)
# attention
if default_net().plugin_config.bert_attention_plugin:
assert input_lengths is not None
context = bert_attention(qkv,
input_lengths,
self.num_attention_heads,
self.attention_head_size,
q_scaling=1.0,
max_input_length=max_input_length)
else:
assert not default_net().plugin_config.remove_input_padding, \
"remove_input_padding requires bert_attention_plugin enabled"
def transpose_for_scores(x):
new_x_shape = concat([
shape(x, 0),
shape(x, 1), self.num_attention_heads,
self.attention_head_size
])
return x.view(new_x_shape).permute([0, 2, 1, 3])
query, key, value = split(qkv, self.hidden_size, dim=2)
query = transpose_for_scores(query)
key = transpose_for_scores(key)
value = transpose_for_scores(value)
key = key.permute([0, 1, 3, 2])
attention_scores = matmul(query, key)
attention_scores = attention_scores / self.norm_factor
if attention_mask is not None:
attention_mask = cast(attention_mask, attention_scores.dtype)
attention_scores = attention_scores + attention_mask
attention_probs = softmax(attention_scores, dim=-1)
context = matmul(attention_probs, value,
use_fp32_acc=False).permute([0, 2, 1, 3])
context = context.view(
concat([shape(context, 0),
shape(context, 1), self.hidden_size]))
context = self.dense(context)
return context
class BertEncoderLayer(Module):
def __init__(self,
hidden_size,
num_attention_heads,
max_position_embeddings,
hidden_act='relu',
tp_group=None,
tp_size=1,
dtype=None):
super().__init__()
self.input_layernorm = LayerNorm(normalized_shape=hidden_size,
dtype=dtype)
self.attention = BertAttention(hidden_size,
num_attention_heads,
max_position_embeddings,
tp_group=tp_group,
tp_size=tp_size,
dtype=dtype)
self.mlp = MLP(hidden_size=hidden_size,
ffn_hidden_size=hidden_size * 4,
hidden_act=hidden_act,
tp_group=tp_group,
tp_size=tp_size,
dtype=dtype)
self.post_layernorm = LayerNorm(normalized_shape=hidden_size,
dtype=dtype)
def forward(self,
hidden_states,
attention_mask=None,
input_lengths=None,
max_input_length=None):
residual = hidden_states
attention_output = self.attention(hidden_states,
attention_mask=attention_mask,
input_lengths=input_lengths,
max_input_length=max_input_length)
hidden_states = residual + attention_output
hidden_states = self.input_layernorm(hidden_states)
residual = hidden_states
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
hidden_states = self.post_layernorm(hidden_states)
return hidden_states
class BertModel(Module):
def __init__(self,
num_layers,
num_heads,
hidden_size,
vocab_size,
hidden_act,
max_position_embeddings,
type_vocab_size,
pad_token_id=None,
is_roberta=False,
mapping=Mapping(),
dtype=None):
super().__init__()
self.max_position_embeddings = max_position_embeddings
self.padding_idx = pad_token_id
self.is_roberta = is_roberta
self.embedding = BertEmbedding(
vocab_size=vocab_size,
hidden_size=hidden_size,
max_position_embeddings=max_position_embeddings,
type_vocab_size=type_vocab_size,
dtype=dtype)
self.layers = ModuleList([
BertEncoderLayer(hidden_size=hidden_size,
num_attention_heads=num_heads,
max_position_embeddings=max_position_embeddings,
hidden_act=hidden_act,
tp_group=mapping.tp_group,
tp_size=mapping.tp_size,
dtype=dtype) for _ in range(num_layers)
])
def forward(self,
input_ids=None,
input_lengths=None,
position_ids=None,
token_type_ids=None,
hidden_states=None,
max_input_length=None):
# remove_input_padding requires these fields as explicit input
extended_attention_mask = None
if not default_net().plugin_config.remove_input_padding:
seq_len_2d = concat([1, shape(input_ids, 1)])
# create position ids
position_ids_buffer = constant(
np.expand_dims(
np.arange(self.max_position_embeddings).astype(np.int32),
0))
tmp_position_ids = slice(position_ids_buffer,
starts=[0, 0],
sizes=seq_len_2d)
tmp_position_ids = expand(tmp_position_ids, shape(input_ids)) #BxL
tmp_input_lengths = unsqueeze(input_lengths, 1) #Bx1
tmp_input_lengths = expand(tmp_input_lengths,
shape(input_ids)) #BxL
mask = tmp_position_ids < tmp_input_lengths # BxL
mask = mask.cast('int32')
if position_ids is None:
if self.is_roberta:
# see create_position_ids_from_input_ids() in https://github.com/huggingface/transformers/blob/main/src/transformers/models/roberta/modeling_roberta.py
position_ids = (tmp_position_ids + 1) * mask
position_ids = position_ids + self.padding_idx
else:
position_ids = slice(position_ids_buffer,
starts=[0, 0],
sizes=seq_len_2d)
position_ids = expand(position_ids, shape(input_ids))
# create extended_attention_mask as https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py
extended_attention_mask = expand_mask(mask,
tgt_len=1) # BxL -> Bx1x1xL
# create token_type_ids
if token_type_ids is None:
token_type_ids_buffer = constant(
np.expand_dims(
np.zeros(self.max_position_embeddings).astype(np.int32),
0))
token_type_ids = slice(token_type_ids_buffer,
starts=[0, 0],
sizes=seq_len_2d)
token_type_ids = expand(token_type_ids, shape(input_ids))
hidden_states = self.embedding(input_ids, position_ids, token_type_ids)
for layer in self.layers:
hidden_states = layer(hidden_states=hidden_states,
input_lengths=input_lengths,
attention_mask=extended_attention_mask,
max_input_length=max_input_length)
return hidden_states
class BertForQuestionAnswering(Module):
def __init__(self,
num_layers,
num_heads,
hidden_size,
vocab_size,
hidden_act,
max_position_embeddings,
type_vocab_size,
pad_token_id=None,
is_roberta=False,
num_labels=2,
mapping=Mapping(),
dtype=None):
super().__init__()
self.bert = BertModel(num_layers=num_layers,
num_heads=num_heads,
hidden_size=hidden_size,
vocab_size=vocab_size,
hidden_act=hidden_act,
max_position_embeddings=max_position_embeddings,
type_vocab_size=type_vocab_size,
pad_token_id=pad_token_id,
is_roberta=is_roberta,
mapping=mapping,
dtype=dtype)
self.num_labels = num_labels
self.qa_outputs = Linear(hidden_size, num_labels, dtype=dtype)
def forward(self,
input_ids=None,
input_lengths=None,
token_type_ids=None,
position_ids=None,
hidden_states=None):
hidden_states = self.bert.forward(input_ids=input_ids,
input_lengths=input_lengths,
token_type_ids=token_type_ids,
position_ids=position_ids,
hidden_states=hidden_states)
logits = self.qa_outputs(hidden_states)
return logits
class BertPooler(Module):
def __init__(self, hidden_size, dtype):
super().__init__()
self.dense = Linear(hidden_size, hidden_size, dtype=dtype)
self.activation = ACT2FN['tanh']
def forward(self, hidden_states, input_lengths, remove_input_padding):
if not remove_input_padding:
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = select(hidden_states, 1, 0)
else:
# when remove_input_padding is enabled, the shape of hidden_states is [num_tokens, hidden_size]
# We can take the first token of each sequence according to input_lengths,
# and then do pooling similar to padding mode.
# For example, if input_lengths is [8, 5, 6], then the indices of first tokens
# should be [0, 8, 13]
first_token_indices = cumsum(
concat([
0,
slice(input_lengths,
starts=[0],
sizes=(shape(input_lengths) -
constant(np.array([1], dtype=np.int32))))
]), 0)
first_token_tensor = index_select(hidden_states, 0,
first_token_indices)
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
class RobertaClassificationHead(Module):
"""Head for sentence-level classification tasks."""
def __init__(self, hidden_size, dtype, num_labels):
super().__init__()
self.dense = Linear(hidden_size, hidden_size, dtype=dtype)
self.out_proj = Linear(hidden_size, num_labels)
def forward(self, features, **kwargs):
x = select(features, 1, 0)
x = self.dense(x)
x = ACT2FN['tanh'](x)
x = self.out_proj(x)
return x
class BertForSequenceClassification(Module):
def __init__(self,
num_layers,
num_heads,
hidden_size,
vocab_size,
hidden_act,
max_position_embeddings,
type_vocab_size,
pad_token_id=None,
is_roberta=False,
num_labels=2,
mapping=Mapping(),
dtype=None):
super().__init__()
self.is_roberta = is_roberta
self.bert = BertModel(num_layers=num_layers,
num_heads=num_heads,
hidden_size=hidden_size,
vocab_size=vocab_size,
hidden_act=hidden_act,
max_position_embeddings=max_position_embeddings,
type_vocab_size=type_vocab_size,
pad_token_id=pad_token_id,
is_roberta=is_roberta,
mapping=mapping,
dtype=dtype)
self.num_labels = num_labels
if not is_roberta:
self.pooler = BertPooler(hidden_size=hidden_size, dtype=dtype)
self.classifier = Linear(hidden_size, num_labels, dtype=dtype)
else:
self.classifier = RobertaClassificationHead(hidden_size=hidden_size,
num_labels=num_labels,
dtype=dtype)
def forward(self,
input_ids,
input_lengths,
token_type_ids=None,
position_ids=None,
hidden_states=None,
max_input_length=None):
remove_input_padding = default_net().plugin_config.remove_input_padding
# required as explicit input in remove_input_padding mode
# see examples/bert/run_remove_input_padding.py for how to create them from input_ids and input_lengths
if remove_input_padding:
assert token_type_ids is not None and \
position_ids is not None and \
max_input_length is not None, \
"token_type_ids, position_ids, max_input_length is required " \
"in remove_input_padding mode"
hidden_states = self.bert.forward(input_ids=input_ids,
input_lengths=input_lengths,
token_type_ids=token_type_ids,
position_ids=position_ids,
hidden_states=hidden_states,
max_input_length=max_input_length)
if not self.is_roberta:
pooled_output = self.pooler(
hidden_states=hidden_states,
input_lengths=input_lengths,
remove_input_padding=remove_input_padding)
logits = self.classifier(pooled_output)
else:
logits = self.classifier(hidden_states)
return logits
|