| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
| from typing import Tuple |
|
|
| import torch |
| import torch.distributed as dist |
| import torch.distributed.nn.functional as dist_nn |
|
|
| from specforge.distributed import get_draft_sp_group, get_sp_ulysses_group |
|
|
|
|
| @dataclass |
| class StepState: |
| input_ids: torch.Tensor |
| hidden_states: torch.Tensor |
| position_ids: torch.Tensor |
| attention_mask: torch.Tensor |
| target_p: torch.Tensor |
| position_mask: torch.Tensor |
| loss_mask: torch.Tensor |
|
|
|
|
| class BackendAdapter: |
| def __init__(self, model: "OnlineEagle3Model"): |
| self.m = model |
|
|
| def step_view( |
| self, |
| *, |
| idx: int, |
| ttt_length: int, |
| global_input_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| loss_mask: torch.Tensor, |
| position_ids: torch.Tensor, |
| hidden_states: torch.Tensor, |
| target_p_padded: torch.Tensor, |
| position_mask: torch.Tensor, |
| seq_length: int, |
| ) -> StepState: |
| raise NotImplementedError |
|
|
| def reduce_metrics( |
| self, *, local_correct: torch.Tensor, local_denom: torch.Tensor |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| return local_correct, local_denom |
|
|
| def reduce_loss(self, loss: torch.Tensor) -> torch.Tensor: |
| return loss |
|
|
|
|
| class SdpaLikeAdapter(BackendAdapter): |
| def step_view( |
| self, |
| *, |
| idx: int, |
| ttt_length: int, |
| global_input_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| loss_mask: torch.Tensor, |
| position_ids: torch.Tensor, |
| hidden_states: torch.Tensor, |
| target_p_padded: torch.Tensor, |
| position_mask: torch.Tensor, |
| seq_length: int, |
| ) -> StepState: |
| target_p = target_p_padded[:, idx : idx + seq_length, :].contiguous() |
| return StepState( |
| input_ids=global_input_ids, |
| hidden_states=hidden_states, |
| position_ids=position_ids, |
| attention_mask=attention_mask, |
| target_p=target_p, |
| position_mask=position_mask, |
| loss_mask=loss_mask, |
| ) |
|
|
|
|
| class UspAdapter(BackendAdapter): |
| def __init__(self, model: "OnlineEagle3Model"): |
| super().__init__(model) |
| self.sp_group = get_draft_sp_group() |
| self.sp_world_size = dist.get_world_size(self.sp_group) |
| self.ulysses_pg = get_sp_ulysses_group() |
| self.sp_ulysses_degree = dist.get_world_size(self.ulysses_pg) |
|
|
| def step_view( |
| self, |
| *, |
| idx: int, |
| ttt_length: int, |
| global_input_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| loss_mask: torch.Tensor, |
| position_ids: torch.Tensor, |
| hidden_states: torch.Tensor, |
| target_p_padded: torch.Tensor, |
| position_mask: torch.Tensor, |
| seq_length: int, |
| ) -> StepState: |
| usp_chunk_size = seq_length - ttt_length |
| if usp_chunk_size <= 0: |
| raise ValueError( |
| f"USP local seq_length ({seq_length}) must be larger than " |
| f"ttt_length ({ttt_length})" |
| ) |
| target_p = target_p_padded[:, idx : idx + usp_chunk_size, :] |
| return StepState( |
| input_ids=global_input_ids[:, :usp_chunk_size], |
| hidden_states=hidden_states[:, :usp_chunk_size, :], |
| position_ids=position_ids[:, : usp_chunk_size * self.sp_ulysses_degree], |
| attention_mask=attention_mask[:, :usp_chunk_size], |
| target_p=target_p, |
| position_mask=position_mask[:, :usp_chunk_size, :], |
| loss_mask=loss_mask[:, :usp_chunk_size, :], |
| ) |
|
|
| def reduce_metrics( |
| self, *, local_correct: torch.Tensor, local_denom: torch.Tensor |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| local_correct = dist_nn.all_reduce( |
| local_correct, op=dist.ReduceOp.SUM, group=self.sp_group |
| ) |
| local_denom = dist_nn.all_reduce( |
| local_denom, op=dist.ReduceOp.SUM, group=self.sp_group |
| ) |
| return local_correct, local_denom |
|
|
| def reduce_loss(self, loss: torch.Tensor) -> torch.Tensor: |
| loss = dist_nn.all_reduce(loss, op=dist.ReduceOp.SUM, group=self.sp_group) |
| loss = loss / self.sp_world_size |
| return loss |
|
|