Spaces:
Running on Zero
Running on Zero
| import torch | |
| import torch.nn as nn | |
| class MIDIFuzzDisturb(nn.Module): | |
| """Applies fuzzing perturbations to MIDI latent representations. | |
| The raw MIDI teacher model output preserves good prosody but causes | |
| pronunciation interference. This module mitigates that by applying | |
| blur, temporal dropout, and noise to the melody latent. | |
| """ | |
| def __init__( | |
| self, dim=128, drop_prob=0.3, noise_scale=0.1, blur_kernel=3, drop_type="random" | |
| ): | |
| super().__init__() | |
| self.blur = None | |
| self.drop_prob = None | |
| self.noise_scale = None | |
| self.dim = dim | |
| self.drop_type = drop_type | |
| assert drop_prob is not None | |
| assert drop_type is not None | |
| if drop_type == "random": | |
| # drop_prob is a float | |
| if drop_prob != 0: | |
| self.drop_prob = drop_prob | |
| elif drop_type == "equal_space": | |
| # drop_prob is a [drop, keep] list, e.g., [1, 1] means 1 frame drop, 1 frame keep | |
| self.drop_prob = drop_prob | |
| else: | |
| raise ValueError(f"Unknown drop_type: {drop_type}") | |
| if noise_scale != 0: | |
| self.noise_scale = noise_scale | |
| if blur_kernel != 0: | |
| assert blur_kernel % 2 == 1, f"blur_kernel {blur_kernel} must be odd" | |
| self.blur = nn.AvgPool1d( | |
| kernel_size=blur_kernel, stride=1, padding=blur_kernel // 2 | |
| ) | |
| def _create_equal_space_mask(self, batch_size, seq_len, device): | |
| """Create an equally-spaced mask cycling [drop, keep] frames.""" | |
| drop_frames, keep_frames = self.drop_prob | |
| cycle_len = drop_frames + keep_frames | |
| # Pattern: first drop_frames are 0 (drop), next keep_frames are 1 (keep) | |
| pattern = torch.cat( | |
| [ | |
| torch.zeros(drop_frames, device=device), | |
| torch.ones(keep_frames, device=device), | |
| ] | |
| ) | |
| # Repeat pattern to cover the full sequence length | |
| num_repeats = (seq_len + cycle_len - 1) // cycle_len | |
| mask = pattern.repeat(num_repeats)[:seq_len] # [T] | |
| # Expand to [B, T, 1] | |
| mask = mask.view(1, seq_len, 1).expand(batch_size, -1, -1) | |
| return mask | |
| def forward(self, x): | |
| # x: [B, T, D=128], pre-sigmoid logits | |
| x = torch.sigmoid(x) | |
| assert x.shape[-1] == self.dim, ( | |
| f"MIDIFuzzDisturb: expected dim={self.dim}, got {x.shape[-1]}" | |
| ) | |
| if self.blur: | |
| x = self.blur(x.transpose(1, 2)).transpose(1, 2) | |
| if self.drop_prob: | |
| if self.drop_type == "random": | |
| time_mask = ( | |
| torch.rand(x.shape[0], x.shape[1], 1, device=x.device) | |
| > self.drop_prob | |
| ) | |
| x = x * time_mask.float() | |
| elif self.drop_type == "equal_space": | |
| time_mask = self._create_equal_space_mask( | |
| x.shape[0], x.shape[1], x.device | |
| ) | |
| x = x * time_mask.float() | |
| else: | |
| raise ValueError(f"Unknown drop_type: {self.drop_type}") | |
| if self.noise_scale: | |
| noise = torch.randn_like(x) * self.noise_scale | |
| x = x + noise | |
| return x | |
| class MIDIDigitalEmbedding(nn.Module): | |
| """Embeds continuous MIDI values into discrete token embeddings. | |
| Continuous MIDI values in [0, 127] are quantized at a configurable | |
| resolution (mark_distinguish_scale) and mapped to learned embeddings. | |
| """ | |
| def __init__(self, embed_dim=128, num_classes=128, mark_distinguish_scale=2): | |
| super().__init__() | |
| # num_classes covers the input range [0, 127] plus 2 special tokens | |
| self.num_classes = num_classes + 2 | |
| self.mark_distinguish_scale = mark_distinguish_scale | |
| self.embedding_input_num_class = self.num_classes * self.mark_distinguish_scale | |
| self.embedding = nn.Embedding(self.embedding_input_num_class, embed_dim) | |
| def midi_to_class(self, midi_values): | |
| """Map continuous MIDI values to discrete class indices. | |
| Args: | |
| midi_values: [B, T] continuous MIDI values, roughly in [0, 127] | |
| Returns: | |
| class_indices: [B, T] discrete class indices | |
| """ | |
| # Round to nearest quantization step | |
| # e.g., with scale=2: 0->0, 0.3->1, 0.5->1, 0.8->2, 1.0->2, ... | |
| class_indices = torch.round(midi_values * self.mark_distinguish_scale).long() | |
| # Clamp to valid range | |
| class_indices = torch.clamp( | |
| class_indices, 0, self.embedding_input_num_class - 1 | |
| ) | |
| return class_indices | |
| def forward(self, midi_values): | |
| """ | |
| Args: | |
| midi_values: [B, T] continuous MIDI values | |
| Returns: | |
| embeddings: [B, T, embed_dim] embedding vectors | |
| """ | |
| class_indices = self.midi_to_class(midi_values) | |
| embeddings = self.embedding(class_indices) | |
| return embeddings | |