Spaces:
Running
on
Zero
Running
on
Zero
| from __future__ import annotations | |
| from typing import List, Tuple | |
| import torch | |
| import torch.nn.functional as F | |
| import torchaudio | |
| from lemas_tts.api import TTS | |
| def build_tokens_from_text(tts: TTS, text: str) -> List[List[str]]: | |
| """ | |
| Convert raw text into token sequence(s) consistent with the multilingual | |
| LEMAS-TTS training pipeline. | |
| We reuse the same frontend logic as in `TTS.infer`: | |
| - frontend.dtype == "phone" -> TextNorm.text2phn -> split on '|' | |
| - frontend.dtype == "char" -> TextNorm.text2norm -> language tag + chars | |
| - frontend is None -> simple character sequence as fallback. | |
| """ | |
| text_proc = text.strip() | |
| if not text_proc.endswith((".", "。", "!", "?", "?", "!")): | |
| text_proc = text_proc + "." | |
| if getattr(tts, "frontend", None) is None: | |
| tokens = list(text_proc) | |
| return [tokens] | |
| dtype = getattr(tts.frontend, "dtype", "phone") | |
| if dtype == "phone": | |
| phones = tts.frontend.text2phn(text_proc + " ") | |
| phones = phones.replace("(cmn)", "(zh)") | |
| tokens = [tok for tok in phones.split("|") if tok] | |
| return [tokens] | |
| if dtype == "char": | |
| lang, norm = tts.frontend.text2norm(text_proc + " ") | |
| lang_tag = f"({lang.replace('cmn', 'zh')})" | |
| tokens = [lang_tag] + list(norm) | |
| return [tokens] | |
| # Fallback: character-level | |
| tokens = list(text_proc) | |
| return [tokens] | |
| def gen_wav_multilingual( | |
| tts: TTS, | |
| segment_audio: torch.Tensor, | |
| sr: int, | |
| target_text: str, | |
| parts_to_edit: List[Tuple[float, float]], | |
| speed: float = 1.0, | |
| nfe_step: int = 64, | |
| cfg_strength: float = 5.0, | |
| sway_sampling_coef: float = 3.0, | |
| ref_ratio: float = 1.0, | |
| no_ref_audio: bool = False, | |
| use_acc_grl: bool = False, | |
| use_prosody_encoder_flag: bool = False, | |
| seed: int | None = None, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Core editing routine: | |
| - build an edit mask over the mel frames; | |
| - run CFM.sample with that mask and the new text; | |
| - decode mel to waveform via the vocoder. | |
| """ | |
| device = tts.device | |
| model = tts.ema_model | |
| vocoder = tts.vocoder | |
| mel_spec = getattr(model, "mel_spec", None) | |
| if mel_spec is None: | |
| raise RuntimeError("CFM model has no attached MelSpec; check your checkpoint.") | |
| target_sr = int(mel_spec.target_sample_rate) | |
| hop_length = int(mel_spec.hop_length) | |
| target_rms = 0.1 | |
| if segment_audio.dim() == 1: | |
| audio = segment_audio.unsqueeze(0) | |
| else: | |
| audio = segment_audio | |
| # RMS normalization | |
| rms = torch.sqrt(torch.mean(torch.square(audio))) | |
| if rms < target_rms: | |
| audio = audio * target_rms / rms | |
| # Resample if needed | |
| if sr != target_sr: | |
| resampler = torchaudio.transforms.Resample(sr, target_sr) | |
| audio = resampler(audio) | |
| audio = audio.to(device) | |
| total_frames = audio.shape[-1] // hop_length | |
| # Start from "keep everything", then carve out spans to re-generate. | |
| edit_mask = torch.ones(1, total_frames + 1, dtype=torch.bool, device=device) | |
| # Clamp speed and interpret it as: >1 → faster (shorter edited span), | |
| # <1 → slower (longer edited span). | |
| speed_safe = max(float(speed), 1e-3) | |
| for (start, end) in parts_to_edit: | |
| # small safety margin around the region to edit | |
| start_sec = max(start - 0.1, 0.0) | |
| end_sec = min(end + 0.1, audio.shape[-1] / target_sr) | |
| start_frame = int(round(start_sec * target_sr / hop_length)) | |
| end_frame = int(round(end_sec * target_sr / hop_length)) | |
| start_frame = max(0, min(start_frame, total_frames - 1)) | |
| end_frame = max(start_frame + 1, min(end_frame, total_frames)) | |
| orig_len = end_frame - start_frame | |
| scaled_len = max(1, int(round(orig_len / speed_safe))) | |
| center = (start_frame + end_frame) // 2 | |
| new_start = max(0, center - scaled_len // 2) | |
| new_end = min(total_frames, new_start + scaled_len) | |
| edit_mask[:, new_start:new_end] = False | |
| duration = total_frames | |
| # Text tokens using multilingual frontend | |
| final_text_list = build_tokens_from_text(tts, target_text) | |
| # For multilingual models trained with `separate_langs=True`, we need to | |
| # post-process the phone sequence so that each non-punctuation token is | |
| # prefixed with its language id, consistent with training and the main API. | |
| if hasattr(tts, "process_phone_list") and len(final_text_list) > 0: | |
| final_text_list = [tts.process_phone_list(final_text_list[0])] | |
| print("final_text_list:", final_text_list) | |
| with torch.inference_mode(): | |
| generated, _ = model.sample( | |
| cond=audio, | |
| text=final_text_list, | |
| duration=duration, | |
| steps=nfe_step, | |
| cfg_strength=cfg_strength, | |
| sway_sampling_coef=sway_sampling_coef, | |
| seed=seed, | |
| edit_mask=edit_mask, | |
| use_acc_grl=use_acc_grl, | |
| use_prosody_encoder=use_prosody_encoder_flag, | |
| ref_ratio=ref_ratio, | |
| no_ref_audio=no_ref_audio, | |
| ) | |
| generated = generated.to(torch.float32) | |
| generated_mel = generated.permute(0, 2, 1) # [B, C, T_mel] | |
| mel_for_vocoder = generated_mel.to(device) | |
| if tts.mel_spec_type == "vocos": | |
| wav_out = vocoder.decode(mel_for_vocoder) | |
| elif tts.mel_spec_type == "bigvgan": | |
| wav_out = vocoder(mel_for_vocoder) | |
| else: | |
| raise ValueError(f"Unsupported vocoder type: {tts.mel_spec_type}") | |
| if rms < target_rms: | |
| wav_out = wav_out * rms / target_rms | |
| return wav_out.squeeze(0), generated_mel | |