Spaces:
Runtime error
Runtime error
Commit
Β·
2338c72
1
Parent(s):
d53a9dd
fix
Browse files- app.py +9 -1
- diffrhythm/infer/infer.py +3 -1
- diffrhythm/infer/infer_utils.py +38 -0
- diffrhythm/model/cfm.py +6 -0
app.py
CHANGED
|
@@ -42,8 +42,15 @@ def infer_music(lrc, ref_audio_path, steps, file_type, cfg_strength, odeint_meth
|
|
| 42 |
|
| 43 |
max_frames = math.floor(duration * 21.56)
|
| 44 |
sway_sampling_coef = -1 if steps < 32 else None
|
|
|
|
| 45 |
lrc_prompt, start_time = get_lrc_token(max_frames, lrc, tokenizer, device)
|
| 46 |
-
style_prompt = get_style_prompt(muq, ref_audio_path, prompt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
negative_style_prompt = get_negative_style_prompt(device)
|
| 48 |
latent_prompt = get_reference_latent(device, max_frames)
|
| 49 |
print(">0")
|
|
@@ -59,6 +66,7 @@ def infer_music(lrc, ref_audio_path, steps, file_type, cfg_strength, odeint_meth
|
|
| 59 |
sway_sampling_coef=sway_sampling_coef,
|
| 60 |
start_time=start_time,
|
| 61 |
file_type=file_type,
|
|
|
|
| 62 |
odeint_method=odeint_method,
|
| 63 |
)
|
| 64 |
devicetorch.empty_cache(torch)
|
|
|
|
| 42 |
|
| 43 |
max_frames = math.floor(duration * 21.56)
|
| 44 |
sway_sampling_coef = -1 if steps < 32 else None
|
| 45 |
+
vocal_flag = False
|
| 46 |
lrc_prompt, start_time = get_lrc_token(max_frames, lrc, tokenizer, device)
|
| 47 |
+
# style_prompt = get_style_prompt(muq, ref_audio_path, prompt)
|
| 48 |
+
|
| 49 |
+
if prompt is not None:
|
| 50 |
+
style_prompt = get_text_style_prompt(muq, text_prompt)
|
| 51 |
+
else:
|
| 52 |
+
style_prompt, vocal_flag = get_audio_style_prompt(muq, ref_audio_path)
|
| 53 |
+
|
| 54 |
negative_style_prompt = get_negative_style_prompt(device)
|
| 55 |
latent_prompt = get_reference_latent(device, max_frames)
|
| 56 |
print(">0")
|
|
|
|
| 66 |
sway_sampling_coef=sway_sampling_coef,
|
| 67 |
start_time=start_time,
|
| 68 |
file_type=file_type,
|
| 69 |
+
vocal_flag=vocal_flag,
|
| 70 |
odeint_method=odeint_method,
|
| 71 |
)
|
| 72 |
devicetorch.empty_cache(torch)
|
diffrhythm/infer/infer.py
CHANGED
|
@@ -16,6 +16,7 @@ from diffrhythm.infer.infer_utils import (
|
|
| 16 |
get_reference_latent,
|
| 17 |
get_lrc_token,
|
| 18 |
get_style_prompt,
|
|
|
|
| 19 |
prepare_model,
|
| 20 |
get_negative_style_prompt
|
| 21 |
)
|
|
@@ -75,7 +76,7 @@ def decode_audio(latents, vae_model, chunked=False, overlap=32, chunk_size=128):
|
|
| 75 |
y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
|
| 76 |
return y_final
|
| 77 |
|
| 78 |
-
def inference(cfm_model, vae_model, cond, text, duration, style_prompt, negative_style_prompt, steps, cfg_strength, sway_sampling_coef, start_time, file_type, odeint_method):
|
| 79 |
|
| 80 |
with torch.inference_mode():
|
| 81 |
print(">1")
|
|
@@ -89,6 +90,7 @@ def inference(cfm_model, vae_model, cond, text, duration, style_prompt, negative
|
|
| 89 |
cfg_strength=cfg_strength,
|
| 90 |
sway_sampling_coef=sway_sampling_coef,
|
| 91 |
start_time=start_time,
|
|
|
|
| 92 |
odeint_method=odeint_method,
|
| 93 |
)
|
| 94 |
if torch.cuda.is_available():
|
|
|
|
| 16 |
get_reference_latent,
|
| 17 |
get_lrc_token,
|
| 18 |
get_style_prompt,
|
| 19 |
+
get_audio_style_prompt,
|
| 20 |
prepare_model,
|
| 21 |
get_negative_style_prompt
|
| 22 |
)
|
|
|
|
| 76 |
y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
|
| 77 |
return y_final
|
| 78 |
|
| 79 |
+
def inference(cfm_model, vae_model, cond, text, duration, style_prompt, negative_style_prompt, steps, cfg_strength, sway_sampling_coef, start_time, file_type, vocal_flag, odeint_method):
|
| 80 |
|
| 81 |
with torch.inference_mode():
|
| 82 |
print(">1")
|
|
|
|
| 90 |
cfg_strength=cfg_strength,
|
| 91 |
sway_sampling_coef=sway_sampling_coef,
|
| 92 |
start_time=start_time,
|
| 93 |
+
vocal_flag=vocal_flag,
|
| 94 |
odeint_method=odeint_method,
|
| 95 |
)
|
| 96 |
if torch.cuda.is_available():
|
diffrhythm/infer/infer_utils.py
CHANGED
|
@@ -52,6 +52,41 @@ def get_negative_style_prompt(device):
|
|
| 52 |
|
| 53 |
return vocal_stlye
|
| 54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
@torch.no_grad()
|
| 56 |
def get_style_prompt(model, wav_path, prompt):
|
| 57 |
mulan = model
|
|
@@ -129,6 +164,9 @@ def get_lrc_token(max_frames, text, tokenizer, device):
|
|
| 129 |
comma_token_id = 1
|
| 130 |
period_token_id = 2
|
| 131 |
|
|
|
|
|
|
|
|
|
|
| 132 |
lrc_with_time = parse_lyrics(text)
|
| 133 |
|
| 134 |
modified_lrc_with_time = []
|
|
|
|
| 52 |
|
| 53 |
return vocal_stlye
|
| 54 |
|
| 55 |
+
|
| 56 |
+
def get_audio_style_prompt(model, wav_path):
|
| 57 |
+
vocal_flag = False
|
| 58 |
+
mulan = model
|
| 59 |
+
audio, _ = librosa.load(wav_path, sr=24000)
|
| 60 |
+
audio_len = librosa.get_duration(y=audio, sr=24000)
|
| 61 |
+
|
| 62 |
+
if audio_len <= 1:
|
| 63 |
+
vocal_flag = True
|
| 64 |
+
|
| 65 |
+
if audio_len > 10:
|
| 66 |
+
start_time = int(audio_len // 2 - 5)
|
| 67 |
+
wav = audio[start_time*24000:(start_time+10)*24000]
|
| 68 |
+
|
| 69 |
+
else:
|
| 70 |
+
wav = audio
|
| 71 |
+
wav = torch.tensor(wav).unsqueeze(0).to(model.device)
|
| 72 |
+
|
| 73 |
+
with torch.no_grad():
|
| 74 |
+
audio_emb = mulan(wavs = wav) # [1, 512]
|
| 75 |
+
|
| 76 |
+
audio_emb = audio_emb.half()
|
| 77 |
+
|
| 78 |
+
return audio_emb, vocal_flag
|
| 79 |
+
|
| 80 |
+
def get_text_style_prompt(model, text_prompt):
|
| 81 |
+
mulan = model
|
| 82 |
+
|
| 83 |
+
with torch.no_grad():
|
| 84 |
+
text_emb = mulan(texts = text_prompt) # [1, 512]
|
| 85 |
+
text_emb = text_emb.half()
|
| 86 |
+
|
| 87 |
+
return text_emb
|
| 88 |
+
|
| 89 |
+
|
| 90 |
@torch.no_grad()
|
| 91 |
def get_style_prompt(model, wav_path, prompt):
|
| 92 |
mulan = model
|
|
|
|
| 164 |
comma_token_id = 1
|
| 165 |
period_token_id = 2
|
| 166 |
|
| 167 |
+
if text == "":
|
| 168 |
+
return torch.zeros((max_frames,), dtype=torch.long).unsqueeze(0).to(device), torch.tensor(0.).unsqueeze(0).to(device).half()
|
| 169 |
+
|
| 170 |
lrc_with_time = parse_lyrics(text)
|
| 171 |
|
| 172 |
modified_lrc_with_time = []
|
diffrhythm/model/cfm.py
CHANGED
|
@@ -121,6 +121,7 @@ class CFM(nn.Module):
|
|
| 121 |
start_time=None,
|
| 122 |
latent_pred_start_frame=0,
|
| 123 |
latent_pred_end_frame=2048,
|
|
|
|
| 124 |
odeint_method="euler"
|
| 125 |
):
|
| 126 |
self.eval()
|
|
@@ -199,6 +200,11 @@ class CFM(nn.Module):
|
|
| 199 |
start_time_embed, positive_text_embed, positive_text_residuals = self.transformer.forward_timestep_invariant(text, step_cond.shape[1], drop_text=False, start_time=start_time)
|
| 200 |
_, negative_text_embed, negative_text_residuals = self.transformer.forward_timestep_invariant(text, step_cond.shape[1], drop_text=True, start_time=start_time)
|
| 201 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
text_embed = torch.cat([positive_text_embed, negative_text_embed], 0)
|
| 203 |
text_residuals = [torch.cat([a, b], 0) for a, b in zip(positive_text_residuals, negative_text_residuals)]
|
| 204 |
step_cond = torch.cat([step_cond, step_cond], 0)
|
|
|
|
| 121 |
start_time=None,
|
| 122 |
latent_pred_start_frame=0,
|
| 123 |
latent_pred_end_frame=2048,
|
| 124 |
+
vocal_flag=False,
|
| 125 |
odeint_method="euler"
|
| 126 |
):
|
| 127 |
self.eval()
|
|
|
|
| 200 |
start_time_embed, positive_text_embed, positive_text_residuals = self.transformer.forward_timestep_invariant(text, step_cond.shape[1], drop_text=False, start_time=start_time)
|
| 201 |
_, negative_text_embed, negative_text_residuals = self.transformer.forward_timestep_invariant(text, step_cond.shape[1], drop_text=True, start_time=start_time)
|
| 202 |
|
| 203 |
+
if vocal_flag:
|
| 204 |
+
style_prompt = negative_style_prompt
|
| 205 |
+
negative_style_prompt = torch.zeros_like(style_prompt)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
text_embed = torch.cat([positive_text_embed, negative_text_embed], 0)
|
| 209 |
text_residuals = [torch.cat([a, b], 0) for a, b in zip(positive_text_residuals, negative_text_residuals)]
|
| 210 |
step_cond = torch.cat([step_cond, step_cond], 0)
|