File size: 5,506 Bytes
1315cad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
from __future__ import annotations

from collections import deque
from dataclasses import dataclass, field
from typing import Deque, Iterable, List, Sequence, Tuple


@dataclass
class TokenIds:
    card: int
    new_word: int
    pad: int
    bos: int
    zero: int
    spk1: int
    spk2: int
    audio_pad: int
    audio_bos: int
    ungenerated: int = -2


@dataclass
class Entry:
    tokens: List[int]
    text: str
    padding: int = 0


@dataclass
class State:
    entries: Deque[Entry]
    padding_budget: int
    forced_padding: int
    pending_tokens: Deque[int] = field(default_factory=deque)
    lookahead_tokens: Deque[int] = field(default_factory=deque)
    end_step: int | None = None
    consumption_times: List[int] = field(default_factory=list)
    transcript: List[Tuple[str, int]] = field(default_factory=list)

    def peek_tokens(self, count: int) -> List[int]:
        """Return tokens from upcoming entries (used for second-stream lookahead)."""
        assert count > 0
        for entry in self.entries:
            if entry.tokens:
                count -= 1
                if count == 0:
                    return entry.tokens
        return []


class StateMachine:
    def __init__(
        self,
        token_ids: TokenIds,
        *,
        second_stream_ahead: int = 0,
        max_padding: int = 6,
        initial_padding: int = 0,
    ) -> None:
        self.token_ids = token_ids
        self.second_stream_ahead = second_stream_ahead
        self.max_padding = max_padding
        self.initial_padding = initial_padding

    def new_state(self, entries: Iterable[Entry]) -> State:
        return State(
            entries=deque(entries),
            padding_budget=self.initial_padding,
            forced_padding=self.initial_padding,
        )

    def process(
        self,
        step: int,
        state: State,
        token: int,
        is_forced: bool = False,
    ) -> Tuple[int, int, bool]:
        token = self._sanitize_token(token)
        token = self._enforce_token_constraints(state, token, is_forced)
        token, consumed_new_word = self._handle_new_word(step, state, token)
        output_token = self._select_output_token(state, token)
        final_main, final_second = self._maybe_multiplex_second_stream(
            state, output_token
        )
        return final_main, final_second, consumed_new_word

    def _sanitize_token(self, token: int) -> int:
        if token == 1:
            token = self.token_ids.new_word
        elif token == 0:
            token = self.token_ids.pad
        if token not in (self.token_ids.new_word, self.token_ids.pad):
            return self.token_ids.pad
        return token

    def _enforce_token_constraints(
        self, state: State, token: int, is_forced: bool
    ) -> int:
        if state.pending_tokens:
            return self.token_ids.pad
        if is_forced:
            return token
        if state.forced_padding > 0:
            if token != self.token_ids.pad:
                token = self.token_ids.pad
            return token
        if state.padding_budget <= 0 and token != self.token_ids.new_word:
            return self.token_ids.new_word
        return token

    def _handle_new_word(
        self, step: int, state: State, token: int
    ) -> Tuple[int, bool]:
        if token != self.token_ids.new_word:
            return token, False
        if state.entries:
            entry = state.entries.popleft()
            state.consumption_times.append(step)
            if entry.tokens:
                state.transcript.append((entry.text, step))
                state.pending_tokens.extend(entry.tokens)
                if self.second_stream_ahead:
                    state.lookahead_tokens.extend(
                        state.peek_tokens(self.second_stream_ahead)
                    )
                state.padding_budget = self.max_padding
            else:
                token = self.token_ids.pad
            state.forced_padding = entry.padding
            return token, True
        token = self.token_ids.pad
        if self.second_stream_ahead and state.end_step is None:
            token = self.token_ids.new_word
        if state.end_step is None:
            state.end_step = step
        return token, False

    def _select_output_token(self, state: State, token: int) -> int:
        if token == self.token_ids.pad:
            if state.padding_budget > 0:
                state.padding_budget -= 1
            if state.forced_padding > 0:
                state.forced_padding -= 1
            if state.pending_tokens:
                return state.pending_tokens.popleft()
            return self.token_ids.pad
        if token == self.token_ids.new_word:
            return self.token_ids.new_word
        if token == self.token_ids.zero:
            return token
        raise RuntimeError(f"Invalid token {token}")

    def _maybe_multiplex_second_stream(
        self, state: State, output: int
    ) -> Tuple[int, int]:
        if not self.second_stream_ahead:
            return output, output
        second = -1
        if output == self.token_ids.new_word:
            second = self.token_ids.new_word
            if state.pending_tokens:
                output = state.pending_tokens.popleft()
            else:
                output = self.token_ids.pad
        elif state.lookahead_tokens:
            second = state.lookahead_tokens.popleft()
        else:
            second = self.token_ids.pad
        return output, second