File size: 5,653 Bytes
f36e46d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1914d13
f36e46d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1914d13
 
 
 
 
 
 
 
f36e46d
 
1914d13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f36e46d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171

from __future__ import annotations

from typing import List, Tuple

import torch
import torch.nn.functional as F
import torchaudio

from lemas_tts.api import TTS


def build_tokens_from_text(tts: TTS, text: str) -> List[List[str]]:
    """
    Convert raw text into token sequence(s) consistent with the multilingual
    LEMAS-TTS training pipeline.

    We reuse the same frontend logic as in `TTS.infer`:
      - frontend.dtype == "phone" -> TextNorm.text2phn -> split on '|'
      - frontend.dtype == "char"  -> TextNorm.text2norm -> language tag + chars
      - frontend is None          -> simple character sequence as fallback.
    """
    text_proc = text.strip()
    if not text_proc.endswith((".", "。", "!", "?", "?", "!")):
        text_proc = text_proc + "."

    if getattr(tts, "frontend", None) is None:
        tokens = list(text_proc)
        return [tokens]

    dtype = getattr(tts.frontend, "dtype", "phone")

    if dtype == "phone":
        phones = tts.frontend.text2phn(text_proc + " ")
        phones = phones.replace("(cmn)", "(zh)")
        tokens = [tok for tok in phones.split("|") if tok]
        return [tokens]

    if dtype == "char":
        lang, norm = tts.frontend.text2norm(text_proc + " ")
        lang_tag = f"({lang.replace('cmn', 'zh')})"
        tokens = [lang_tag] + list(norm)
        return [tokens]

    # Fallback: character-level
    tokens = list(text_proc)
    return [tokens]


def gen_wav_multilingual(
    tts: TTS,
    segment_audio: torch.Tensor,
    sr: int,
    target_text: str,
    parts_to_edit: List[Tuple[float, float]],
    speed: float = 1.0,
    nfe_step: int = 64,
    cfg_strength: float = 5.0,
    sway_sampling_coef: float = 3.0,
    ref_ratio: float = 1.0,
    no_ref_audio: bool = False,
    use_acc_grl: bool = False,
    use_prosody_encoder_flag: bool = False,
    seed: int | None = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Core editing routine:
      - build an edit mask over the mel frames;
      - run CFM.sample with that mask and the new text;
      - decode mel to waveform via the vocoder.
    """
    device = tts.device
    model = tts.ema_model
    vocoder = tts.vocoder

    mel_spec = getattr(model, "mel_spec", None)
    if mel_spec is None:
        raise RuntimeError("CFM model has no attached MelSpec; check your checkpoint.")

    target_sr = int(mel_spec.target_sample_rate)
    hop_length = int(mel_spec.hop_length)
    target_rms = 0.1

    if segment_audio.dim() == 1:
        audio = segment_audio.unsqueeze(0)
    else:
        audio = segment_audio

    # RMS normalization
    rms = torch.sqrt(torch.mean(torch.square(audio)))
    if rms < target_rms:
        audio = audio * target_rms / rms

    # Resample if needed
    if sr != target_sr:
        resampler = torchaudio.transforms.Resample(sr, target_sr)
        audio = resampler(audio)

    audio = audio.to(device)

    total_frames = audio.shape[-1] // hop_length
    # Start from "keep everything", then carve out spans to re-generate.
    edit_mask = torch.ones(1, total_frames + 1, dtype=torch.bool, device=device)

    # Clamp speed and interpret it as: >1 → faster (shorter edited span),
    # <1 → slower (longer edited span).
    speed_safe = max(float(speed), 1e-3)

    for (start, end) in parts_to_edit:
        # small safety margin around the region to edit
        start_sec = max(start - 0.1, 0.0)
        end_sec = min(end + 0.1, audio.shape[-1] / target_sr)

        start_frame = int(round(start_sec * target_sr / hop_length))
        end_frame = int(round(end_sec * target_sr / hop_length))
        start_frame = max(0, min(start_frame, total_frames - 1))
        end_frame = max(start_frame + 1, min(end_frame, total_frames))

        orig_len = end_frame - start_frame
        scaled_len = max(1, int(round(orig_len / speed_safe)))

        center = (start_frame + end_frame) // 2
        new_start = max(0, center - scaled_len // 2)
        new_end = min(total_frames, new_start + scaled_len)

        edit_mask[:, new_start:new_end] = False

    duration = total_frames

    # Text tokens using multilingual frontend
    final_text_list = build_tokens_from_text(tts, target_text)

    # For multilingual models trained with `separate_langs=True`, we need to
    # post-process the phone sequence so that each non-punctuation token is
    # prefixed with its language id, consistent with training and the main API.
    if hasattr(tts, "process_phone_list") and len(final_text_list) > 0:
        final_text_list = [tts.process_phone_list(final_text_list[0])]
    print("final_text_list:", final_text_list)

    with torch.inference_mode():
        generated, _ = model.sample(
            cond=audio,
            text=final_text_list,
            duration=duration,
            steps=nfe_step,
            cfg_strength=cfg_strength,
            sway_sampling_coef=sway_sampling_coef,
            seed=seed,
            edit_mask=edit_mask,
            use_acc_grl=use_acc_grl,
            use_prosody_encoder=use_prosody_encoder_flag,
            ref_ratio=ref_ratio,
            no_ref_audio=no_ref_audio,
        )

    generated = generated.to(torch.float32)
    generated_mel = generated.permute(0, 2, 1)  # [B, C, T_mel]

    mel_for_vocoder = generated_mel.to(device)
    if tts.mel_spec_type == "vocos":
        wav_out = vocoder.decode(mel_for_vocoder)
    elif tts.mel_spec_type == "bigvgan":
        wav_out = vocoder(mel_for_vocoder)
    else:
        raise ValueError(f"Unsupported vocoder type: {tts.mel_spec_type}")

    if rms < target_rms:
        wav_out = wav_out * rms / target_rms

    return wav_out.squeeze(0), generated_mel