|
|
|
|
|
|
|
|
|
|
|
from torch import nn |
|
|
import torch |
|
|
import functools |
|
|
from torch.nn import functional as F |
|
|
import warnings |
|
|
|
|
|
|
|
|
class TextualHead(nn.Module): |
|
|
def __init__(self, |
|
|
visual_feature_size: int, vocab_size: int, hidden_size: int): |
|
|
super().__init__() |
|
|
self.visual_feature_size = visual_feature_size |
|
|
self.vocab_size = vocab_size |
|
|
self.hidden_size = hidden_size |
|
|
|
|
|
@property |
|
|
def textual_feature_size(self): |
|
|
return self.hidden_size |
|
|
|
|
|
|
|
|
class WordAndPositionalEmbedding(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
vocab_size: int, |
|
|
hidden_size: int, |
|
|
dropout: float = 0.0, |
|
|
max_caption_length: int = 30, |
|
|
padding_idx: int = 0, |
|
|
): |
|
|
super().__init__() |
|
|
self.vocab_size = vocab_size |
|
|
self.padding_idx = padding_idx |
|
|
|
|
|
|
|
|
self.words = nn.Embedding(vocab_size, hidden_size) |
|
|
|
|
|
|
|
|
|
|
|
self.positions = nn.Embedding(max_caption_length, hidden_size) |
|
|
self.layer_norm = nn.LayerNorm( |
|
|
hidden_size, eps=1e-8, elementwise_affine=True |
|
|
) |
|
|
self.dropout = nn.Dropout(p=dropout) |
|
|
|
|
|
def forward(self, tokens: torch.Tensor): |
|
|
position_indices = self._create_position_indices(tokens) |
|
|
|
|
|
|
|
|
word_embeddings = self.words(tokens) |
|
|
position_embeddings = self.positions(position_indices) |
|
|
|
|
|
|
|
|
embeddings = self.layer_norm(word_embeddings + position_embeddings) |
|
|
embeddings = self.dropout(embeddings) |
|
|
|
|
|
return embeddings |
|
|
|
|
|
@functools.lru_cache(maxsize=128) |
|
|
def _create_position_indices(self, tokens: torch.Tensor): |
|
|
|
|
|
|
|
|
batch_size, max_caption_length = tokens.size() |
|
|
positions = torch.arange( |
|
|
max_caption_length, dtype=tokens.dtype, device=tokens.device |
|
|
) |
|
|
|
|
|
positions = positions.unsqueeze(0).expand(batch_size, max_caption_length) |
|
|
return positions |
|
|
|
|
|
|
|
|
class BertEncoderAsDecoder(nn.Module): |
|
|
def __init__(self, encoder): |
|
|
super().__init__() |
|
|
self.encoder = encoder |
|
|
|
|
|
def forward(self, tgt, memory, |
|
|
tgt_mask=None, |
|
|
tgt_key_padding_mask=None, |
|
|
memory_key_padding_mask=None, |
|
|
tgt_bi_valid_mask=None, |
|
|
encoder_history_states=None, |
|
|
): |
|
|
assert tgt_key_padding_mask is None, 'not supported' |
|
|
assert tgt_mask.dim() == 2 |
|
|
assert tgt_mask.shape[0] == tgt_mask.shape[1] |
|
|
|
|
|
tgt = tgt.transpose(0, 1) |
|
|
memory = memory.transpose(0, 1) |
|
|
|
|
|
hidden_states = torch.cat((memory, tgt), dim=1) |
|
|
num_tgt = tgt.shape[1] |
|
|
num_memory = memory.shape[1] |
|
|
device = tgt.device |
|
|
dtype = tgt.dtype |
|
|
top_left = torch.zeros((num_memory, num_memory), device=device, dtype=dtype) |
|
|
top_right = torch.full((num_memory, num_tgt), float('-inf'), device=tgt.device, dtype=dtype,) |
|
|
bottom_left = torch.zeros((num_tgt, num_memory), dtype=dtype, device=tgt_mask.device,) |
|
|
left = torch.cat((top_left, bottom_left), dim=0) |
|
|
right = torch.cat((top_right, tgt_mask.to(dtype)), dim=0) |
|
|
|
|
|
full_attention_mask = torch.cat((left, right), dim=1)[None, :] |
|
|
|
|
|
if memory_key_padding_mask is None: |
|
|
memory_key_padding_mask = torch.full((memory.shape[0], memory.shape[1]), fill_value=False, device=device) |
|
|
|
|
|
assert memory_key_padding_mask.dtype == torch.bool |
|
|
zero_negative_infinity = torch.zeros_like(memory_key_padding_mask, dtype=tgt.dtype) |
|
|
zero_negative_infinity[memory_key_padding_mask] = float('-inf') |
|
|
full_attention_mask = full_attention_mask.expand((memory_key_padding_mask.shape[0], num_memory + num_tgt, num_memory + num_tgt)) |
|
|
full_attention_mask = full_attention_mask.clone() |
|
|
origin_left = full_attention_mask[:, :, :num_memory] |
|
|
update = zero_negative_infinity[:, None, :] |
|
|
full_attention_mask[:, :, :num_memory] = origin_left + update |
|
|
|
|
|
if tgt_bi_valid_mask is not None: |
|
|
|
|
|
bs = full_attention_mask.shape[0] |
|
|
|
|
|
|
|
|
max_valid_target = tgt_bi_valid_mask.shape[1] |
|
|
mask = tgt_bi_valid_mask[:, None, :].expand((bs, num_memory+num_tgt, max_valid_target)) |
|
|
full_attention_mask[:, :, num_memory:(num_memory+max_valid_target)][mask] = 0 |
|
|
|
|
|
|
|
|
full_attention_mask = full_attention_mask[:, None, :, :] |
|
|
|
|
|
if encoder_history_states is None: |
|
|
result = self.encoder( |
|
|
hidden_states=hidden_states, |
|
|
attention_mask=full_attention_mask, |
|
|
encoder_history_states=encoder_history_states, |
|
|
) |
|
|
result = list(result) |
|
|
result[0] = result[0][:, num_memory:].transpose(0, 1) |
|
|
if self.encoder.output_hidden_states: |
|
|
return result[0], result[1] |
|
|
else: |
|
|
|
|
|
return result[0] |
|
|
else: |
|
|
encoder_out = self.encoder( |
|
|
hidden_states=hidden_states[:, -1:], |
|
|
attention_mask=full_attention_mask[:, :, -1:], |
|
|
encoder_history_states=encoder_history_states, |
|
|
) |
|
|
result = encoder_out[0].transpose(0, 1) |
|
|
if self.encoder.output_hidden_states: |
|
|
return result, encoder_out[1] |
|
|
else: |
|
|
return result |
|
|
|
|
|
|
|
|
def create_transformer(decoder_type, norm_type, |
|
|
textual_feature_size, |
|
|
attention_heads, |
|
|
feedforward_size, |
|
|
dropout, |
|
|
num_layers, |
|
|
output_hidden_states=False, |
|
|
use_mlp_wrapper=None, |
|
|
use_act_checkpoint=True, |
|
|
): |
|
|
assert norm_type in ['post', 'pre'] |
|
|
if decoder_type is None: |
|
|
LayerClass = ( |
|
|
nn.TransformerDecoderLayer |
|
|
if norm_type == "post" |
|
|
else PreNormTransformerDecoderLayer |
|
|
) |
|
|
_layer = LayerClass( |
|
|
textual_feature_size, |
|
|
attention_heads, |
|
|
dim_feedforward=feedforward_size, |
|
|
dropout=dropout, |
|
|
activation="gelu", |
|
|
) |
|
|
return nn.TransformerDecoder(_layer, num_layers) |
|
|
elif decoder_type == 'bert_en': |
|
|
from .modeling_bert import BertConfig, BertEncoder |
|
|
config = BertConfig( |
|
|
vocab_size_or_config_json_file=30522, |
|
|
hidden_size=textual_feature_size, |
|
|
num_hidden_layers=num_layers, |
|
|
num_attention_heads=attention_heads, |
|
|
intermediate_size=feedforward_size, |
|
|
hidden_act="gelu", |
|
|
hidden_dropout_prob=0.1, |
|
|
attention_probs_dropout_prob=0.1, |
|
|
layer_norm_eps=1e-12, |
|
|
) |
|
|
config.pre_norm = (norm_type == 'pre') |
|
|
config.use_mlp_wrapper = use_mlp_wrapper |
|
|
config.output_hidden_states = output_hidden_states |
|
|
encoder = BertEncoder(config, use_act_checkpoint=use_act_checkpoint) |
|
|
return BertEncoderAsDecoder(encoder) |
|
|
|
|
|
|
|
|
class PreNormTransformerDecoderLayer(nn.TransformerDecoderLayer): |
|
|
def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, |
|
|
tgt_key_padding_mask=None, memory_key_padding_mask=None): |
|
|
|
|
|
|
|
|
|
|
|
tgt2 = self.norm1(tgt) |
|
|
tgt2, _ = self.self_attn( |
|
|
tgt2, tgt2, tgt2, attn_mask=tgt_mask, |
|
|
key_padding_mask=tgt_key_padding_mask |
|
|
) |
|
|
tgt = tgt + self.dropout1(tgt2) |
|
|
|
|
|
|
|
|
tgt2 = self.norm2(tgt) |
|
|
tgt2, _ = self.multihead_attn( |
|
|
tgt2, memory, memory, attn_mask=memory_mask, |
|
|
key_padding_mask=memory_key_padding_mask |
|
|
) |
|
|
tgt = tgt + self.dropout2(tgt2) |
|
|
|
|
|
|
|
|
tgt2 = self.norm3(tgt) |
|
|
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) |
|
|
tgt = tgt + self.dropout3(tgt2) |
|
|
return tgt |
|
|
|
|
|
|
|
|
class TransformerDecoderTextualHead(TextualHead): |
|
|
def __init__( |
|
|
self, |
|
|
object_feature_size: int, |
|
|
vocab_size: int, |
|
|
hidden_size: int, |
|
|
num_layers: int, |
|
|
attention_heads: int, |
|
|
feedforward_size: int, |
|
|
dropout: float = 0.1, |
|
|
norm_type: str = "post", |
|
|
mask_future_positions: bool = True, |
|
|
max_caption_length: int = 1024, |
|
|
padding_idx: int = 0, |
|
|
decoder_type=None, |
|
|
not_tie_weight=None, |
|
|
output_hidden_states=None, |
|
|
use_mlp_wrapper=None, |
|
|
use_act_checkpoint=True, |
|
|
): |
|
|
super().__init__(object_feature_size, vocab_size, hidden_size) |
|
|
self.num_layers = num_layers |
|
|
self.attention_heads = attention_heads |
|
|
self.feedforward_size = feedforward_size |
|
|
self.dropout = dropout |
|
|
assert mask_future_positions |
|
|
self.padding_idx = padding_idx |
|
|
|
|
|
self.object_feature_projection = nn.Sequential( |
|
|
nn.Linear(object_feature_size, self.textual_feature_size), |
|
|
nn.LayerNorm(self.textual_feature_size)) |
|
|
|
|
|
self.embedding = WordAndPositionalEmbedding( |
|
|
self.vocab_size, |
|
|
self.textual_feature_size, |
|
|
dropout=dropout, |
|
|
max_caption_length=max_caption_length, |
|
|
padding_idx=padding_idx, |
|
|
) |
|
|
self.transformer = create_transformer( |
|
|
decoder_type=decoder_type, |
|
|
norm_type=norm_type, |
|
|
textual_feature_size=self.textual_feature_size, |
|
|
attention_heads=self.attention_heads, |
|
|
feedforward_size=self.feedforward_size, |
|
|
dropout=dropout, |
|
|
num_layers=self.num_layers, |
|
|
output_hidden_states=output_hidden_states, |
|
|
use_mlp_wrapper=use_mlp_wrapper, |
|
|
use_act_checkpoint=use_act_checkpoint, |
|
|
) |
|
|
self.apply(self._init_weights) |
|
|
|
|
|
|
|
|
|
|
|
self.output = nn.Linear(self.textual_feature_size, vocab_size) |
|
|
if not not_tie_weight: |
|
|
self.output.weight = self.embedding.words.weight |
|
|
|
|
|
@staticmethod |
|
|
def _init_weights(module): |
|
|
"""Initialize weights like BERT - N(0.0, 0.02), bias = 0.""" |
|
|
|
|
|
if isinstance(module, nn.Linear): |
|
|
module.weight.data.normal_(mean=0.0, std=0.02) |
|
|
elif isinstance(module, nn.MultiheadAttention): |
|
|
module.in_proj_weight.data.normal_(mean=0.0, std=0.02) |
|
|
module.out_proj.weight.data.normal_(mean=0.0, std=0.02) |
|
|
elif isinstance(module, nn.Embedding): |
|
|
module.weight.data.normal_(mean=0.0, std=0.02) |
|
|
if module.padding_idx is not None: |
|
|
module.weight.data[module.padding_idx].zero_() |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states, |
|
|
text_tokens, |
|
|
): |
|
|
projected_object_features = self.object_feature_projection(hidden_states) if hidden_states is not None else None |
|
|
batch_size, max_text_length = text_tokens.size() |
|
|
text_embeddings = self.embedding(text_tokens) |
|
|
|
|
|
|
|
|
uni_mask_zero_neg = self._generate_future_mask( |
|
|
max_text_length, text_embeddings.dtype, text_embeddings.device |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
text_embeddings = text_embeddings.transpose(0, 1) |
|
|
|
|
|
projected_object_features = projected_object_features.transpose(0, 1) |
|
|
|
|
|
|
|
|
|
|
|
trans_out = self.transformer( |
|
|
text_embeddings, |
|
|
projected_object_features, |
|
|
tgt_mask=uni_mask_zero_neg, |
|
|
) |
|
|
if isinstance(trans_out, tuple): |
|
|
textual_features = trans_out[0] |
|
|
else: |
|
|
assert isinstance(trans_out, torch.Tensor) |
|
|
textual_features = trans_out |
|
|
|
|
|
|
|
|
textual_features = textual_features.transpose(0, 1) |
|
|
|
|
|
|
|
|
output_logits = self.output(textual_features) |
|
|
if isinstance(trans_out, tuple): |
|
|
return output_logits, trans_out[1] |
|
|
else: |
|
|
return output_logits |
|
|
|
|
|
def _generate_future_mask( |
|
|
self, size: int, dtype: torch.dtype, device: torch.device |
|
|
): |
|
|
|
|
|
mask = torch.triu( |
|
|
torch.ones(size, size, device=device, dtype=dtype), diagonal=1 |
|
|
) |
|
|
mask = mask.masked_fill(mask == 1, float("-inf")) |
|
|
return mask |
|
|
|
|
|
|
|
|
class AutoRegressiveBeamSearch(object): |
|
|
def __init__( |
|
|
self, |
|
|
end_token_id: int, |
|
|
max_steps: int = 50, |
|
|
beam_size: int = 5, |
|
|
objectdet=True, |
|
|
per_node_beam_size: int = 2, |
|
|
): |
|
|
self._eos_index = end_token_id |
|
|
self.max_steps = max_steps |
|
|
self.beam_size = beam_size |
|
|
self.objectdet = objectdet |
|
|
self.per_node_beam_size = per_node_beam_size or beam_size |
|
|
|
|
|
def search(self, begin_tokens, step): |
|
|
if self.beam_size > 1 and self.objectdet: |
|
|
only_return_best = False |
|
|
else: |
|
|
only_return_best = True |
|
|
|
|
|
batch_size = begin_tokens.size()[0] |
|
|
|
|
|
predictions = begin_tokens.unsqueeze(1).expand((batch_size, self.beam_size, begin_tokens.shape[-1])) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
start_class_logits = step(begin_tokens) |
|
|
|
|
|
|
|
|
|
|
|
start_class_logprobs = F.log_softmax(start_class_logits, dim=1) |
|
|
|
|
|
num_classes = start_class_logprobs.size()[1] |
|
|
|
|
|
|
|
|
start_top_logprobs, start_predicted_classes = start_class_logprobs.topk( |
|
|
self.beam_size |
|
|
) |
|
|
|
|
|
if ( |
|
|
self.beam_size == 1 |
|
|
and (start_predicted_classes == self._eos_index).all() |
|
|
): |
|
|
warnings.warn( |
|
|
"Empty object description predicted. You may want to increase beam" |
|
|
"size or ensure your step function is working properly.", |
|
|
RuntimeWarning, |
|
|
) |
|
|
if only_return_best: |
|
|
return start_predicted_classes, start_top_logprobs |
|
|
else: |
|
|
return start_predicted_classes.unsqueeze(-1), start_top_logprobs |
|
|
|
|
|
|
|
|
|
|
|
last_logprobs = start_top_logprobs |
|
|
|
|
|
|
|
|
predictions = torch.cat([predictions, start_predicted_classes.unsqueeze(-1)], dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
logprobs_after_end = start_class_logprobs.new_full( |
|
|
(batch_size * self.beam_size, num_classes), float("-inf") |
|
|
) |
|
|
logprobs_after_end[:, self._eos_index] = 0.0 |
|
|
|
|
|
logits_after_end = start_class_logprobs.new_full( |
|
|
(batch_size * self.beam_size, num_classes), float("-inf") |
|
|
) |
|
|
logits_after_end[:, self._eos_index] = 0 |
|
|
|
|
|
while predictions.shape[-1] < self.max_steps: |
|
|
|
|
|
last_predictions = predictions[:, :, -1].reshape(batch_size * self.beam_size) |
|
|
|
|
|
|
|
|
|
|
|
if (last_predictions == self._eos_index).all(): |
|
|
break |
|
|
|
|
|
predictions_so_far = predictions.view( |
|
|
batch_size * self.beam_size, -1 |
|
|
) |
|
|
|
|
|
class_logits = step(predictions_so_far) |
|
|
|
|
|
|
|
|
|
|
|
class_logits = class_logits.scatter(1, predictions_so_far[:, -1].view((-1, 1)), -10000) |
|
|
|
|
|
|
|
|
last_predictions_expanded = last_predictions.unsqueeze(-1).expand( |
|
|
batch_size * self.beam_size, num_classes |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class_logits = torch.where( |
|
|
last_predictions_expanded == self._eos_index, |
|
|
logits_after_end, |
|
|
class_logits, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
class_logprobs = F.log_softmax(class_logits, dim=1) |
|
|
|
|
|
|
|
|
top_logprobs, predicted_classes = class_logprobs.topk( |
|
|
self.per_node_beam_size |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
expanded_last_logprobs = ( |
|
|
last_logprobs.unsqueeze(2) |
|
|
.expand(batch_size, self.beam_size, self.per_node_beam_size) |
|
|
.reshape(batch_size * self.beam_size, self.per_node_beam_size) |
|
|
) |
|
|
|
|
|
summed_top_logprobs = top_logprobs + expanded_last_logprobs |
|
|
|
|
|
|
|
|
reshaped_summed = summed_top_logprobs.reshape( |
|
|
batch_size, self.beam_size * self.per_node_beam_size |
|
|
) |
|
|
|
|
|
reshaped_predicted_classes = predicted_classes.reshape( |
|
|
batch_size, self.beam_size * self.per_node_beam_size |
|
|
) |
|
|
|
|
|
reshaped_beam = ( |
|
|
predictions.view(batch_size * self.beam_size, 1, -1) |
|
|
.repeat(1, self.per_node_beam_size, 1) |
|
|
.reshape(batch_size, self.beam_size * self.per_node_beam_size, -1) |
|
|
) |
|
|
|
|
|
reshaped_beam = torch.cat([reshaped_beam, reshaped_predicted_classes.unsqueeze(-1)], dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
restricted_beam_logprobs, restricted_beam_indices = reshaped_summed.topk( |
|
|
self.beam_size |
|
|
) |
|
|
predictions = reshaped_beam.gather( |
|
|
1, restricted_beam_indices.unsqueeze(-1).repeat(1,1,reshaped_beam.shape[-1]) |
|
|
) |
|
|
|
|
|
|
|
|
last_logprobs = restricted_beam_logprobs |
|
|
|
|
|
if not torch.isfinite(last_logprobs).all(): |
|
|
warnings.warn( |
|
|
"Infinite log probs encountered. Some final descriptions may not " |
|
|
"make sense. This can happen when the beam size is larger than" |
|
|
" the number of valid (non-zero probability) transitions that " |
|
|
"the step function produces.", |
|
|
RuntimeWarning, |
|
|
) |
|
|
|
|
|
|
|
|
if only_return_best: |
|
|
|
|
|
predictions = predictions[:, 0, :] |
|
|
last_logprobs = last_logprobs[:, 0] |
|
|
num_valid = (predictions != self._eos_index).sum(dim=-1) |
|
|
num_valid += (predictions == self._eos_index).sum(dim=-1) > 0 |
|
|
num_valid = num_valid - begin_tokens.shape[1] |
|
|
num_valid = num_valid.clip(min=1) |
|
|
|
|
|
last_logprobs = last_logprobs / num_valid |
|
|
|
|
|
return predictions, last_logprobs |
|
|
|
|
|
|
|
|
class GRiTTextDecoder(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
transformer, |
|
|
begin_token_id=101, |
|
|
beamsearch_decode=None, |
|
|
loss_type=None, |
|
|
tokenizer=None, |
|
|
): |
|
|
super().__init__() |
|
|
self.textual = transformer |
|
|
self.padding_idx = self.textual.padding_idx |
|
|
|
|
|
self.begin_token_id = begin_token_id |
|
|
self.beamsearch_decode = beamsearch_decode |
|
|
self.tokenizer = tokenizer |
|
|
|
|
|
if loss_type is None: |
|
|
self.loss = nn.CrossEntropyLoss(ignore_index=self.padding_idx) |
|
|
elif loss_type == 'smooth': |
|
|
self.loss = SmoothLabelCrossEntropyLoss(ignore_index=self.padding_idx) |
|
|
else: |
|
|
raise NotImplementedError(loss_type) |
|
|
|
|
|
def forward(self, batch): |
|
|
object_features = batch['object_features'] |
|
|
|
|
|
if self.training: |
|
|
caption_token_input = batch["text_tokens"] |
|
|
|
|
|
output_logits = self.textual( |
|
|
object_features, |
|
|
caption_token_input, |
|
|
) |
|
|
|
|
|
if 'need_predict' in batch: |
|
|
|
|
|
|
|
|
target = batch["text_tokens"].clone() |
|
|
target[batch['need_predict'] == 0] = self.padding_idx |
|
|
else: |
|
|
target = batch["text_tokens"] |
|
|
|
|
|
feat = output_logits[:, :-1].contiguous() |
|
|
target = target[:, 1:].contiguous() |
|
|
feat = feat.view(-1, self.textual.vocab_size) |
|
|
target = target.view(-1) |
|
|
|
|
|
valid_mask = target != self.padding_idx |
|
|
target = target[valid_mask] |
|
|
feat = feat[valid_mask] |
|
|
loss = self.loss(feat, target) |
|
|
|
|
|
return loss |
|
|
else: |
|
|
output_dict = self.infer(object_features) |
|
|
return output_dict |
|
|
|
|
|
def infer(self, object_features): |
|
|
batch_size = object_features.size(0) |
|
|
begin_tokens = object_features.new_full( |
|
|
(batch_size, 1), self.begin_token_id |
|
|
).long() |
|
|
|
|
|
decoding_step = functools.partial( |
|
|
self.decoding_step, object_features |
|
|
) |
|
|
|
|
|
object_description_tokens, logprobs = self.beamsearch_decode.search( |
|
|
begin_tokens, decoding_step |
|
|
) |
|
|
|
|
|
output_dict = { |
|
|
'predictions': object_description_tokens, |
|
|
'logprobs': logprobs, |
|
|
} |
|
|
|
|
|
return output_dict |
|
|
|
|
|
def decoding_step(self, object_features, partial_text): |
|
|
batch_size = object_features.shape[0] |
|
|
beam_size = int(partial_text.size(0) / batch_size) |
|
|
if beam_size > 1: |
|
|
batch_size, num_token, channels = object_features.size() |
|
|
object_features = object_features.unsqueeze(1).repeat(1, beam_size, 1, 1) |
|
|
object_features = object_features.view( |
|
|
batch_size * beam_size, num_token, channels |
|
|
) |
|
|
|
|
|
text_lengths = torch.ones_like(partial_text) |
|
|
if len(text_lengths.size()) != 2: |
|
|
partial_text = partial_text.unsqueeze(1) |
|
|
|
|
|
|
|
|
logits = self.textual( |
|
|
object_features, |
|
|
partial_text, |
|
|
) |
|
|
|
|
|
return logits[:, -1, :].float() |
|
|
|
|
|
|
|
|
class SmoothLabelCrossEntropyLoss(nn.Module): |
|
|
def __init__(self, eps=0.1, log_prefix='', ignore_index=None): |
|
|
super().__init__() |
|
|
self.eps = eps |
|
|
self.log_soft = nn.LogSoftmax(dim=1) |
|
|
self.kl = nn.KLDivLoss(reduction='none') |
|
|
|
|
|
self.iter = 0 |
|
|
self.max_loss = 0 |
|
|
self.min_loss = 0 |
|
|
self.log_prefix = log_prefix |
|
|
self.ignore_index = ignore_index |
|
|
|
|
|
def forward(self, feature, target): |
|
|
feature = feature.float() |
|
|
if self.ignore_index is not None: |
|
|
valid_mask = target != self.ignore_index |
|
|
target = target[valid_mask] |
|
|
feature = feature[valid_mask] |
|
|
assert target.numel() > 0 |
|
|
self.iter += 1 |
|
|
eps = self.eps |
|
|
n_class = feature.size(1) |
|
|
one_hot = torch.zeros_like(feature).scatter(1, target.view(-1, 1), 1) |
|
|
one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1) |
|
|
log_prb = self.log_soft(feature) |
|
|
loss = self.kl(log_prb, one_hot) |
|
|
return loss.sum(dim=1).mean() |
|
|
|
|
|
|