Approximetal commited on
Commit
1914d13
·
verified ·
1 Parent(s): c9c7e92

Update lemas_tts/infer/edit_multilingual.py

Browse files
Files changed (1) hide show
  1. lemas_tts/infer/edit_multilingual.py +25 -39
lemas_tts/infer/edit_multilingual.py CHANGED
@@ -1,9 +1,3 @@
1
- """
2
- Multilingual speech editing helpers for LEMAS-TTS.
3
-
4
- This is adapted from F5-TTS's `speech_edit_multilingual.py`, but uses the
5
- `lemas_tts.api.TTS` API instead of `F5TTS`.
6
- """
7
 
8
  from __future__ import annotations
9
 
@@ -59,6 +53,7 @@ def gen_wav_multilingual(
59
  sr: int,
60
  target_text: str,
61
  parts_to_edit: List[Tuple[float, float]],
 
62
  nfe_step: int = 64,
63
  cfg_strength: float = 5.0,
64
  sway_sampling_coef: float = 3.0,
@@ -103,40 +98,32 @@ def gen_wav_multilingual(
103
 
104
  audio = audio.to(device)
105
 
106
- # Build edit mask over mel frames
107
- offset = 0.0
108
- edit_mask = torch.zeros(1, 0, dtype=torch.bool, device=device)
 
 
 
 
 
109
  for (start, end) in parts_to_edit:
110
  # small safety margin around the region to edit
111
- start = max(start - 0.1, 0.0)
112
- end = min(end + 0.1, audio.shape[-1] / target_sr)
113
- part_dur_sec = end - start
114
- part_dur_samples = int(round(part_dur_sec * target_sr))
115
- start_samples = int(round(start * target_sr))
116
-
117
- # frames before edited span: keep original (mask=True)
118
- num_keep_frames = int(round((start_samples - offset) / hop_length))
119
- # frames inside edited span: to be regenerated (mask=False)
120
- num_edit_frames = int(round(part_dur_samples / hop_length))
121
-
122
- if num_keep_frames > 0:
123
- edit_mask = torch.cat(
124
- [edit_mask, torch.ones(1, num_keep_frames, dtype=torch.bool, device=device)],
125
- dim=-1,
126
- )
127
- if num_edit_frames > 0:
128
- edit_mask = torch.cat(
129
- [edit_mask, torch.zeros(1, num_edit_frames, dtype=torch.bool, device=device)],
130
- dim=-1,
131
- )
132
-
133
- offset = end * target_sr
134
-
135
- # Pad mask to full sequence length (True = keep original)
136
- total_frames = audio.shape[-1] // hop_length
137
- if edit_mask.shape[-1] < total_frames + 1:
138
- pad_len = total_frames + 1 - edit_mask.shape[-1]
139
- edit_mask = F.pad(edit_mask, (0, pad_len), value=True)
140
 
141
  duration = total_frames
142
 
@@ -181,4 +168,3 @@ def gen_wav_multilingual(
181
  wav_out = wav_out * rms / target_rms
182
 
183
  return wav_out.squeeze(0), generated_mel
184
-
 
 
 
 
 
 
 
1
 
2
  from __future__ import annotations
3
 
 
53
  sr: int,
54
  target_text: str,
55
  parts_to_edit: List[Tuple[float, float]],
56
+ speed: float = 1.0,
57
  nfe_step: int = 64,
58
  cfg_strength: float = 5.0,
59
  sway_sampling_coef: float = 3.0,
 
98
 
99
  audio = audio.to(device)
100
 
101
+ total_frames = audio.shape[-1] // hop_length
102
+ # Start from "keep everything", then carve out spans to re-generate.
103
+ edit_mask = torch.ones(1, total_frames + 1, dtype=torch.bool, device=device)
104
+
105
+ # Clamp speed and interpret it as: >1 → faster (shorter edited span),
106
+ # <1 → slower (longer edited span).
107
+ speed_safe = max(float(speed), 1e-3)
108
+
109
  for (start, end) in parts_to_edit:
110
  # small safety margin around the region to edit
111
+ start_sec = max(start - 0.1, 0.0)
112
+ end_sec = min(end + 0.1, audio.shape[-1] / target_sr)
113
+
114
+ start_frame = int(round(start_sec * target_sr / hop_length))
115
+ end_frame = int(round(end_sec * target_sr / hop_length))
116
+ start_frame = max(0, min(start_frame, total_frames - 1))
117
+ end_frame = max(start_frame + 1, min(end_frame, total_frames))
118
+
119
+ orig_len = end_frame - start_frame
120
+ scaled_len = max(1, int(round(orig_len / speed_safe)))
121
+
122
+ center = (start_frame + end_frame) // 2
123
+ new_start = max(0, center - scaled_len // 2)
124
+ new_end = min(total_frames, new_start + scaled_len)
125
+
126
+ edit_mask[:, new_start:new_end] = False
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
  duration = total_frames
129
 
 
168
  wav_out = wav_out * rms / target_rms
169
 
170
  return wav_out.squeeze(0), generated_mel