Lakoc commited on
Commit
e4d2de3
·
verified ·
1 Parent(s): 0377762

Update modeling_dicow.py

Browse files
Files changed (1) hide show
  1. modeling_dicow.py +1 -17
modeling_dicow.py CHANGED
@@ -25,8 +25,7 @@ from .encoder import DiCoWEncoder
25
  from .FDDT import FDDT
26
  from .layers import CustomLinear, CustomDiagonalLinear, Gate
27
  from .generation import DiCoWGenerationMixin
28
- from .contrastive_loss import ContrastiveLoss
29
- import wandb
30
  logging.set_verbosity_debug()
31
  logger = logging.get_logger("transformers")
32
 
@@ -334,21 +333,6 @@ class DiCoWForConditionalGeneration(DiCoWGenerationMixin, WhisperForConditionalG
334
  wandb.log({"ctc_loss": ctc_loss})
335
  loss = (1 - self.config.ctc_weight) * dec_loss + self.config.ctc_weight * ctc_loss
336
 
337
- if self.config.contrastive_loss_weight > 0.0:
338
- loss_fct = ContrastiveLoss(distance_metric="cosine")
339
- stno_per_spk_pair = stno_mask.view(-1, self.config.mt_num_speakers, stno_mask.shape[1], stno_mask.shape[2])
340
- positive_mask = ((stno_per_spk_pair[:, :, 1, :] + stno_per_spk_pair[:, :, 3, :]) > 0.5).flatten(1)
341
- intermediate_states = outputs.encoder_hidden_states[8].view(-1, self.config.mt_num_speakers, stno_mask.shape[2],
342
- outputs.encoder_hidden_states[8].shape[-1]).flatten(1, 2)
343
- valid_pairs = is_valid.view((-1, self.config.mt_num_speakers)).all(dim=-1)
344
- contrastive_loss = loss_fct(
345
- intermediate_states[valid_pairs],
346
- positive_mask[valid_pairs])
347
- # print(contrastive_loss)
348
- if wandb.run is not None:
349
- wandb.log({"contrastive_loss": contrastive_loss})
350
- if contrastive_loss != 0.0 and loss < 0.5:
351
- loss += self.config.contrastive_loss_weight * contrastive_loss
352
  if not return_dict:
353
  output = (dec_lm_logits,) + outputs[1:]
354
  return ((loss,) + output) if loss is not None else output
 
25
  from .FDDT import FDDT
26
  from .layers import CustomLinear, CustomDiagonalLinear, Gate
27
  from .generation import DiCoWGenerationMixin
28
+
 
29
  logging.set_verbosity_debug()
30
  logger = logging.get_logger("transformers")
31
 
 
333
  wandb.log({"ctc_loss": ctc_loss})
334
  loss = (1 - self.config.ctc_weight) * dec_loss + self.config.ctc_weight * ctc_loss
335
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
  if not return_dict:
337
  output = (dec_lm_logits,) + outputs[1:]
338
  return ((loss,) + output) if loss is not None else output