Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,506 Bytes
1315cad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
from __future__ import annotations
from collections import deque
from dataclasses import dataclass, field
from typing import Deque, Iterable, List, Sequence, Tuple
@dataclass
class TokenIds:
card: int
new_word: int
pad: int
bos: int
zero: int
spk1: int
spk2: int
audio_pad: int
audio_bos: int
ungenerated: int = -2
@dataclass
class Entry:
tokens: List[int]
text: str
padding: int = 0
@dataclass
class State:
entries: Deque[Entry]
padding_budget: int
forced_padding: int
pending_tokens: Deque[int] = field(default_factory=deque)
lookahead_tokens: Deque[int] = field(default_factory=deque)
end_step: int | None = None
consumption_times: List[int] = field(default_factory=list)
transcript: List[Tuple[str, int]] = field(default_factory=list)
def peek_tokens(self, count: int) -> List[int]:
"""Return tokens from upcoming entries (used for second-stream lookahead)."""
assert count > 0
for entry in self.entries:
if entry.tokens:
count -= 1
if count == 0:
return entry.tokens
return []
class StateMachine:
def __init__(
self,
token_ids: TokenIds,
*,
second_stream_ahead: int = 0,
max_padding: int = 6,
initial_padding: int = 0,
) -> None:
self.token_ids = token_ids
self.second_stream_ahead = second_stream_ahead
self.max_padding = max_padding
self.initial_padding = initial_padding
def new_state(self, entries: Iterable[Entry]) -> State:
return State(
entries=deque(entries),
padding_budget=self.initial_padding,
forced_padding=self.initial_padding,
)
def process(
self,
step: int,
state: State,
token: int,
is_forced: bool = False,
) -> Tuple[int, int, bool]:
token = self._sanitize_token(token)
token = self._enforce_token_constraints(state, token, is_forced)
token, consumed_new_word = self._handle_new_word(step, state, token)
output_token = self._select_output_token(state, token)
final_main, final_second = self._maybe_multiplex_second_stream(
state, output_token
)
return final_main, final_second, consumed_new_word
def _sanitize_token(self, token: int) -> int:
if token == 1:
token = self.token_ids.new_word
elif token == 0:
token = self.token_ids.pad
if token not in (self.token_ids.new_word, self.token_ids.pad):
return self.token_ids.pad
return token
def _enforce_token_constraints(
self, state: State, token: int, is_forced: bool
) -> int:
if state.pending_tokens:
return self.token_ids.pad
if is_forced:
return token
if state.forced_padding > 0:
if token != self.token_ids.pad:
token = self.token_ids.pad
return token
if state.padding_budget <= 0 and token != self.token_ids.new_word:
return self.token_ids.new_word
return token
def _handle_new_word(
self, step: int, state: State, token: int
) -> Tuple[int, bool]:
if token != self.token_ids.new_word:
return token, False
if state.entries:
entry = state.entries.popleft()
state.consumption_times.append(step)
if entry.tokens:
state.transcript.append((entry.text, step))
state.pending_tokens.extend(entry.tokens)
if self.second_stream_ahead:
state.lookahead_tokens.extend(
state.peek_tokens(self.second_stream_ahead)
)
state.padding_budget = self.max_padding
else:
token = self.token_ids.pad
state.forced_padding = entry.padding
return token, True
token = self.token_ids.pad
if self.second_stream_ahead and state.end_step is None:
token = self.token_ids.new_word
if state.end_step is None:
state.end_step = step
return token, False
def _select_output_token(self, state: State, token: int) -> int:
if token == self.token_ids.pad:
if state.padding_budget > 0:
state.padding_budget -= 1
if state.forced_padding > 0:
state.forced_padding -= 1
if state.pending_tokens:
return state.pending_tokens.popleft()
return self.token_ids.pad
if token == self.token_ids.new_word:
return self.token_ids.new_word
if token == self.token_ids.zero:
return token
raise RuntimeError(f"Invalid token {token}")
def _maybe_multiplex_second_stream(
self, state: State, output: int
) -> Tuple[int, int]:
if not self.second_stream_ahead:
return output, output
second = -1
if output == self.token_ids.new_word:
second = self.token_ids.new_word
if state.pending_tokens:
output = state.pending_tokens.popleft()
else:
output = self.token_ids.pad
elif state.lookahead_tokens:
second = state.lookahead_tokens.popleft()
else:
second = self.token_ids.pad
return output, second
|