File size: 4,265 Bytes
7a60a87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from __future__ import annotations

from dataclasses import dataclass
from typing import Tuple

import torch
import torch.distributed as dist
import torch.distributed.nn.functional as dist_nn

from specforge.distributed import get_draft_sp_group, get_sp_ulysses_group


@dataclass
class StepState:
    input_ids: torch.Tensor
    hidden_states: torch.Tensor
    position_ids: torch.Tensor
    attention_mask: torch.Tensor
    target_p: torch.Tensor
    position_mask: torch.Tensor
    loss_mask: torch.Tensor


class BackendAdapter:
    def __init__(self, model: "OnlineEagle3Model"):
        self.m = model

    def step_view(
        self,
        *,
        idx: int,
        ttt_length: int,
        global_input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        loss_mask: torch.Tensor,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
        target_p_padded: torch.Tensor,
        position_mask: torch.Tensor,
        seq_length: int,
    ) -> StepState:
        raise NotImplementedError

    def reduce_metrics(
        self, *, local_correct: torch.Tensor, local_denom: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        return local_correct, local_denom

    def reduce_loss(self, loss: torch.Tensor) -> torch.Tensor:
        return loss


class SdpaLikeAdapter(BackendAdapter):
    def step_view(
        self,
        *,
        idx: int,
        ttt_length: int,
        global_input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        loss_mask: torch.Tensor,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
        target_p_padded: torch.Tensor,
        position_mask: torch.Tensor,
        seq_length: int,
    ) -> StepState:
        target_p = target_p_padded[:, idx : idx + seq_length, :].contiguous()
        return StepState(
            input_ids=global_input_ids,
            hidden_states=hidden_states,
            position_ids=position_ids,
            attention_mask=attention_mask,
            target_p=target_p,
            position_mask=position_mask,
            loss_mask=loss_mask,
        )


class UspAdapter(BackendAdapter):
    def __init__(self, model: "OnlineEagle3Model"):
        super().__init__(model)
        self.sp_group = get_draft_sp_group()
        self.sp_world_size = dist.get_world_size(self.sp_group)
        self.ulysses_pg = get_sp_ulysses_group()
        self.sp_ulysses_degree = dist.get_world_size(self.ulysses_pg)

    def step_view(
        self,
        *,
        idx: int,
        ttt_length: int,
        global_input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        loss_mask: torch.Tensor,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
        target_p_padded: torch.Tensor,
        position_mask: torch.Tensor,
        seq_length: int,
    ) -> StepState:
        usp_chunk_size = seq_length - ttt_length
        if usp_chunk_size <= 0:
            raise ValueError(
                f"USP local seq_length ({seq_length}) must be larger than "
                f"ttt_length ({ttt_length})"
            )
        target_p = target_p_padded[:, idx : idx + usp_chunk_size, :]
        return StepState(
            input_ids=global_input_ids[:, :usp_chunk_size],
            hidden_states=hidden_states[:, :usp_chunk_size, :],
            position_ids=position_ids[:, : usp_chunk_size * self.sp_ulysses_degree],
            attention_mask=attention_mask[:, :usp_chunk_size],
            target_p=target_p,
            position_mask=position_mask[:, :usp_chunk_size, :],
            loss_mask=loss_mask[:, :usp_chunk_size, :],
        )

    def reduce_metrics(
        self, *, local_correct: torch.Tensor, local_denom: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        local_correct = dist_nn.all_reduce(
            local_correct, op=dist.ReduceOp.SUM, group=self.sp_group
        )
        local_denom = dist_nn.all_reduce(
            local_denom, op=dist.ReduceOp.SUM, group=self.sp_group
        )
        return local_correct, local_denom

    def reduce_loss(self, loss: torch.Tensor) -> torch.Tensor:
        loss = dist_nn.all_reduce(loss, op=dist.ReduceOp.SUM, group=self.sp_group)
        loss = loss / self.sp_world_size
        return loss