test / src /heartlib /heartcodec /modeling_heartcodec.py
brandongraves08's picture
HeartMuLa Gradio Space deployment
0322512
import torch
from .models.flow_matching import FlowMatching
from .models.sq_codec import ScalarModel
from .configuration_heartcodec import HeartCodecConfig
from transformers.modeling_utils import PreTrainedModel
import math
import numpy as np
class HeartCodec(PreTrainedModel):
config_class = HeartCodecConfig
def __init__(
self,
config: HeartCodecConfig,
):
super(HeartCodec, self).__init__(config)
self.config = config
self.flow_matching = FlowMatching(
dim=config.dim,
codebook_size=config.codebook_size,
decay=config.decay,
commitment_weight=config.commitment_weight,
threshold_ema_dead_code=config.threshold_ema_dead_code,
use_cosine_sim=config.use_cosine_sim,
codebook_dim=config.codebook_dim,
num_quantizers=config.num_quantizers,
attention_head_dim=config.attention_head_dim,
in_channels=config.in_channels,
norm_type=config.norm_type,
num_attention_heads=config.num_attention_heads,
num_layers=config.num_layers,
num_layers_2=config.num_layers_2,
out_channels=config.out_channels,
)
self.scalar_model = ScalarModel(
num_bands=config.num_bands,
sample_rate=config.sample_rate,
causal=config.causal,
num_samples=config.num_samples,
downsample_factors=config.downsample_factors,
downsample_kernel_sizes=config.downsample_kernel_sizes,
upsample_factors=config.upsample_factors,
upsample_kernel_sizes=config.upsample_kernel_sizes,
latent_hidden_dim=config.latent_hidden_dim,
default_kernel_size=config.default_kernel_size,
delay_kernel_size=config.delay_kernel_size,
init_channel=config.init_channel,
res_kernel_size=config.res_kernel_size,
)
self.post_init()
self.sample_rate = config.sample_rate
@torch.inference_mode()
def detokenize(
self,
codes,
duration=29.76,
num_steps=10,
disable_progress=False,
guidance_scale=1.25,
):
codes = codes.unsqueeze(0).to(self.device)
first_latent = torch.randn(
codes.shape[0], int(duration * 25), 256, dtype=self.dtype
).to(
self.device
) # B, T, 64
first_latent_length = 0
first_latent_codes_length = 0
min_samples = int(duration * 12.5)
hop_samples = min_samples // 93 * 80
ovlp_samples = min_samples - hop_samples
ovlp_frames = ovlp_samples * 2
codes_len = codes.shape[-1] #
target_len = int(
(codes_len - first_latent_codes_length) / 12.5 * self.sample_rate
)
# code repeat
if codes_len < min_samples:
while codes.shape[-1] < min_samples:
codes = torch.cat([codes, codes], -1)
codes = codes[:, :, 0:min_samples]
codes_len = codes.shape[-1]
if (codes_len - ovlp_frames) % hop_samples > 0:
len_codes = (
math.ceil((codes_len - ovlp_samples) / float(hop_samples)) * hop_samples
+ ovlp_samples
)
while codes.shape[-1] < len_codes:
codes = torch.cat([codes, codes], -1)
codes = codes[:, :, 0:len_codes]
latent_length = int(duration * 25)
latent_list = []
for sinx in range(0, codes.shape[-1] - hop_samples + 1, hop_samples):
codes_input = []
codes_input.append(codes[:, :, sinx : sinx + min_samples])
if sinx == 0 or ovlp_frames == 0:
incontext_length = first_latent_length
latents = self.flow_matching.inference_codes(
codes_input,
first_latent,
latent_length,
incontext_length,
guidance_scale=guidance_scale,
num_steps=num_steps,
disable_progress=disable_progress,
scenario="other_seg",
)
latent_list.append(latents)
else:
true_latent = latent_list[-1][:, -ovlp_frames:, :]
len_add_to_latent = latent_length - true_latent.shape[1] #
incontext_length = true_latent.shape[1]
true_latent = torch.cat(
[
true_latent,
torch.randn(
true_latent.shape[0],
len_add_to_latent,
true_latent.shape[-1],
dtype=self.dtype,
).to(self.device),
],
1,
)
latents = self.flow_matching.inference_codes(
codes_input,
true_latent,
latent_length,
incontext_length,
guidance_scale=guidance_scale,
num_steps=num_steps,
disable_progress=disable_progress,
scenario="other_seg",
)
latent_list.append(latents)
# latent_list = [l.float() for l in latent_list]
latent_list[0] = latent_list[0][:, first_latent_length:, :]
min_samples = int(duration * self.sample_rate)
hop_samples = min_samples // 93 * 80
ovlp_samples = min_samples - hop_samples
output = None
for i in range(len(latent_list)):
latent = latent_list[i]
bsz, t, f = latent.shape
latent = latent.reshape(
latent.shape[0], latent.shape[1], 2, latent.shape[2] // 2
).permute(0, 2, 1, 3)
latent = latent.reshape(
latent.shape[0] * 2, latent.shape[2], latent.shape[3]
)
cur_output = (
self.scalar_model.decode(latent.transpose(1, 2)).squeeze(0).squeeze(1)
) # 1 512 256
cur_output = cur_output[:, 0:min_samples].detach().cpu() # B, T
if cur_output.dim() == 3:
cur_output = cur_output[0]
if output is None:
output = cur_output
else:
if ovlp_samples == 0:
output = torch.cat([output, cur_output], -1)
else:
ov_win = torch.from_numpy(np.linspace(0, 1, ovlp_samples)[None, :])
ov_win = torch.cat([ov_win, 1 - ov_win], -1)
output[:, -ovlp_samples:] = (
output[:, -ovlp_samples:] * ov_win[:, -ovlp_samples:]
+ cur_output[:, 0:ovlp_samples] * ov_win[:, 0:ovlp_samples]
)
output = torch.cat([output, cur_output[:, ovlp_samples:]], -1)
output = output[:, 0:target_len]
return output