Hanrui / SpecForge /specforge /core /eagle3_adapters.py
Lekr0's picture
Add files using upload-large-folder tool
7a60a87 verified
from __future__ import annotations
from dataclasses import dataclass
from typing import Tuple
import torch
import torch.distributed as dist
import torch.distributed.nn.functional as dist_nn
from specforge.distributed import get_draft_sp_group, get_sp_ulysses_group
@dataclass
class StepState:
input_ids: torch.Tensor
hidden_states: torch.Tensor
position_ids: torch.Tensor
attention_mask: torch.Tensor
target_p: torch.Tensor
position_mask: torch.Tensor
loss_mask: torch.Tensor
class BackendAdapter:
def __init__(self, model: "OnlineEagle3Model"):
self.m = model
def step_view(
self,
*,
idx: int,
ttt_length: int,
global_input_ids: torch.Tensor,
attention_mask: torch.Tensor,
loss_mask: torch.Tensor,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
target_p_padded: torch.Tensor,
position_mask: torch.Tensor,
seq_length: int,
) -> StepState:
raise NotImplementedError
def reduce_metrics(
self, *, local_correct: torch.Tensor, local_denom: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
return local_correct, local_denom
def reduce_loss(self, loss: torch.Tensor) -> torch.Tensor:
return loss
class SdpaLikeAdapter(BackendAdapter):
def step_view(
self,
*,
idx: int,
ttt_length: int,
global_input_ids: torch.Tensor,
attention_mask: torch.Tensor,
loss_mask: torch.Tensor,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
target_p_padded: torch.Tensor,
position_mask: torch.Tensor,
seq_length: int,
) -> StepState:
target_p = target_p_padded[:, idx : idx + seq_length, :].contiguous()
return StepState(
input_ids=global_input_ids,
hidden_states=hidden_states,
position_ids=position_ids,
attention_mask=attention_mask,
target_p=target_p,
position_mask=position_mask,
loss_mask=loss_mask,
)
class UspAdapter(BackendAdapter):
def __init__(self, model: "OnlineEagle3Model"):
super().__init__(model)
self.sp_group = get_draft_sp_group()
self.sp_world_size = dist.get_world_size(self.sp_group)
self.ulysses_pg = get_sp_ulysses_group()
self.sp_ulysses_degree = dist.get_world_size(self.ulysses_pg)
def step_view(
self,
*,
idx: int,
ttt_length: int,
global_input_ids: torch.Tensor,
attention_mask: torch.Tensor,
loss_mask: torch.Tensor,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
target_p_padded: torch.Tensor,
position_mask: torch.Tensor,
seq_length: int,
) -> StepState:
usp_chunk_size = seq_length - ttt_length
if usp_chunk_size <= 0:
raise ValueError(
f"USP local seq_length ({seq_length}) must be larger than "
f"ttt_length ({ttt_length})"
)
target_p = target_p_padded[:, idx : idx + usp_chunk_size, :]
return StepState(
input_ids=global_input_ids[:, :usp_chunk_size],
hidden_states=hidden_states[:, :usp_chunk_size, :],
position_ids=position_ids[:, : usp_chunk_size * self.sp_ulysses_degree],
attention_mask=attention_mask[:, :usp_chunk_size],
target_p=target_p,
position_mask=position_mask[:, :usp_chunk_size, :],
loss_mask=loss_mask[:, :usp_chunk_size, :],
)
def reduce_metrics(
self, *, local_correct: torch.Tensor, local_denom: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
local_correct = dist_nn.all_reduce(
local_correct, op=dist.ReduceOp.SUM, group=self.sp_group
)
local_denom = dist_nn.all_reduce(
local_denom, op=dist.ReduceOp.SUM, group=self.sp_group
)
return local_correct, local_denom
def reduce_loss(self, loss: torch.Tensor) -> torch.Tensor:
loss = dist_nn.all_reduce(loss, op=dist.ReduceOp.SUM, group=self.sp_group)
loss = loss / self.sp_world_size
return loss