YongganFu commited on
Commit
be9e4da
·
verified ·
1 Parent(s): 7a07ecc

Update modeling_nvrdiff.py

Browse files
Files changed (1) hide show
  1. modeling_nvrdiff.py +35 -35
modeling_nvrdiff.py CHANGED
@@ -486,45 +486,45 @@ class DiffEncoderModel(Qwen3PreTrainedModel, GenerationMixin):
486
  logits = logits[:, :input_ids_len]
487
 
488
  loss = None
489
- if labels is not None:
490
- if self.config.dlm_paradigm == 'autoregressive':
491
- shift_logits = logits[..., :-1, :].contiguous()
492
- shift_labels = labels[..., 1:].contiguous()
493
 
494
- if loss_mask is None:
495
- loss_fct = CrossEntropyLoss()
496
- shift_logits = shift_logits.view(-1, shift_logits.size(-1))
497
- shift_labels = shift_labels.view(-1)
498
- loss = loss_fct(shift_logits, shift_labels)
499
-
500
- else:
501
- loss_mask = loss_mask[..., 1:].contiguous()
502
-
503
- loss_fct = CrossEntropyLoss(reduction='none')
504
- shift_logits = shift_logits.view(-1, shift_logits.size(-1))
505
- shift_labels = shift_labels.view(-1)
506
- shift_labels = shift_labels.to(shift_logits.device)
507
 
508
- token_losses = loss_fct(shift_logits, shift_labels)
509
 
510
- loss = token_losses[loss_mask].sum() / loss_mask.sum()
511
-
512
- else:
513
- # Handle DREAM vs LLADA style losses
514
- if hasattr(self.config, 'dlm_type') and self.config.dlm_type == 'dream':
515
- logits = logits[..., :-1, :].contiguous()
516
- labels = labels[..., 1:].contiguous()
517
- masked_indices = masked_indices[:, 1:]
518
- p_mask = p_mask[:, 1:]
519
-
520
- # Calculate token-wise cross entropy loss for masked positions in B
521
- token_loss = torch.nn.functional.cross_entropy(
522
- logits[masked_indices],
523
- labels[masked_indices],
524
- reduction='none'
525
- ) / p_mask[masked_indices]
526
 
527
- loss = token_loss.sum() / masked_indices.sum()
528
 
529
  return CausalLMOutputWithPast(
530
  loss=loss if not is_teacher else logits,
 
486
  logits = logits[:, :input_ids_len]
487
 
488
  loss = None
489
+ # if labels is not None:
490
+ # if self.config.dlm_paradigm == 'autoregressive':
491
+ # shift_logits = logits[..., :-1, :].contiguous()
492
+ # shift_labels = labels[..., 1:].contiguous()
493
 
494
+ # if loss_mask is None:
495
+ # loss_fct = CrossEntropyLoss()
496
+ # shift_logits = shift_logits.view(-1, shift_logits.size(-1))
497
+ # shift_labels = shift_labels.view(-1)
498
+ # loss = loss_fct(shift_logits, shift_labels)
499
+
500
+ # else:
501
+ # loss_mask = loss_mask[..., 1:].contiguous()
502
+
503
+ # loss_fct = CrossEntropyLoss(reduction='none')
504
+ # shift_logits = shift_logits.view(-1, shift_logits.size(-1))
505
+ # shift_labels = shift_labels.view(-1)
506
+ # shift_labels = shift_labels.to(shift_logits.device)
507
 
508
+ # token_losses = loss_fct(shift_logits, shift_labels)
509
 
510
+ # loss = token_losses[loss_mask].sum() / loss_mask.sum()
511
+
512
+ # else:
513
+ # # Handle DREAM vs LLADA style losses
514
+ # if hasattr(self.config, 'dlm_type') and self.config.dlm_type == 'dream':
515
+ # logits = logits[..., :-1, :].contiguous()
516
+ # labels = labels[..., 1:].contiguous()
517
+ # masked_indices = masked_indices[:, 1:]
518
+ # p_mask = p_mask[:, 1:]
519
+
520
+ # # Calculate token-wise cross entropy loss for masked positions in B
521
+ # token_loss = torch.nn.functional.cross_entropy(
522
+ # logits[masked_indices],
523
+ # labels[masked_indices],
524
+ # reduction='none'
525
+ # ) / p_mask[masked_indices]
526
 
527
+ # loss = token_loss.sum() / masked_indices.sum()
528
 
529
  return CausalLMOutputWithPast(
530
  loss=loss if not is_teacher else logits,