| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| from typing import Optional, Tuple |
|
|
| import torch |
| import torch.distributed as dist |
|
|
| from .comm import ( |
| get_unified_sequence_parallel_group, |
| get_unified_sequence_parallel_world_size, |
| ) |
|
|
|
|
| class ReduceLoss(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx: torch.autograd.Function, loss: torch.Tensor, num_valid_tokens: torch.Tensor) -> torch.Tensor: |
| if num_valid_tokens == 0: |
| loss = torch.nan_to_num(loss) |
|
|
| local_num_tokens = num_valid_tokens.detach().clone() |
| loss *= num_valid_tokens |
| group = get_unified_sequence_parallel_group() |
| dist.all_reduce(loss, group=group) |
| dist.all_reduce(num_valid_tokens, group=group) |
| ctx.save_for_backward(local_num_tokens, num_valid_tokens) |
| return loss / num_valid_tokens |
|
|
| @staticmethod |
| def backward( |
| ctx: torch.autograd.Function, grad_output: torch.Tensor |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
| local_num_tokens, global_num_tokens = ctx.saved_tensors |
| grad_output = get_unified_sequence_parallel_world_size() * local_num_tokens * grad_output / global_num_tokens |
| return grad_output, None |
|
|
|
|
| def reduce_sequence_parallel_loss(loss: torch.Tensor, num_valid_tokens: torch.Tensor) -> torch.Tensor: |
| return ReduceLoss.apply(loss, num_valid_tokens) |
|
|