File size: 7,046 Bytes
6f0b660 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
import torch.nn as nn
from torch.nn import BCEWithLogitsLoss, MSELoss
from .loss_d_fine import DFineForObjectDetectionLoss
from .loss_deformable_detr import DeformableDetrForObjectDetectionLoss, DeformableDetrForSegmentationLoss
from .loss_for_object_detection import ForObjectDetectionLoss, ForSegmentationLoss
from .loss_grounding_dino import GroundingDinoForObjectDetectionLoss
from .loss_rt_detr import RTDetrForObjectDetectionLoss
def fixed_cross_entropy(
source: torch.Tensor,
target: torch.Tensor,
num_items_in_batch: Optional[torch.Tensor] = None,
ignore_index: int = -100,
**kwargs,
) -> torch.Tensor:
reduction = "sum" if num_items_in_batch is not None else "mean"
loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction)
if reduction == "sum":
# just in case users pass an int for num_items_in_batch, which could be the case for custom trainer
if torch.is_tensor(num_items_in_batch):
num_items_in_batch = num_items_in_batch.to(loss.device)
loss = loss / num_items_in_batch
return loss
def ForCausalLMLoss(
logits,
labels,
vocab_size: int,
num_items_in_batch: Optional[torch.Tensor] = None,
ignore_index: int = -100,
shift_labels: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
if shift_labels is None:
# Shift so that tokens < n predict n
labels = nn.functional.pad(labels, (0, 1), value=ignore_index)
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
logits = logits.view(-1, vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(logits.device)
loss = fixed_cross_entropy(logits, shift_labels, num_items_in_batch, ignore_index, **kwargs)
return loss
def ForMaskedLMLoss(
logits: torch.Tensor,
labels: torch.Tensor,
vocab_size: int,
num_items_in_batch: Optional[torch.Tensor] = None,
ignore_index: int = -100,
**kwargs,
):
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Flatten the tokens
logits = logits.view(-1, vocab_size)
labels = labels.view(-1)
# Enable model parallelism
labels = labels.to(logits.device)
loss = fixed_cross_entropy(logits, labels, num_items_in_batch, ignore_index, **kwargs)
return loss
def ForSequenceClassificationLoss(labels: torch.Tensor, pooled_logits: torch.Tensor, config, **kwargs) -> torch.Tensor:
num_labels = config.num_labels
if config.problem_type is None:
if num_labels == 1:
config.problem_type = "regression"
elif num_labels > 1 and (labels.dtype in (torch.long, torch.int)):
config.problem_type = "single_label_classification"
else:
config.problem_type = "multi_label_classification"
labels = labels.to(pooled_logits.device)
if config.problem_type == "regression":
loss_fct = MSELoss()
if num_labels == 1:
return loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
return loss_fct(pooled_logits, labels)
if config.problem_type == "single_label_classification":
return fixed_cross_entropy(pooled_logits.view(-1, num_labels), labels.view(-1), **kwargs)
if config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
return loss_fct(pooled_logits, labels)
raise RuntimeError(f"Invalid problem type: {config.problem_type}")
def ForQuestionAnsweringLoss(start_logits, end_logits, start_positions, end_positions, **kwargs):
total_loss = None
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1).to(start_logits.device)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1).to(end_logits.device)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions = start_positions.clamp(0, ignored_index)
end_positions = end_positions.clamp(0, ignored_index)
start_loss = fixed_cross_entropy(start_logits, start_positions, ignore_index=ignored_index, **kwargs)
end_loss = fixed_cross_entropy(end_logits, end_positions, ignore_index=ignored_index, **kwargs)
total_loss = (start_loss + end_loss) / 2
return total_loss
def ForTokenClassification(logits: torch.Tensor, labels, config, **kwargs):
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.view(-1, config.num_labels)
labels = labels.view(-1).to(logits.device)
logits = logits.float()
# Flatten the tokens
return fixed_cross_entropy(logits, labels, **kwargs)
LOSS_MAPPING = {
"ForCausalLM": ForCausalLMLoss,
"ForMaskedLM": ForMaskedLMLoss,
"ForQuestionAnswering": ForQuestionAnsweringLoss,
"ForSequenceClassification": ForSequenceClassificationLoss,
"ForImageClassification": ForSequenceClassificationLoss,
"ForVideoClassification": ForSequenceClassificationLoss,
"ForAudioClassification": ForSequenceClassificationLoss,
"ForTokenClassification": ForTokenClassification,
"ForSegmentation": ForSegmentationLoss,
"ForObjectDetection": ForObjectDetectionLoss,
"ForConditionalGeneration": ForCausalLMLoss,
"DeformableDetrForObjectDetection": DeformableDetrForObjectDetectionLoss,
"ConditionalDetrForObjectDetection": DeformableDetrForObjectDetectionLoss,
"DabDetrForObjectDetection": DeformableDetrForObjectDetectionLoss,
"GroundingDinoForObjectDetection": GroundingDinoForObjectDetectionLoss,
"MMGroundingDinoForObjectDetection": GroundingDinoForObjectDetectionLoss,
"ConditionalDetrForSegmentation": DeformableDetrForSegmentationLoss,
"RTDetrForObjectDetection": RTDetrForObjectDetectionLoss,
"RTDetrV2ForObjectDetection": RTDetrForObjectDetectionLoss,
"DFineForObjectDetection": DFineForObjectDetectionLoss,
"CsmForConditionalGeneration": ForCausalLMLoss,
}
|