Update modeling_dicow.py
Browse files- modeling_dicow.py +1 -17
modeling_dicow.py
CHANGED
|
@@ -25,8 +25,7 @@ from .encoder import DiCoWEncoder
|
|
| 25 |
from .FDDT import FDDT
|
| 26 |
from .layers import CustomLinear, CustomDiagonalLinear, Gate
|
| 27 |
from .generation import DiCoWGenerationMixin
|
| 28 |
-
|
| 29 |
-
import wandb
|
| 30 |
logging.set_verbosity_debug()
|
| 31 |
logger = logging.get_logger("transformers")
|
| 32 |
|
|
@@ -334,21 +333,6 @@ class DiCoWForConditionalGeneration(DiCoWGenerationMixin, WhisperForConditionalG
|
|
| 334 |
wandb.log({"ctc_loss": ctc_loss})
|
| 335 |
loss = (1 - self.config.ctc_weight) * dec_loss + self.config.ctc_weight * ctc_loss
|
| 336 |
|
| 337 |
-
if self.config.contrastive_loss_weight > 0.0:
|
| 338 |
-
loss_fct = ContrastiveLoss(distance_metric="cosine")
|
| 339 |
-
stno_per_spk_pair = stno_mask.view(-1, self.config.mt_num_speakers, stno_mask.shape[1], stno_mask.shape[2])
|
| 340 |
-
positive_mask = ((stno_per_spk_pair[:, :, 1, :] + stno_per_spk_pair[:, :, 3, :]) > 0.5).flatten(1)
|
| 341 |
-
intermediate_states = outputs.encoder_hidden_states[8].view(-1, self.config.mt_num_speakers, stno_mask.shape[2],
|
| 342 |
-
outputs.encoder_hidden_states[8].shape[-1]).flatten(1, 2)
|
| 343 |
-
valid_pairs = is_valid.view((-1, self.config.mt_num_speakers)).all(dim=-1)
|
| 344 |
-
contrastive_loss = loss_fct(
|
| 345 |
-
intermediate_states[valid_pairs],
|
| 346 |
-
positive_mask[valid_pairs])
|
| 347 |
-
# print(contrastive_loss)
|
| 348 |
-
if wandb.run is not None:
|
| 349 |
-
wandb.log({"contrastive_loss": contrastive_loss})
|
| 350 |
-
if contrastive_loss != 0.0 and loss < 0.5:
|
| 351 |
-
loss += self.config.contrastive_loss_weight * contrastive_loss
|
| 352 |
if not return_dict:
|
| 353 |
output = (dec_lm_logits,) + outputs[1:]
|
| 354 |
return ((loss,) + output) if loss is not None else output
|
|
|
|
| 25 |
from .FDDT import FDDT
|
| 26 |
from .layers import CustomLinear, CustomDiagonalLinear, Gate
|
| 27 |
from .generation import DiCoWGenerationMixin
|
| 28 |
+
|
|
|
|
| 29 |
logging.set_verbosity_debug()
|
| 30 |
logger = logging.get_logger("transformers")
|
| 31 |
|
|
|
|
| 333 |
wandb.log({"ctc_loss": ctc_loss})
|
| 334 |
loss = (1 - self.config.ctc_weight) * dec_loss + self.config.ctc_weight * ctc_loss
|
| 335 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 336 |
if not return_dict:
|
| 337 |
output = (dec_lm_logits,) + outputs[1:]
|
| 338 |
return ((loss,) + output) if loss is not None else output
|