from dataclasses import dataclass from typing import Callable, List, Tuple from huggingface_hub import hf_hub_download import safetensors.torch as st import torch import torchaudio from torchcodec.decoders import AudioDecoder from autoencoder import DAC, build_ae from model import EchoDiT def load_model_from_hf(repo_id: str = "jordand/echo-tts-base", device: str = "cuda", dtype: torch.dtype | None = torch.bfloat16, compile: bool = False, token: str | None = None, delete_blockwise_modules: bool = False) -> EchoDiT: with torch.device("meta"): model = EchoDiT( latent_size=80, model_size=2048, num_layers=24, num_heads=16, intermediate_size=5888, norm_eps=1e-5, text_vocab_size=256, text_model_size=1280, text_num_layers=14, text_num_heads=10, text_intermediate_size=3328, speaker_patch_size=4, speaker_model_size=1280, speaker_num_layers=14, speaker_num_heads=10, speaker_intermediate_size=3328, timestep_embed_size=512, adaln_rank=256, ) w_path = hf_hub_download(repo_id, "pytorch_model.safetensors", token=token) state = st.load_file(w_path, device="cpu") if delete_blockwise_modules: state = {k: v for k, v in state.items() if not ( k.startswith("latent_encoder.") or k.startswith("latent_norm") or ".wk_latent" in k or ".wv_latent" in k )} if dtype is not None: state = {k: v.to(dtype=dtype) for k, v in state.items()} state = {k: v.to(device=device) for k, v in state.items()} model.load_state_dict(state, strict=False, assign=True) model = model.eval() if compile: model = compile_model(model) return model def compile_model(model: EchoDiT) -> EchoDiT: model = torch.compile(model) model.get_kv_cache_text = torch.compile(model.get_kv_cache_text) model.get_kv_cache_speaker = torch.compile(model.get_kv_cache_speaker) model.get_kv_cache_latent = torch.compile(model.get_kv_cache_latent) return model def load_fish_ae_from_hf(repo_id: str = "jordand/fish-s1-dac-min", device: str = "cuda", dtype: torch.dtype | None = torch.float32, compile: bool = False, token: str | None = None) -> DAC: with torch.device("meta"): fish_ae = build_ae() w_path = hf_hub_download(repo_id, "pytorch_model.safetensors", token=token) if dtype is not None and dtype != torch.float32: state = st.load_file(w_path, device="cpu") state = {k: v.to(dtype=dtype) for k, v in state.items()} state = {k: v.to(device=device) for k, v in state.items()} fish_ae.load_state_dict(state, strict=False, assign=True) else: state = st.load_file(w_path, device=device) fish_ae.load_state_dict(state, strict=False, assign=True) fish_ae = fish_ae.eval().to(device) if compile: fish_ae = compile_fish_ae(fish_ae) return fish_ae def compile_fish_ae(fish_ae: DAC) -> DAC: fish_ae.quantizer.upsample = torch.compile(fish_ae.quantizer.upsample) fish_ae.quantizer.downsample = torch.compile(fish_ae.quantizer.downsample) fish_ae.quantizer.pre_module = torch.compile(fish_ae.quantizer.pre_module) fish_ae.quantizer.post_module = torch.compile(fish_ae.quantizer.post_module) return fish_ae @dataclass class PCAState: pca_components: torch.Tensor pca_mean: torch.Tensor latent_scale: float def load_pca_state_from_hf(repo_id: str = "jordand/echo-tts-base", device: str = "cuda", filename: str = "pca_state.safetensors", token: str | None = None) -> PCAState: p_path = hf_hub_download(repo_id, filename, token=token) t = st.load_file(p_path, device=device) return PCAState( pca_components=t["pca_components"], pca_mean=t["pca_mean"], latent_scale=float(t["latent_scale"].item()), ) # ________ def load_audio(path: str, max_duration: int = 300) -> torch.Tensor: decoder = AudioDecoder(path) sr = decoder.metadata.sample_rate audio = decoder.get_samples_played_in_range(0, max_duration) audio = audio.data.mean(dim=0).unsqueeze(0) audio = torchaudio.functional.resample(audio, sr, 44_100) audio = audio / torch.maximum(audio.abs().max(), torch.tensor(1.)) # is this better than clipping? should we target a specific energy level? return audio def tokenizer_encode(text: str, append_bos: bool = True, normalize: bool = True, return_normalized_text: bool = False) -> torch.Tensor | Tuple[torch.Tensor, str]: if normalize: text = text.replace("…", "...") text = text.replace('’', "'") text = text.replace('”', '"') text = text.replace('”', '"') text = text.replace("\n", " ") text = text.replace(":", ",") text = text.replace(";", ",") text = text.replace("—", ", ") if not text.startswith("[") and not text.startswith("(") and 'S1' not in text and 'S2' not in text: text = "[S1] " + text b = list(text.encode("utf-8")) if append_bos: b.insert(0, 0) if return_normalized_text: return torch.tensor(b), text return torch.tensor(b) def get_text_input_ids_and_mask(text_arr: List[str], max_length: int | None, device: str | None = None, normalize: bool = True, return_normalized_text: bool = False, pad_to_max: bool = True) -> Tuple[torch.Tensor, torch.Tensor] | Tuple[torch.Tensor, torch.Tensor, List[str]]: encoded_texts = [tokenizer_encode(text, normalize=normalize, return_normalized_text=True) for text in text_arr] if max_length is None: max_length = max(len(enc) for enc, _ in encoded_texts) tokens = torch.zeros((len(text_arr), max_length), dtype=torch.int32) mask = torch.zeros((len(text_arr), max_length), dtype=torch.bool) for i, (encoded, _) in enumerate(encoded_texts): length = min(len(encoded), max_length) tokens[i, :length] = encoded[:length] mask[i, :length] = 1 if not pad_to_max and max_length is not None: tokens, mask = tokens[:, :max_length], mask[:, :max_length] if device is not None: tokens, mask = tokens.to(device), mask.to(device) if return_normalized_text: return tokens, mask, [text for _, text in encoded_texts] return tokens, mask # ________ @torch.inference_mode() def ae_encode(fish_ae: DAC, pca_state: PCAState, audio: torch.Tensor) -> torch.Tensor: assert audio.ndim == 3 and audio.shape[1] == 1 # (b, 1, length) z_q = fish_ae.encode_zq(audio).float() z_q = (z_q.transpose(1, 2) - pca_state.pca_mean) @ pca_state.pca_components.T z_q = z_q * pca_state.latent_scale return z_q @torch.inference_mode() def ae_decode(fish_ae: DAC, pca_state: PCAState, z_q: torch.Tensor) -> torch.Tensor: z_q = (z_q / pca_state.latent_scale) @ pca_state.pca_components + pca_state.pca_mean return fish_ae.decode_zq(z_q.transpose(1, 2).to(fish_ae.dtype)).float() @torch.inference_mode() def ae_reconstruct(fish_ae: DAC, pca_state: PCAState, audio: torch.Tensor) -> torch.Tensor: assert audio.ndim == 3 and audio.shape[1] == 1 # (b, 1, length) z_q = ae_encode(fish_ae, pca_state, audio.to(fish_ae.dtype)) return ae_decode(fish_ae, pca_state, z_q) # ________ @torch.inference_mode() def get_speaker_latent_and_mask( fish_ae: DAC, pca_state: PCAState, audio: torch.Tensor, # (1, length) max_speaker_latent_length: int = 6400, # pretrained max length audio_chunk_size: int = 640 * 2048, # (~30 seconds, 1/10 max speaker condition size; max chunk seen in training) pad_to_max: bool = False, divis_by_patch_size: int | None = 4, ) -> Tuple[torch.Tensor, torch.Tensor]: # gets speaker latent and mask from audio, computes in chunks and concatenates (similar to training setup) AE_DOWNSAMPLE_FACTOR = 2048 max_audio_len_length = max_speaker_latent_length * AE_DOWNSAMPLE_FACTOR assert audio.ndim == 2 and audio.shape[0] == 1 # (1, length) audio = audio[:, :max_audio_len_length] latent_arr = [] for i in range(0, audio.shape[1], audio_chunk_size): audio_chunk = audio[:, i:i + audio_chunk_size] if audio_chunk.shape[1] < audio_chunk_size: audio_chunk = torch.nn.functional.pad(audio_chunk, (0, audio_chunk_size - audio_chunk.shape[1])) latent_chunk = ae_encode(fish_ae, pca_state, audio_chunk.unsqueeze(0)) latent_arr.append(latent_chunk) speaker_latent = torch.cat(latent_arr, dim=1) actual_latent_length = audio.shape[1] // AE_DOWNSAMPLE_FACTOR speaker_mask = (torch.arange(speaker_latent.shape[1], device=speaker_latent.device) < actual_latent_length).unsqueeze(0) if pad_to_max and speaker_latent.shape[1] < max_speaker_latent_length: speaker_latent = torch.nn.functional.pad(speaker_latent, (0, 0, 0, max_speaker_latent_length - speaker_latent.shape[1])) speaker_mask = torch.nn.functional.pad(speaker_mask, (0, max_speaker_latent_length - speaker_mask.shape[1])) elif not pad_to_max: speaker_latent = speaker_latent[:, :actual_latent_length] speaker_mask = speaker_mask[:, :actual_latent_length] if divis_by_patch_size is not None: speaker_latent = speaker_latent[:, :speaker_latent.shape[1] // divis_by_patch_size * divis_by_patch_size] speaker_mask = speaker_mask[:, :speaker_mask.shape[1] // divis_by_patch_size * divis_by_patch_size] return speaker_latent, speaker_mask # ________ def find_flattening_point(data, target_value=0.0, window_size=20, std_threshold=0.05): # simple heuristic to find end of latent generations; slow and can be improved # (data is (length, 80)) padded_data = torch.cat([data, torch.zeros(window_size, *data.shape[1:], device=data.device, dtype=data.dtype)]) for i in range(len(padded_data) - window_size): window = padded_data[i:i + window_size] if window.std() < std_threshold and abs(window.mean() - target_value) < 0.1: return i return len(data) def crop_audio_to_flattening_point(audio: torch.Tensor, latent: torch.Tensor) -> torch.Tensor: # (audio is (..., length), latent is (length, 80)) flattening_point = find_flattening_point(latent) return audio[..., :flattening_point * 2048] SampleFn = Callable[ [EchoDiT, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int], torch.Tensor ] @torch.inference_mode() def sample_pipeline( model: EchoDiT, fish_ae: DAC, pca_state: PCAState, sample_fn: SampleFn, text_prompt: str, speaker_audio: torch.Tensor | None, rng_seed: int, pad_to_max_speaker_latent_length: int | None = None, pad_to_max_text_length: int | None = None, normalize_text: bool = True, ) -> Tuple[torch.Tensor, str]: MAX_SPEAKER_LATENT_LENGTH = 6400 # max seen during training, though maybe can go higher? MAX_TEXT_LENGTH = 768 device, dtype = model.device, model.dtype text_input_ids, text_mask, normalized_text = get_text_input_ids_and_mask([text_prompt], max_length=min(pad_to_max_text_length or MAX_TEXT_LENGTH, MAX_TEXT_LENGTH), device=device, normalize=normalize_text, return_normalized_text=True, pad_to_max=(pad_to_max_text_length is not None)) if speaker_audio is None: speaker_latent = torch.zeros((1, pad_to_max_speaker_latent_length or 4, 80), device=device, dtype=dtype) speaker_mask = torch.zeros((1, pad_to_max_speaker_latent_length or 4), device=device, dtype=torch.bool) else: speaker_latent, speaker_mask = get_speaker_latent_and_mask( fish_ae, pca_state, speaker_audio.to(fish_ae.dtype).to(device), max_speaker_latent_length=pad_to_max_speaker_latent_length or MAX_SPEAKER_LATENT_LENGTH, pad_to_max=(pad_to_max_speaker_latent_length is not None) ) latent_out = sample_fn(model, speaker_latent, speaker_mask, text_input_ids, text_mask, rng_seed) audio_out = ae_decode(fish_ae, pca_state, latent_out) audio_out = crop_audio_to_flattening_point(audio_out, latent_out[0]) return audio_out, normalized_text[0] # ________ KVCache = List[Tuple[torch.Tensor, torch.Tensor]] def _concat_kv_caches(*caches: KVCache) -> KVCache: # helper that concatenates multiple KV caches along the batch dimension num_layers = len(caches[0]) result = [] for i in range(num_layers): k = torch.cat([c[i][0] for c in caches], dim=0) v = torch.cat([c[i][1] for c in caches], dim=0) result.append((k, v)) return result def _multiply_kv_cache(cache: KVCache, scale: float, max_layers: int | None = None) -> None: # helper that multiplies KV cache values in-place, for kv speaker scaling num_layers = len(cache) if max_layers is None else min(max_layers, len(cache)) for i in range(num_layers): k, v = cache[i] k.mul_(scale) v.mul_(scale) def _temporal_score_rescale( v_pred: torch.Tensor, x_t: torch.Tensor, t: float, rescale_k: float, rescale_sigma: float ) -> torch.Tensor: # for https://arxiv.org/pdf/2510.01184 if t < 1: snr = (1 - t) ** 2 / (t ** 2) ratio = (snr * rescale_sigma ** 2 + 1) / (snr * rescale_sigma ** 2 / rescale_k + 1) return 1 / (1 - t) * (ratio * ((1 - t) * v_pred + x_t) - x_t) return v_pred @torch.inference_mode() def sample_euler_cfg_independent_guidances( model: EchoDiT, speaker_latent: torch.Tensor, speaker_mask: torch.Tensor, text_input_ids: torch.Tensor, text_mask: torch.Tensor, rng_seed: int, num_steps: int, cfg_scale_text: float, cfg_scale_speaker: float, cfg_min_t: float, cfg_max_t: float, truncation_factor: float | None, rescale_k: float | None, rescale_sigma: float | None, speaker_kv_scale: float | None, speaker_kv_max_layers: int | None, speaker_kv_min_t: float | None, sequence_length: int | None = None, ) -> torch.Tensor: if sequence_length is None: sequence_length = 640 # max sequence length during training INIT_SCALE = 0.999 # so that we can apply rescale to first step device, dtype = model.device, model.dtype batch_size = text_input_ids.shape[0] rng = torch.Generator(device=device).manual_seed(rng_seed) t_schedule = torch.linspace(1., 0., num_steps + 1, device=device) * INIT_SCALE text_mask_uncond = torch.zeros_like(text_mask) speaker_mask_uncond = torch.zeros_like(speaker_mask) kv_text_cond = model.get_kv_cache_text(text_input_ids, text_mask) kv_speaker_cond = model.get_kv_cache_speaker(speaker_latent.to(dtype)) if speaker_kv_scale is not None: _multiply_kv_cache(kv_speaker_cond, speaker_kv_scale, speaker_kv_max_layers) # masks prevent decoder from attending to unconds: kv_text_full = _concat_kv_caches(kv_text_cond, kv_text_cond, kv_text_cond) kv_speaker_full = _concat_kv_caches(kv_speaker_cond, kv_speaker_cond, kv_speaker_cond) full_text_mask = torch.cat([text_mask, text_mask_uncond, text_mask], dim=0) full_speaker_mask = torch.cat([speaker_mask, speaker_mask, speaker_mask_uncond], dim=0) x_t = torch.randn((batch_size, sequence_length, 80), device=device, dtype=torch.float32, generator=rng) if truncation_factor is not None: x_t = x_t * truncation_factor for i in range(num_steps): t, t_next = t_schedule[i], t_schedule[i + 1] has_cfg = ((t >= cfg_min_t) * (t <= cfg_max_t)).item() if has_cfg: v_cond, v_uncond_text, v_uncond_speaker = model( x=torch.cat([x_t, x_t, x_t], dim=0).to(dtype), t=(torch.ones((batch_size * 3,), device=device) * t).to(dtype), text_mask=full_text_mask, speaker_mask=full_speaker_mask, kv_cache_text=kv_text_full, kv_cache_speaker=kv_speaker_full, ).float().chunk(3, dim=0) v_pred = v_cond + cfg_scale_text * (v_cond - v_uncond_text) + cfg_scale_speaker * (v_cond - v_uncond_speaker) # can also use a single, joint unconditional for fewer NFE else: v_pred = model( x=x_t.to(dtype), t=(torch.ones((batch_size,), device=device) * t).to(dtype), text_mask=text_mask, speaker_mask=speaker_mask, kv_cache_text=kv_text_cond, kv_cache_speaker=kv_speaker_cond, ).float() # optional temporal score rescaling: https://arxiv.org/pdf/2510.01184 if rescale_k is not None and rescale_sigma is not None: v_pred = _temporal_score_rescale(v_pred, x_t, t, rescale_k, rescale_sigma) # optional kv speaker scaling if speaker_kv_scale is not None and t_next < speaker_kv_min_t and t >= speaker_kv_min_t: _multiply_kv_cache(kv_speaker_cond, 1. / speaker_kv_scale, speaker_kv_max_layers) kv_speaker_full = _concat_kv_caches(kv_speaker_cond, kv_speaker_cond, kv_speaker_cond) x_t = x_t + v_pred * (t_next - t) return x_t # ___________________________________________________________ # simple example if __name__ == "__main__": model = load_model_from_hf(delete_blockwise_modules=True) fish_ae = load_fish_ae_from_hf() pca_state = load_pca_state_from_hf() speaker_audio_path = "/path/to/speaker/audio.wav" speaker_audio = load_audio(speaker_audio_path).cuda() speaker_latent, speaker_mask = get_speaker_latent_and_mask(fish_ae, pca_state, speaker_audio) text = "[S1] Alright, I'm going to demo this new model called Echo TTS. Hopefully this works, I'm super excited to try this and see what it can do." text_input_ids, text_mask = get_text_input_ids_and_mask([text], max_length=None, device="cuda") latent_out = sample_euler_cfg_independent_guidances( model=model, speaker_latent=speaker_latent, speaker_mask=speaker_mask, text_input_ids=text_input_ids, text_mask=text_mask, rng_seed=0, num_steps=40, cfg_scale_text=3.0, cfg_scale_speaker=8.0, cfg_min_t=0.5, cfg_max_t=1.0, truncation_factor=0.8, rescale_k=None, rescale_sigma=None, speaker_kv_scale=None, speaker_kv_max_layers=None, speaker_kv_min_t=None, sequence_length=640, # (max 640. shorter lengths will generate prefixes, not necessarily full generations) ) audio_out = ae_decode(fish_ae, pca_state, latent_out) audio_out = crop_audio_to_flattening_point(audio_out, latent_out[0]) torchaudio.save("output.wav", audio_out[0].cpu(), 44100)