--- gated: true extra_gated_prompt: > **Terms of Academic, Non-Commercial Use** I ACKNOWLEDGE THAT I HAVE READ [THE LICENSE AGREEMENT](./LICENSE), UNDERSTAND IT AND AGREE TO BE BOUND BY ITS TERMS AND CONDITIONS. It is important that you read the entirety of and understand the License Agreement. There are, however, a few key points that we need to emphasize again: - **ACADEMIC, NON-COMMERCIAL USE**: The license granted is for academic, non-commercial purposes only. The term "academic, non-commercial" means academic or other scholarly research which (a) is not undertaken for any direct or indirect for-profit purposes, and (b) is not intended to produce works, services, or data for commercial use. - **INTERNAL USE**: The license granted is for your own internal use only. You are not allowed to sublicense, distribute, transfer, disclose or make available the software to any third party. - **NO WARRANTY**: The Software is provided "as is" and any express or implied warranties are disclaimed. extra_gated_button_content: Agree and access pipeline_tag: audio-to-audio --- # CAST 0.7B — Speech-to-Speech [![arXiv](https://img.shields.io/badge/arXiv-2509.26276-b31b1b.svg)](https://arxiv.org/abs/2509.26276) [![Demo](https://img.shields.io/badge/Demo-speech--cast-blue)](https://mortezaro.github.io/speech-cast/) [![Codec Dependency](https://img.shields.io/badge/Codec-cast--wavtokenizer--24k--40tps-informational)](https://huggingface.co/KrauthammerLab/cast-wavtokenizer-24k-40tps) Final checkpoint files. Depends on `KrauthammerLab/cast-wavtokenizer-24k-40tps` for encode/decode. # CAST 0.7B — Speech-to-Speech Model CAST 0.7B is a speech-to-speech language model built on a 0.7B parameter Gemma3-style LM. It generates natural continuations of spoken audio. It requires the companion [CAST WavTokenizer](https://huggingface.co/KrauthammerLab/cast-wavtokenizer-24k-40tps) for encode/decode. --- ## Demo Interactive samples and usage examples: **https://mortezaro.github.io/speech-cast/** --- ## Paper **Optimizing Speech Language Models for Acoustic Consistency** arXiv: **2509.26276** — https://arxiv.org/abs/2509.26276 > We study speech language models that incorporate semantic initialization and planning losses to achieve robust and consistent generation. Our approach initializes speech tokens with self-supervised features, applies a light alignment loss, and trains with thinning and auxiliary objectives that target robustness and content planning. We train three models: a 0.7B speech-only model, a 1.0B speech-only model, and a 1.0B interleaved model with both text and speech. Acoustic studies show that the speech-only models achieve the highest consistency across speaker, gender, sentiment, room, and background factors, surpassing larger systems. Interleaving improves lexical and syntactic probes and semantic–acoustic alignment but reduces consistency. Linear probes show that our initialization biases the model toward content structure while trading off prosody detail. These results show that LM-side design and training mix control the balance between acoustic stability and semantic grounding without changes to the tokenizer or runtime architecture. A demo and model weights are available for exploration. ## Installation ```bash pip install torch torchaudio transformers accelerate soundfile pip install git+https://github.com/jishengpeng/WavTokenizer.git ``` 1- Resynthesis ```bash WT_REPO = "KrauthammerLab/cast-wavtokenizer-24k-40tps" # Download tokenizer ckpt + config from HF wt_ckpt = hf_hub_download(WT_REPO, filename="wavtokenizer_large_unify_600_24k.ckpt") try: wt_cfg = hf_hub_download(WT_REPO, filename="config.yaml") except Exception: wt_cfg = None # cfg optional in your setup device = "cuda" if torch.cuda.is_available() else "cpu" # Load WavTokenizer (codec) if wt_cfg is not None: wt = WavTokenizer.from_pretrained0802(wt_cfg, wt_ckpt).to(device) else: wt = WavTokenizer.from_pretrained0802(None, wt_ckpt).to(device) # Load a 16 kHz prompt wav16, sr = torchaudio.load("prompt_16k.wav") # mono recommended assert sr == 16000, f"Expected 16k input, got {sr}" # (Optional) ensure mono if wav16.size(0) > 1: wav16 = wav16.mean(dim=0, keepdim=True) # Resample 16k -> 24k before encode (your pipeline runs at 24k) wav24 = torchaudio.functional.resample(wav16, orig_freq=16000, new_freq=24000).to(device) # Encode → features, codes bandwidth_id = torch.tensor([0], device=device) feats, codes = wt.encode_infer(wav24, bandwidth_id=bandwidth_id) # feats: [1, ?, T], codes: [1, streams?, T] or [1, T] # Decode back to waveform (24 kHz) recon24 = wt.decode(feats, bandwidth_id=bandwidth_id) # [1, T] or [1,1,T] if recon24.dim() == 3: recon24 = recon24.squeeze(0) # Save 24k round-trip audio sf.write("recon_24k.wav", recon24.squeeze(0).detach().cpu().numpy(), 24000) print("Wrote recon_24k.wav") ``` 2) Speech generation ```bash LM_REPO = "KrauthammerLab/cast-0.7b-s2s" WT_REPO = "KrauthammerLab/cast-wavtokenizer-24k-40tps" device = "cuda" if torch.cuda.is_available() else "cpu" codes_per_second = 40 # your setup: ~40 tokens/s codebook_size = 4096 # [Sp1]..[Sp4096] speech_prefix = "[Speech]" # ---------- helpers ---------- def equal_power_crossfade(prev_24k: torch.Tensor, cont_24k: torch.Tensor, fade_ms: int = 40, sr: int = 24000) -> torch.Tensor: """Equal-power crossfade between prev and cont (both [1,T] @ 24k).""" fade = max(1, int(sr * fade_ms / 1000)) prev_24k = prev_24k.to(device) cont_24k = cont_24k.to(device) if prev_24k.size(1) < fade or cont_24k.size(1) < fade: return torch.cat([prev_24k, cont_24k], dim=1) a = prev_24k[:, -fade:] b = cont_24k[:, :fade] t = torch.linspace(0, 1, fade, device=device).view(1, -1) mix = torch.cos(t * 0.5 * math.pi) * a + torch.sin(t * 0.5 * math.pi) * b return torch.cat([prev_24k[:, :-fade], mix, cont_24k[:, fade:]], dim=1) class SpeechOnlyLogitsProcessor(LogitsProcessor): """Mask logits so only [Sp#] tokens (and EOS) can be sampled.""" def __init__(self, allowed: Set[int]): super().__init__() self.allowed = list(allowed) def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: mask = torch.full_like(scores, float("-inf")) mask[..., self.allowed] = 0.0 return scores + mask # ---------- load codec ---------- wt_ckpt = hf_hub_download(WT_REPO, filename="wavtokenizer_large_unify_600_24k.ckpt") try: wt_cfg = hf_hub_download(WT_REPO, filename="config.yaml") except Exception: wt_cfg = None if wt_cfg is not None: wt = WavTokenizer.from_pretrained0802(wt_cfg, wt_ckpt).to(device) else: wt = WavTokenizer.from_pretrained0802(None, wt_ckpt).to(device) # ---------- load LM + tokenizer ---------- tok = AutoTokenizer.from_pretrained(LM_REPO) lm = AutoModelForCausalLM.from_pretrained(LM_REPO, torch_dtype=torch.bfloat16).to(device).eval() # Build speech token id table: "[Sp1]".."[Sp4096]" must be single tokens speech_token_ids: List[int] = [] for i in range(1, codebook_size + 1): ids = tok(f"[Sp{i}]", add_special_tokens=False)["input_ids"] if len(ids) != 1: raise RuntimeError(f"[Sp{i}] is not a single token; tokenizer mismatch.") speech_token_ids.append(ids[0]) # For mapping back: token_id -> code_index (0-based) id2code = {tid: i for i, tid in enumerate(speech_token_ids)} eos_id = tok.eos_token_id allowed_ids = set(speech_token_ids + ([eos_id] if eos_id is not None else [])) # ---------- load prompt audio (16k), encode to codes ---------- wav16, sr = torchaudio.load("prompt_16k.wav") # mono assert sr == 16000 if wav16.size(0) > 1: wav16 = wav16.mean(dim=0, keepdim=True) # Resample to 24k before codec wav24 = torchaudio.functional.resample(wav16, orig_freq=16000, new_freq=24000).to(device) bw = torch.tensor([0], device=device) feats, codes = wt.encode_infer(wav24, bandwidth_id=bw) # Normalize shapes to [T] list[int] if codes.dim() == 3: codes = codes.squeeze(0) codes = codes[0] if codes.size(0) > 1 else codes.squeeze(0) elif codes.dim() == 2: codes = codes.squeeze(0) codes_list = codes.long().tolist() # each in [0..4095] # ---------- optional: decode round-trip for the stitched prefix ---------- recon24 = wt.decode(feats, bandwidth_id=bw) if recon24.dim() == 3: recon24 = recon24.squeeze(0) # ---------- build LM prefix string ---------- prefix_text = speech_prefix + "".join(f"[Sp{c+1}]" for c in codes_list) enc = tok(prefix_text, return_tensors="pt") input_ids = enc["input_ids"].to(device) attn_mask = enc.get("attention_mask", None) if attn_mask is not None: attn_mask = attn_mask.to(device) # ---------- generate continuation (about N seconds) ---------- seconds = 3.0 max_new_tokens = max(1, int(round(seconds * codes_per_second))) lp = LogitsProcessorList([SpeechOnlyLogitsProcessor(allowed_ids)]) gen = lm.generate( input_ids=input_ids, attention_mask=attn_mask, max_new_tokens=max_new_tokens, do_sample=True, temperature=0.7, top_p=0.9, repetition_penalty=1.1, eos_token_id=eos_id, pad_token_id=(tok.pad_token_id if tok.pad_token_id is not None else eos_id), logits_processor=lp, ) # Strip prefix and EOS gen_tail = gen[0][input_ids.size(1):].tolist() if eos_id is not None and eos_id in gen_tail: gen_tail = gen_tail[:gen_tail.index(eos_id)] # Map token ids -> code indices (0-based) new_codes = [id2code[t] for t in gen_tail if t in id2code] # (Nice) keep last ~1s of prompt codes to avoid a hard seam keep_sec = 1.0 keep = max(0, int(round(keep_sec * codes_per_second))) tail_codes = codes_list[-keep:] if keep > 0 else [] decode_codes = tail_codes + new_codes # Decode to audio (24 kHz) tok_tensor = torch.tensor(decode_codes, dtype=torch.long, device=device).view(1,1,-1) cont24 = wt.decode(wt.codes_to_features(tok_tensor), bandwidth_id=bw) if cont24.dim() == 3: cont24 = cont24.squeeze(0) # Stitch with crossfade stitched = equal_power_crossfade(recon24, cont24, fade_ms=60, sr=24000) # Save files sf.write("recon_24k.wav", recon24.squeeze(0).detach().cpu().numpy(), 24000) sf.write("continuation.wav", cont24.squeeze(0).detach().cpu().numpy(), 24000) sf.write("stitched_24k.wav", stitched.squeeze(0).detach().cpu().numpy(),24000) print("Wrote recon_24k.wav, continuation.wav, stitched_24k.wav") ```