File size: 4,265 Bytes
7a60a87 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 | from __future__ import annotations
from dataclasses import dataclass
from typing import Tuple
import torch
import torch.distributed as dist
import torch.distributed.nn.functional as dist_nn
from specforge.distributed import get_draft_sp_group, get_sp_ulysses_group
@dataclass
class StepState:
input_ids: torch.Tensor
hidden_states: torch.Tensor
position_ids: torch.Tensor
attention_mask: torch.Tensor
target_p: torch.Tensor
position_mask: torch.Tensor
loss_mask: torch.Tensor
class BackendAdapter:
def __init__(self, model: "OnlineEagle3Model"):
self.m = model
def step_view(
self,
*,
idx: int,
ttt_length: int,
global_input_ids: torch.Tensor,
attention_mask: torch.Tensor,
loss_mask: torch.Tensor,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
target_p_padded: torch.Tensor,
position_mask: torch.Tensor,
seq_length: int,
) -> StepState:
raise NotImplementedError
def reduce_metrics(
self, *, local_correct: torch.Tensor, local_denom: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
return local_correct, local_denom
def reduce_loss(self, loss: torch.Tensor) -> torch.Tensor:
return loss
class SdpaLikeAdapter(BackendAdapter):
def step_view(
self,
*,
idx: int,
ttt_length: int,
global_input_ids: torch.Tensor,
attention_mask: torch.Tensor,
loss_mask: torch.Tensor,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
target_p_padded: torch.Tensor,
position_mask: torch.Tensor,
seq_length: int,
) -> StepState:
target_p = target_p_padded[:, idx : idx + seq_length, :].contiguous()
return StepState(
input_ids=global_input_ids,
hidden_states=hidden_states,
position_ids=position_ids,
attention_mask=attention_mask,
target_p=target_p,
position_mask=position_mask,
loss_mask=loss_mask,
)
class UspAdapter(BackendAdapter):
def __init__(self, model: "OnlineEagle3Model"):
super().__init__(model)
self.sp_group = get_draft_sp_group()
self.sp_world_size = dist.get_world_size(self.sp_group)
self.ulysses_pg = get_sp_ulysses_group()
self.sp_ulysses_degree = dist.get_world_size(self.ulysses_pg)
def step_view(
self,
*,
idx: int,
ttt_length: int,
global_input_ids: torch.Tensor,
attention_mask: torch.Tensor,
loss_mask: torch.Tensor,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
target_p_padded: torch.Tensor,
position_mask: torch.Tensor,
seq_length: int,
) -> StepState:
usp_chunk_size = seq_length - ttt_length
if usp_chunk_size <= 0:
raise ValueError(
f"USP local seq_length ({seq_length}) must be larger than "
f"ttt_length ({ttt_length})"
)
target_p = target_p_padded[:, idx : idx + usp_chunk_size, :]
return StepState(
input_ids=global_input_ids[:, :usp_chunk_size],
hidden_states=hidden_states[:, :usp_chunk_size, :],
position_ids=position_ids[:, : usp_chunk_size * self.sp_ulysses_degree],
attention_mask=attention_mask[:, :usp_chunk_size],
target_p=target_p,
position_mask=position_mask[:, :usp_chunk_size, :],
loss_mask=loss_mask[:, :usp_chunk_size, :],
)
def reduce_metrics(
self, *, local_correct: torch.Tensor, local_denom: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
local_correct = dist_nn.all_reduce(
local_correct, op=dist.ReduceOp.SUM, group=self.sp_group
)
local_denom = dist_nn.all_reduce(
local_denom, op=dist.ReduceOp.SUM, group=self.sp_group
)
return local_correct, local_denom
def reduce_loss(self, loss: torch.Tensor) -> torch.Tensor:
loss = dist_nn.all_reduce(loss, op=dist.ReduceOp.SUM, group=self.sp_group)
loss = loss / self.sp_world_size
return loss
|