| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import typing |
|
|
| import torch |
| import torch.distributed as dist |
|
|
|
|
| class LinearCrossEntropy(torch.autograd.Function): |
| @staticmethod |
| def forward( |
| ctx, |
| hidden: torch.Tensor, |
| weight: torch.Tensor, |
| labels: torch.Tensor, |
| temperature: typing.Optional[float] = 1.0, |
| reduction: typing.Optional[str] = "none", |
| dist_process_group: typing.Optional[dist.ProcessGroup] = None, |
| ) -> list[torch.Tensor]: |
| """_summary_ |
| |
| Args: |
| ctx (_type_): _description_ |
| hidden (torch.Tensor): (batch_size, num_tokens, hidden_size) -> (batch_size * num_tokens, hidden_size) |
| weight (torch.Tensor): (vocab_size, hidden_size) |
| labels (torch.Tensor): (batch_size, num_tokens) -> (batch_size * num_tokens, ) |
| temperature (typing.Optional[float], optional): _description_. Defaults to 1.0. |
| reduction (typing.Optional[str], optional): _description_. Defaults to "none". |
| dist_process_group (typing.Optional[dist.ProcessGroup], optional): _description_. Defaults to None. |
| |
| Returns: |
| typing.List[torch.Tensor]: _description_ |
| """ |
|
|
| assert isinstance(temperature, float), f"temperature must be a float, but got {type(temperature)}" |
| assert isinstance(reduction, str), f"reduction must be a str, but got {type(reduction)}" |
| with torch.cuda.nvtx.range("LinearCrossEntropy-forward"): |
| from . import kernels |
|
|
| REDUCTION = kernels.get_entropy_reduction_enum_number(reduction.lower()) |
|
|
| original_hidden_shape = hidden.shape |
| if len(hidden.shape) != 2: |
| hidden = hidden.view(-1, hidden.shape[-1]) |
| if len(labels.shape) != 1: |
| labels = labels.view(-1) |
|
|
| logprobs, entropy, _maximum, _accumulate, _entropy_b = kernels.efficient_entropy_forward( |
| hidden, weight, labels, REDUCTION, temperature, dist_process_group |
| ) |
|
|
| ctx.save_for_backward(hidden, weight, labels, _maximum, _accumulate, _entropy_b) |
| ctx.original_hidden_shape = original_hidden_shape |
| ctx.REDUCTION = REDUCTION |
| ctx.dist_process_group = dist_process_group |
| ctx.should_return_fp32_grad = False |
| ctx.temperature = temperature |
| return logprobs, entropy |
|
|
| @staticmethod |
| def backward(ctx, dlogprobs: torch.Tensor, dentropy: torch.Tensor) -> list[torch.Tensor]: |
| from . import kernels |
|
|
| with torch.cuda.nvtx.range("LinearCrossEntropy-backward"): |
| (hidden, weight, labels, _maximum, _accumulate, _entropy_b) = ctx.saved_tensors |
| REDUCTION = ctx.REDUCTION |
| dist_process_group = ctx.dist_process_group |
| should_return_fp32_grad = ctx.should_return_fp32_grad |
| temperature = ctx.temperature |
|
|
| d_hidden, d_weight = kernels.efficient_entropy_backward( |
| dlogprobs, |
| dentropy, |
| hidden, |
| weight, |
| labels, |
| _maximum, |
| _accumulate, |
| _entropy_b, |
| REDUCTION, |
| should_return_fp32_grad, |
| temperature, |
| dist_process_group, |
| ) |
| d_hidden = d_hidden.view(ctx.original_hidden_shape) |
|
|
| return (d_hidden, d_weight, None, None, None, None) |
|
|
|
|
| linear_cross_entropy = LinearCrossEntropy.apply |
|
|