|
|
from __future__ import annotations |
|
|
|
|
|
from dataclasses import dataclass |
|
|
from typing import Optional, Tuple |
|
|
|
|
|
import torch |
|
|
|
|
|
from ..core.cache import KVCache |
|
|
from ..core.model import DecodeState |
|
|
from ..generation import GenerationConfig |
|
|
from ..audio.grid import delay_frames, mask_audio_logits, undelay_frames |
|
|
from .context import RuntimeContext |
|
|
from .state_machine import State, TokenIds |
|
|
from .guidance import apply_classifier_guidance, sample_audio_logits |
|
|
from .sampler import sample_token |
|
|
from .voice_clone import PrefixPlan |
|
|
from .logger import RuntimeLogger |
|
|
|
|
|
_GRAPH_CUBLAS_READY = False |
|
|
|
|
|
|
|
|
def _ensure_graph_cublas_ready(device: torch.device) -> None: |
|
|
global _GRAPH_CUBLAS_READY |
|
|
if _GRAPH_CUBLAS_READY or device.type != "cuda": |
|
|
return |
|
|
tmp = torch.empty((1, 1), device=device, dtype=torch.float32) |
|
|
torch.matmul(tmp, tmp) |
|
|
torch.cuda.synchronize() |
|
|
_GRAPH_CUBLAS_READY = True |
|
|
@dataclass |
|
|
class GenerationState: |
|
|
decode: DecodeState |
|
|
step_tokens: torch.Tensor |
|
|
audio_buf: torch.Tensor |
|
|
|
|
|
def trim_audio(self, limit: int, pad_token: int, ungenerated: int) -> torch.Tensor: |
|
|
trimmed = self.audio_buf[:, :, :limit] |
|
|
pad = torch.full_like(trimmed, pad_token) |
|
|
trimmed = torch.where(trimmed == ungenerated, pad, trimmed) |
|
|
self.audio_buf = trimmed |
|
|
return trimmed |
|
|
|
|
|
@property |
|
|
def transformer_cache(self) -> KVCache: |
|
|
return self.decode.transformer |
|
|
|
|
|
@transformer_cache.setter |
|
|
def transformer_cache(self, cache: KVCache) -> None: |
|
|
self.decode.transformer = cache |
|
|
|
|
|
@property |
|
|
def depformer_cache(self) -> KVCache: |
|
|
return self.decode.depformer |
|
|
|
|
|
@depformer_cache.setter |
|
|
def depformer_cache(self, cache: KVCache) -> None: |
|
|
self.decode.depformer = cache |
|
|
|
|
|
def reset_dep_cache(self) -> None: |
|
|
self.decode.depformer.reset() |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class NetworkBuffers: |
|
|
text: torch.Tensor |
|
|
cb0: torch.Tensor |
|
|
dep: list[torch.Tensor] |
|
|
|
|
|
|
|
|
def _allocate_network_buffers(runtime: RuntimeContext, branches: int) -> NetworkBuffers: |
|
|
device = runtime.device |
|
|
logits_dtype = runtime.precision.logits |
|
|
data_cfg = runtime.config.data |
|
|
text_logits = torch.empty((branches, 1, data_cfg.action_vocab_size), dtype=logits_dtype, device=device) |
|
|
cb0_logits = torch.empty((branches, 1, data_cfg.audio_vocab_size), dtype=logits_dtype, device=device) |
|
|
dep_vocab = runtime.model.depformer.audio_vocab_limit or data_cfg.audio_vocab_size |
|
|
dep_logits = [ |
|
|
torch.empty((branches, 1, 1, dep_vocab), dtype=logits_dtype, device=device) |
|
|
for _ in range(runtime.model.depformer.num_depth) |
|
|
] |
|
|
return NetworkBuffers(text=text_logits, cb0=cb0_logits, dep=dep_logits) |
|
|
|
|
|
|
|
|
def build_initial_state( |
|
|
runtime: RuntimeContext, |
|
|
*, |
|
|
prefix: PrefixPlan | None = None, |
|
|
) -> GenerationState: |
|
|
dep_q = runtime.model.depformer.num_audio_channels |
|
|
channels = 2 + dep_q |
|
|
branches = 2 |
|
|
token_ids = runtime.constants |
|
|
step_tokens = torch.full( |
|
|
(branches, channels, 1), |
|
|
token_ids.pad, |
|
|
dtype=torch.long, |
|
|
device=runtime.device, |
|
|
) |
|
|
step_tokens[0, 0, 0] = token_ids.bos |
|
|
step_tokens[0, 1, 0] = token_ids.pad |
|
|
step_tokens[1, 0, 0] = token_ids.zero |
|
|
step_tokens[1, 1, 0] = token_ids.pad |
|
|
prefix_len = 0 |
|
|
if prefix is not None: |
|
|
delayed = delay_frames(prefix.aligned_tokens, runtime.audio_delays, token_ids.audio_pad) |
|
|
prefix_len = delayed.shape[1] |
|
|
limit = runtime.config.runtime.max_context_steps |
|
|
total_steps = max(limit + prefix_len + 1, limit) |
|
|
decode_state = runtime.model.init_state(branches, runtime.device, total_steps) |
|
|
audio_buf = torch.full( |
|
|
(branches, dep_q, total_steps), |
|
|
token_ids.ungenerated, |
|
|
dtype=torch.long, |
|
|
device=runtime.device, |
|
|
) |
|
|
if prefix is not None: |
|
|
delayed = delay_frames(prefix.aligned_tokens, runtime.audio_delays, token_ids.audio_pad).to(runtime.device) |
|
|
audio_buf[0, :, : delayed.shape[1]] = delayed |
|
|
if branches > 1: |
|
|
audio_buf[1:, :, : delayed.shape[1]] = delayed |
|
|
return GenerationState(decode_state, step_tokens, audio_buf) |
|
|
|
|
|
|
|
|
def _fill_audio_channels( |
|
|
step_tokens: torch.Tensor, |
|
|
audio_buf: torch.Tensor, |
|
|
delays: torch.Tensor, |
|
|
step: int, |
|
|
bos_token: int, |
|
|
) -> None: |
|
|
channels = delays.numel() |
|
|
if channels == 0: |
|
|
return |
|
|
target = step_tokens[:, 2 : 2 + channels, 0] |
|
|
if step < audio_buf.shape[-1]: |
|
|
target.copy_(audio_buf[:, :channels, step]) |
|
|
else: |
|
|
target.fill_(bos_token) |
|
|
mask = delays > step |
|
|
if mask.any().item(): |
|
|
target[:, mask] = bos_token |
|
|
|
|
|
|
|
|
def _execute_transformer_step( |
|
|
step_tokens: torch.Tensor, |
|
|
positions_view: torch.Tensor, |
|
|
generation: GenerationState, |
|
|
transformer_step, |
|
|
buffers: NetworkBuffers, |
|
|
) -> torch.Tensor: |
|
|
hidden_t, text_logits_t, cb0_logits_t, present = transformer_step( |
|
|
step_tokens, |
|
|
positions_view, |
|
|
generation.transformer_cache, |
|
|
) |
|
|
buffers.text.copy_(text_logits_t) |
|
|
buffers.cb0.copy_(cb0_logits_t) |
|
|
generation.transformer_cache = present |
|
|
return hidden_t |
|
|
|
|
|
|
|
|
def _execute_depformer_stage( |
|
|
stage_index: int, |
|
|
prev_audio: torch.Tensor, |
|
|
hidden_t: torch.Tensor, |
|
|
generation: GenerationState, |
|
|
depformer_step, |
|
|
main_tokens: Optional[torch.Tensor], |
|
|
second_tokens: Optional[torch.Tensor], |
|
|
buffers: NetworkBuffers, |
|
|
) -> None: |
|
|
logits_stage, dep_present = depformer_step( |
|
|
prev_audio=prev_audio, |
|
|
transformer_out=hidden_t, |
|
|
stage_index=stage_index, |
|
|
cache=generation.depformer_cache, |
|
|
main_text=main_tokens if stage_index == 0 else None, |
|
|
second_text=second_tokens if stage_index == 0 else None, |
|
|
) |
|
|
target = buffers.dep[stage_index] |
|
|
if logits_stage.shape != target.shape: |
|
|
raise RuntimeError( |
|
|
f"depformer logits shape mismatch: {logits_stage.shape} vs {target.shape}" |
|
|
) |
|
|
target.copy_(logits_stage) |
|
|
generation.depformer_cache = dep_present |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_generation_loop( |
|
|
runtime: RuntimeContext, |
|
|
*, |
|
|
state: State, |
|
|
generation: GenerationState, |
|
|
config: GenerationConfig, |
|
|
start_step: int = 0, |
|
|
logger: RuntimeLogger | None = None, |
|
|
) -> tuple[Optional[int], torch.Tensor]: |
|
|
step_tokens = generation.step_tokens |
|
|
audio_buf = generation.audio_buf |
|
|
branches = step_tokens.shape[0] |
|
|
max_context = runtime.config.runtime.max_context_steps |
|
|
if max_context <= 0: |
|
|
raise ValueError("Runtime configuration must specify a positive max_context_steps") |
|
|
positions = torch.empty(1, 1, dtype=torch.long, device=runtime.device) |
|
|
main_tokens = torch.empty(branches, dtype=torch.long, device=runtime.device) |
|
|
aux_tokens = torch.empty(branches, dtype=torch.long, device=runtime.device) |
|
|
cfg_active = config.cfg_scale != 1.0 |
|
|
token_ids = runtime.constants |
|
|
delay_tensor = runtime.audio_delay_tensor |
|
|
max_delay = int(delay_tensor.max().item()) if delay_tensor.numel() else 0 |
|
|
flush_tail = max_delay + getattr(runtime.machine, "max_padding", 0) |
|
|
first_word_frame: Optional[int] = None |
|
|
eos_cutoff: Optional[int] = None |
|
|
last_step = start_step - 1 |
|
|
use_graph = bool(config.use_cuda_graph and runtime.device.type == "cuda") |
|
|
transformer_step = runtime.transformer_step |
|
|
depformer_step = runtime.depformer_step |
|
|
buffers = _allocate_network_buffers(runtime, branches) |
|
|
positions_view = positions.expand(branches, -1) |
|
|
transformer_capture = None |
|
|
dep_captures: list[dict] | None = None |
|
|
if use_graph: |
|
|
_ensure_graph_cublas_ready(runtime.device) |
|
|
processed_steps = 0 |
|
|
report_interval = 12 |
|
|
with torch.inference_mode(): |
|
|
for offset in range(max_context): |
|
|
t = start_step + offset |
|
|
if eos_cutoff is not None and t >= eos_cutoff: |
|
|
break |
|
|
if t + 1 >= audio_buf.shape[-1]: |
|
|
break |
|
|
generation.reset_dep_cache() |
|
|
positions.fill_(t) |
|
|
_fill_audio_channels(step_tokens, audio_buf, delay_tensor, t, token_ids.audio_bos) |
|
|
if branches > 1: |
|
|
step_tokens[1:, 0, 0] = token_ids.zero |
|
|
step_tokens[1:, 1, 0] = token_ids.pad |
|
|
if use_graph: |
|
|
if transformer_capture is None: |
|
|
torch.cuda.synchronize() |
|
|
graph = torch.cuda.CUDAGraph() |
|
|
with torch.cuda.graph(graph): |
|
|
hidden_ref = _execute_transformer_step( |
|
|
step_tokens, |
|
|
positions_view, |
|
|
generation, |
|
|
transformer_step, |
|
|
buffers, |
|
|
) |
|
|
transformer_capture = (graph, hidden_ref) |
|
|
if runtime.model.depformer.num_depth > 0: |
|
|
dep_captures = [] |
|
|
for idx in range(runtime.model.depformer.num_depth): |
|
|
capture = { |
|
|
"graph": torch.cuda.CUDAGraph(), |
|
|
"captured": False, |
|
|
"prev_audio": torch.empty((branches,), dtype=torch.long, device=runtime.device), |
|
|
"main_tokens": torch.empty((branches,), dtype=torch.long, device=runtime.device) if idx == 0 else None, |
|
|
"second_tokens": torch.empty((branches,), dtype=torch.long, device=runtime.device) if idx == 0 else None, |
|
|
} |
|
|
dep_captures.append(capture) |
|
|
else: |
|
|
transformer_capture[0].replay() |
|
|
hidden_t = transformer_capture[1] |
|
|
else: |
|
|
hidden_t = _execute_transformer_step( |
|
|
step_tokens, |
|
|
positions_view, |
|
|
generation, |
|
|
transformer_step, |
|
|
buffers, |
|
|
) |
|
|
|
|
|
guided_text = apply_classifier_guidance(buffers.text, cfg_active, config.cfg_scale, config.cfg_filter_k) |
|
|
if guided_text.shape[0] > 1: |
|
|
guided_text = guided_text[:1] |
|
|
text_token = sample_token( |
|
|
guided_text, |
|
|
temp=config.text.temperature, |
|
|
top_k=config.text.top_k, |
|
|
).item() |
|
|
|
|
|
main_token, aux_token, _ = runtime.machine.process(t, state, text_token) |
|
|
second_token = aux_token if aux_token != -1 else token_ids.pad |
|
|
if first_word_frame is None and main_token == token_ids.new_word: |
|
|
first_word_frame = t - config.initial_padding |
|
|
step_tokens[:, 0, 0] = main_token |
|
|
step_tokens[:, 1, 0] = second_token |
|
|
|
|
|
guided_cb0 = apply_classifier_guidance(buffers.cb0, cfg_active, config.cfg_scale, config.cfg_filter_k) |
|
|
if guided_cb0.shape[0] > 1: |
|
|
guided_cb0 = guided_cb0[:1] |
|
|
masked_cb0 = mask_audio_logits(guided_cb0, token_ids.audio_pad, token_ids.audio_bos) |
|
|
codebook_token = sample_audio_logits(masked_cb0, config.audio.temperature, config.audio.top_k) |
|
|
audio_buf[:, 0, t + 1] = codebook_token |
|
|
|
|
|
prev_audio = codebook_token.expand(branches) |
|
|
main_tokens.fill_(main_token) |
|
|
aux_tokens.fill_(second_token) |
|
|
for stage in range(runtime.model.depformer.num_depth): |
|
|
if use_graph and dep_captures is not None: |
|
|
capture = dep_captures[stage] |
|
|
capture["prev_audio"].copy_(prev_audio) |
|
|
if capture["main_tokens"] is not None and stage == 0: |
|
|
capture["main_tokens"].copy_(main_tokens) |
|
|
capture["second_tokens"].copy_(aux_tokens) |
|
|
if not capture["captured"]: |
|
|
torch.cuda.synchronize() |
|
|
with torch.cuda.graph(capture["graph"]): |
|
|
_execute_depformer_stage( |
|
|
stage_index=stage, |
|
|
prev_audio=capture["prev_audio"], |
|
|
hidden_t=hidden_t, |
|
|
generation=generation, |
|
|
depformer_step=depformer_step, |
|
|
main_tokens=capture["main_tokens"], |
|
|
second_tokens=capture["second_tokens"], |
|
|
buffers=buffers, |
|
|
) |
|
|
capture["captured"] = True |
|
|
else: |
|
|
capture["graph"].replay() |
|
|
else: |
|
|
_execute_depformer_stage( |
|
|
stage_index=stage, |
|
|
prev_audio=prev_audio, |
|
|
hidden_t=hidden_t, |
|
|
generation=generation, |
|
|
depformer_step=depformer_step, |
|
|
main_tokens=main_tokens, |
|
|
second_tokens=aux_tokens, |
|
|
buffers=buffers, |
|
|
) |
|
|
dep_logits = apply_classifier_guidance(buffers.dep[stage], cfg_active, config.cfg_scale, config.cfg_filter_k) |
|
|
if dep_logits.shape[0] > 1: |
|
|
dep_logits = dep_logits[:1] |
|
|
stage_token = sample_audio_logits( |
|
|
dep_logits, |
|
|
config.audio.temperature, |
|
|
config.audio.top_k, |
|
|
) |
|
|
audio_buf[:, stage + 1, t + 1] = stage_token |
|
|
prev_audio = stage_token.expand(branches) |
|
|
last_step = t |
|
|
if eos_cutoff is None and state.end_step is not None: |
|
|
eos_cutoff = state.end_step + flush_tail |
|
|
processed_steps = offset + 1 |
|
|
if logger and processed_steps % report_interval == 0: |
|
|
logger.progress(processed_steps, max_context) |
|
|
|
|
|
if logger and processed_steps and processed_steps % report_interval != 0: |
|
|
logger.progress(processed_steps, max_context) |
|
|
|
|
|
if first_word_frame is None: |
|
|
first_word_frame = start_step |
|
|
if last_step < start_step: |
|
|
limit = min(start_step + 1, audio_buf.shape[-1]) |
|
|
else: |
|
|
limit = min(last_step + 2, audio_buf.shape[-1]) |
|
|
trimmed = generation.trim_audio(limit, token_ids.audio_pad, token_ids.ungenerated) |
|
|
return first_word_frame, trimmed |
|
|
|
|
|
|
|
|
def decode_audio(runtime: RuntimeContext, tokens: torch.Tensor) -> torch.Tensor: |
|
|
if tokens.shape[-1] == 0: |
|
|
return torch.zeros(0, device=runtime.device) |
|
|
with torch.inference_mode(): |
|
|
pcm = runtime.mimi.decode(tokens.to(runtime.device)) |
|
|
return pcm[0, 0] |
|
|
|
|
|
def warmup_with_prefix( |
|
|
runtime: RuntimeContext, |
|
|
plan: PrefixPlan, |
|
|
state: State, |
|
|
generation: GenerationState, |
|
|
) -> int: |
|
|
step_tokens = generation.step_tokens |
|
|
model_state = generation.decode |
|
|
branches = step_tokens.shape[0] |
|
|
device = runtime.device |
|
|
tokens = plan.aligned_tokens.to(device) |
|
|
new_word_steps = set(plan.new_word_steps) |
|
|
positions = torch.empty(1, 1, dtype=torch.long, device=device) |
|
|
|
|
|
with torch.inference_mode(): |
|
|
for t in range(plan.aligned_frames): |
|
|
positions.fill_(t) |
|
|
channels = tokens.shape[0] |
|
|
for cb in range(channels): |
|
|
delay = runtime.audio_delays[cb] if cb < len(runtime.audio_delays) else 0 |
|
|
idx = t - delay |
|
|
value = tokens[cb, idx] if idx >= 0 else runtime.constants.audio_bos |
|
|
step_tokens[:, 2 + cb, 0] = value |
|
|
hidden, text_logits, cb0_logits, present = runtime.model.transformer.forward_step( |
|
|
step_tokens, |
|
|
positions.expand(branches, -1), |
|
|
model_state.transformer, |
|
|
) |
|
|
model_state.transformer = present |
|
|
|
|
|
forced = runtime.constants.new_word if t in new_word_steps else runtime.constants.pad |
|
|
main_token, aux_token, _ = runtime.machine.process(t, state, forced, is_forced=True) |
|
|
second_token = runtime.constants.pad if aux_token == -1 else aux_token |
|
|
step_tokens[0, 0, 0] = main_token |
|
|
step_tokens[0, 1, 0] = second_token |
|
|
if branches > 1: |
|
|
step_tokens[1:, 0, 0] = runtime.constants.zero |
|
|
step_tokens[1:, 1, 0] = runtime.constants.pad |
|
|
|
|
|
return max(plan.aligned_frames - 1, 0) |
|
|
__all__ = [ |
|
|
"build_initial_state", |
|
|
"run_generation_loop", |
|
|
"decode_audio", |
|
|
"warmup_with_prefix", |
|
|
"GenerationState", |
|
|
] |
|
|
|