File size: 6,419 Bytes
7155cf2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified by Xinrui Wu
import torch
import torch.nn as nn
from torch.nn import BCEWithLogitsLoss, MSELoss
from .loss_deformable_detr import DeformableDetrForObjectDetectionLoss, DeformableDetrForSegmentationLoss
from .loss_for_object_detection import ForObjectDetectionLoss, ForSegmentationLoss
from .loss_rt_detr import RTDetrForObjectDetectionLoss
def fixed_cross_entropy(source, target, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs):
reduction = "sum" if num_items_in_batch is not None else "mean"
loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction)
if reduction == "sum":
loss = loss / num_items_in_batch
return loss
def ForCausalLMLoss(
logits, labels, vocab_size: int, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs
):
special_token_weight = 1
special_token_ids = [i for i in range(151665, 151672)]
#g ---
import os
special_token_loss = os.getenv("special_token_loss", "T")
if special_token_loss == "F":
special_token_weight = 0
elif special_token_loss == "T":
special_token_weight = 1
else:
weight = special_token_loss.split("T")[-1]
special_token_weight = float(weight)
print(f"special_token_weight: {special_token_weight}")
#g ---
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
shift_logits = shift_logits.view(-1, vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
if special_token_ids is None:
print(f"special_token_ids is None, use default loss func")
loss = fixed_cross_entropy(shift_logits, shift_labels, num_items_in_batch, ignore_index, **kwargs)
return loss
special_token_mask = torch.isin(shift_labels, torch.tensor(special_token_ids, device=shift_labels.device))
base_loss = nn.functional.cross_entropy(shift_logits, shift_labels, ignore_index=ignore_index, reduction="none")
if special_token_weight == 0.0:
# 如果权重为 0,将特殊 token 的损失设置为 0
base_loss[special_token_mask] = 0.0
else:
# 如果权重不为 0,将特殊 token 的损失乘以权重
base_loss[special_token_mask] *= special_token_weight
if num_items_in_batch is not None:
loss = base_loss.sum() / num_items_in_batch
else:
loss = base_loss.mean()
return loss
def ForSequenceClassificationLoss(labels, pooled_logits, config, **kwargs):
num_labels = config.num_labels
if config.problem_type is None:
if num_labels == 1:
config.problem_type = "regression"
elif num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
config.problem_type = "single_label_classification"
else:
config.problem_type = "multi_label_classification"
if config.problem_type == "regression":
loss_fct = MSELoss()
if num_labels == 1:
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif config.problem_type == "single_label_classification":
loss = fixed_cross_entropy(pooled_logits.view(-1, num_labels), labels.view(-1), **kwargs)
elif config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
return loss
def ForQuestionAnsweringLoss(start_logits, end_logits, start_positions, end_positions, **kwargs):
total_loss = None
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1).to(start_logits.device)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1).to(end_logits.device)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions = start_positions.clamp(0, ignored_index)
end_positions = end_positions.clamp(0, ignored_index)
start_loss = fixed_cross_entropy(start_logits, start_positions, ignore_index=ignored_index, **kwargs)
end_loss = fixed_cross_entropy(end_logits, end_positions, ignore_index=ignored_index, **kwargs)
total_loss = (start_loss + end_loss) / 2
return total_loss
def ForTokenClassification(logits, labels, config, **kwargs):
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.view(-1, config.num_labels)
labels = labels.view(-1)
logits = logits.float()
# Flatten the tokens
return fixed_cross_entropy(logits, labels, **kwargs)
LOSS_MAPPING = {
"ForCausalLM": ForCausalLMLoss,
"ForQuestionAnswering": ForQuestionAnsweringLoss,
"ForSequenceClassification": ForSequenceClassificationLoss,
"ForTokenClassification": ForTokenClassification,
"ForSegmentation": ForSegmentationLoss,
"ForObjectDetection": ForObjectDetectionLoss,
"DeformableDetrForObjectDetection": DeformableDetrForObjectDetectionLoss,
"ConditionalDetrForObjectDetection": DeformableDetrForObjectDetectionLoss,
"GroundingDinoForObjectDetection": DeformableDetrForObjectDetectionLoss,
"ConditionalDetrForSegmentation": DeformableDetrForSegmentationLoss,
"RTDetrForObjectDetection": RTDetrForObjectDetectionLoss,
}
|