xjsc0's picture
1
64ec292
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from src.YingMusicSinger.melody.Gconform import Gmidi_conform
# midi decoding utils
def decode_gaussian_blurred_probs(probs, vmin, vmax, deviation, threshold):
num_bins = int(probs.shape[-1])
interval = (vmax - vmin) / (num_bins - 1)
width = int(3 * deviation / interval) # 3 * sigma
idx = torch.arange(num_bins, device=probs.device)[None, None, :] # [1, 1, N]
idx_values = idx * interval + vmin
center = torch.argmax(probs, dim=-1, keepdim=True) # [B, T, 1]
start = torch.clip(center - width, min=0) # [B, T, 1]
end = torch.clip(center + width + 1, max=num_bins) # [B, T, 1]
idx_masks = (idx >= start) & (idx < end) # [B, T, N]
weights = probs * idx_masks # [B, T, N]
product_sum = torch.sum(weights * idx_values, dim=2) # [B, T]
weight_sum = torch.sum(weights, dim=2) # [B, T]
values = product_sum / (
weight_sum + (weight_sum == 0)
) # avoid dividing by zero, [B, T]
rest = probs.max(dim=-1)[0] < threshold # [B, T]
return values, rest
def decode_bounds_to_alignment(bounds, use_diff=True):
bounds_step = bounds.cumsum(dim=1).round().long()
if use_diff:
bounds_inc = (
torch.diff(
bounds_step,
dim=1,
prepend=torch.full(
(bounds.shape[0], 1),
fill_value=-1,
dtype=bounds_step.dtype,
device=bounds_step.device,
),
)
> 0
)
else:
bounds_inc = F.pad(
(bounds_step[:, 1:] > bounds_step[:, :-1]), [1, 0], value=True
)
frame2item = bounds_inc.long().cumsum(dim=1)
return frame2item
def decode_note_sequence(frame2item, values, masks, threshold=0.5):
"""
:param frame2item: [1, 1, 1, 1, 2, 2, 3, 3, 3]
:param values:
:param masks:
:param threshold: minimum ratio of unmasked frames required to be regarded as an unmasked item
:return: item_values, item_dur, item_masks
"""
b = frame2item.shape[0]
space = frame2item.max() + 1
item_dur = frame2item.new_zeros(b, space, dtype=frame2item.dtype).scatter_add(
1, frame2item, torch.ones_like(frame2item)
)[:, 1:]
item_unmasked_dur = frame2item.new_zeros(
b, space, dtype=frame2item.dtype
).scatter_add(1, frame2item, masks.long())[:, 1:]
item_masks = item_unmasked_dur / item_dur >= threshold
values_quant = values.round().long()
histogram = (
frame2item.new_zeros(b, space * 128, dtype=frame2item.dtype)
.scatter_add(
1, frame2item * 128 + values_quant, torch.ones_like(frame2item) * masks
)
.unflatten(1, [space, 128])[:, 1:, :]
)
item_values_center = histogram.float().argmax(dim=2).to(dtype=values.dtype)
values_center = torch.gather(F.pad(item_values_center, [1, 0]), 1, frame2item)
values_near_center = (
masks & (values >= values_center - 0.5) & (values <= values_center + 0.5)
)
item_valid_dur = frame2item.new_zeros(b, space, dtype=frame2item.dtype).scatter_add(
1, frame2item, values_near_center.long()
)[:, 1:]
item_values = values.new_zeros(b, space, dtype=values.dtype).scatter_add(
1, frame2item, values * values_near_center
)[:, 1:] / (item_valid_dur + (item_valid_dur == 0))
return item_values, item_dur, item_masks
def expand_batch_padded(feature_tensor, counts_tensor, padding_value=0.0):
assert feature_tensor.dim() == 2 and counts_tensor.dim() == 2
lengths = torch.sum(counts_tensor, dim=1)
feature_tensor = feature_tensor.reshape(-1)
counts_tensor = counts_tensor.reshape(-1)
expanded_flat = torch.repeat_interleave(feature_tensor, counts_tensor)
ragged_list = torch.split(expanded_flat, lengths.tolist())
padded_tensor = pad_sequence(
ragged_list, batch_first=True, padding_value=padding_value
)
return padded_tensor, lengths
class midi_loss(nn.Module):
def __init__(self):
super().__init__()
self.loss = nn.BCELoss()
def forward(self, x, target):
midiout, cutp = x
midi_target, cutp_target = target
cutploss = self.loss(cutp, cutp_target)
midiloss = self.loss(midiout, midi_target)
return midiloss, cutploss
class MIDIExtractor(nn.Module):
def __init__(self, in_dim=None, out_dim=None):
super().__init__()
cfg = {
"attention_drop": 0.1,
"attention_heads": 8,
"attention_heads_dim": 64,
"conv_drop": 0.1,
"dim": 512,
"ffn_latent_drop": 0.1,
"ffn_out_drop": 0.1,
"kernel_size": 31,
"lay": 8,
"use_lay_skip": True,
"indim": 80,
"outdim": 128,
}
if in_dim is not None:
cfg["indim"] = in_dim
if out_dim is not None:
cfg["outdim"] = out_dim
self.midi_conform = Gmidi_conform(**cfg)
self.midi_min = 0
self.midi_max = 127
self.midi_deviation = 1.0
self.rest_threshold = 0.1
def _load_form_ckpt(self, ckpt_path, device="cpu"):
from collections import OrderedDict
if ckpt_path is None:
raise ValueError("midi_extractor_path is required")
state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
prefix_in_ckpt = "model.model"
state_dict = OrderedDict(
{
k.replace(f"{prefix_in_ckpt}.", "midi_conform."): v
for k, v in state_dict.items()
if k.startswith(f"{prefix_in_ckpt}.")
}
)
self.load_state_dict(state_dict, strict=True)
# self.to(device)
def forward(self, x, mask=None):
midi, bound = self.midi_conform(x, mask)
return midi, bound
def postprocess(self, midi, bounds, with_expand=False):
probs = torch.sigmoid(midi)
bound_probs = torch.sigmoid(bounds)
bound_probs = torch.squeeze(bound_probs, -1)
masks = torch.ones_like(bound_probs).bool()
# Avoid in-place ops on tensors needed for autograd (outputs of SigmoidBackward)
probs = probs * masks[..., None]
bound_probs = bound_probs * masks
unit2note_pred = decode_bounds_to_alignment(bound_probs) * masks
midi_pred, rest_pred = decode_gaussian_blurred_probs(
probs,
vmin=self.midi_min,
vmax=self.midi_max,
deviation=self.midi_deviation,
threshold=self.rest_threshold,
)
note_midi_pred, note_dur_pred, note_mask_pred = decode_note_sequence(
unit2note_pred, midi_pred, ~rest_pred & masks
)
if not with_expand:
return note_midi_pred, note_dur_pred
note_midi_expand, _ = expand_batch_padded(note_midi_pred, note_dur_pred)
return note_midi_expand, None