| from typing import TypedDict |
|
|
| import torch |
| import torchaudio |
|
|
|
|
| class AudioDict(TypedDict): |
| """Comfy's representation of AUDIO data.""" |
|
|
| sample_rate: int |
| waveform: torch.Tensor |
|
|
|
|
| AudioData = AudioDict | list[AudioDict] |
|
|
|
|
| class MtbAudio: |
| """Base class for audio processing.""" |
|
|
| @classmethod |
| def is_stereo( |
| cls, |
| audios: AudioData, |
| ) -> bool: |
| if isinstance(audios, list): |
| return any(cls.is_stereo(audio) for audio in audios) |
| else: |
| return audios["waveform"].shape[1] == 2 |
|
|
| @staticmethod |
| def resample(audio: AudioDict, common_sample_rate: int) -> AudioDict: |
| if audio["sample_rate"] != common_sample_rate: |
| resampler = torchaudio.transforms.Resample( |
| orig_freq=audio["sample_rate"], new_freq=common_sample_rate |
| ) |
| return { |
| "sample_rate": common_sample_rate, |
| "waveform": resampler(audio["waveform"]), |
| } |
| else: |
| return audio |
|
|
| @staticmethod |
| def to_stereo(audio: AudioDict) -> AudioDict: |
| if audio["waveform"].shape[1] == 1: |
| return { |
| "sample_rate": audio["sample_rate"], |
| "waveform": torch.cat( |
| [audio["waveform"], audio["waveform"]], dim=1 |
| ), |
| } |
| else: |
| return audio |
|
|
| @classmethod |
| def preprocess_audios( |
| cls, audios: list[AudioDict] |
| ) -> tuple[list[AudioDict], bool, int]: |
| max_sample_rate = max([audio["sample_rate"] for audio in audios]) |
|
|
| resampled_audios = [ |
| cls.resample(audio, max_sample_rate) for audio in audios |
| ] |
|
|
| is_stereo = cls.is_stereo(audios) |
| if is_stereo: |
| audios = [cls.to_stereo(audio) for audio in resampled_audios] |
|
|
| return (audios, is_stereo, max_sample_rate) |
|
|
|
|
| class MTB_AudioCut(MtbAudio): |
| """Basic audio cutter, values are in ms.""" |
|
|
| @classmethod |
| def INPUT_TYPES(cls): |
| return { |
| "required": { |
| "audio": ("AUDIO",), |
| "length": ( |
| ("FLOAT"), |
| { |
| "default": 1000.0, |
| "min": 0.0, |
| "max": 999999.0, |
| "step": 1, |
| }, |
| ), |
| "offset": ( |
| ("FLOAT"), |
| {"default": 0.0, "min": 0.0, "max": 999999.0, "step": 1}, |
| ), |
| }, |
| } |
|
|
| RETURN_TYPES = ("AUDIO",) |
| RETURN_NAMES = ("cut_audio",) |
| CATEGORY = "mtb/audio" |
| FUNCTION = "cut" |
|
|
| def cut(self, audio: AudioDict, length: float, offset: float): |
| sample_rate = audio["sample_rate"] |
| start_idx = int(offset * sample_rate / 1000) |
| end_idx = min( |
| start_idx + int(length * sample_rate / 1000), |
| audio["waveform"].shape[-1], |
| ) |
| cut_waveform = audio["waveform"][:, :, start_idx:end_idx] |
|
|
| return ( |
| { |
| "sample_rate": sample_rate, |
| "waveform": cut_waveform, |
| }, |
| ) |
|
|
|
|
| class MTB_AudioStack(MtbAudio): |
| """Stack/Overlay audio inputs (dynamic inputs). |
| |
| - pad audios to the longest inputs. |
| - resample audios to the highest sample rate in the inputs. |
| - convert them all to stereo if one of the inputs is. |
| """ |
|
|
| @classmethod |
| def INPUT_TYPES(cls): |
| return {"required": {}} |
|
|
| RETURN_TYPES = ("AUDIO",) |
| RETURN_NAMES = ("stacked_audio",) |
| CATEGORY = "mtb/audio" |
| FUNCTION = "stack" |
|
|
| def stack(self, **kwargs: AudioDict) -> tuple[AudioDict]: |
| audios, is_stereo, max_rate = self.preprocess_audios( |
| list(kwargs.values()) |
| ) |
|
|
| max_length = max([audio["waveform"].shape[-1] for audio in audios]) |
|
|
| padded_audios: list[torch.Tensor] = [] |
| for audio in audios: |
| padding = torch.zeros( |
| ( |
| 1, |
| 2 if is_stereo else 1, |
| max_length - audio["waveform"].shape[-1], |
| ) |
| ) |
| padded_audio = torch.cat([audio["waveform"], padding], dim=-1) |
| padded_audios.append(padded_audio) |
|
|
| stacked_waveform = torch.stack(padded_audios, dim=0).sum(dim=0) |
|
|
| return ( |
| { |
| "sample_rate": max_rate, |
| "waveform": stacked_waveform, |
| }, |
| ) |
|
|
|
|
| class MTB_AudioSequence(MtbAudio): |
| """Sequence audio inputs (dynamic inputs). |
| |
| - adding silence_duration between each segment |
| can now also be negative to overlap the clips, safely bound |
| to the the input length. |
| - resample audios to the highest sample rate in the inputs. |
| - convert them all to stereo if one of the inputs is. |
| """ |
|
|
| @classmethod |
| def INPUT_TYPES(cls): |
| return { |
| "required": { |
| "silence_duration": ( |
| ("FLOAT"), |
| {"default": 0.0, "min": -999.0, "max": 999, "step": 0.01}, |
| ) |
| }, |
| } |
|
|
| RETURN_TYPES = ("AUDIO",) |
| RETURN_NAMES = ("sequenced_audio",) |
| CATEGORY = "mtb/audio" |
| FUNCTION = "sequence" |
|
|
| def sequence(self, silence_duration: float, **kwargs: AudioDict): |
| audios, is_stereo, max_rate = self.preprocess_audios( |
| list(kwargs.values()) |
| ) |
|
|
| sequence: list[torch.Tensor] = [] |
| for i, audio in enumerate(audios): |
| if i > 0: |
| if silence_duration > 0: |
| silence = torch.zeros( |
| ( |
| 1, |
| 2 if is_stereo else 1, |
| int(silence_duration * max_rate), |
| ) |
| ) |
| sequence.append(silence) |
| elif silence_duration < 0: |
| overlap = int(abs(silence_duration) * max_rate) |
| previous_audio = sequence[-1] |
| overlap = min( |
| overlap, |
| previous_audio.shape[-1], |
| audio["waveform"].shape[-1], |
| ) |
| if overlap > 0: |
| overlap_part = ( |
| previous_audio[:, :, -overlap:] |
| + audio["waveform"][:, :, :overlap] |
| ) |
| sequence[-1] = previous_audio[:, :, :-overlap] |
| sequence.append(overlap_part) |
| audio["waveform"] = audio["waveform"][:, :, overlap:] |
|
|
| sequence.append(audio["waveform"]) |
|
|
| sequenced_waveform = torch.cat(sequence, dim=-1) |
| return ( |
| { |
| "sample_rate": max_rate, |
| "waveform": sequenced_waveform, |
| }, |
| ) |
|
|
|
|
| __nodes__ = [MTB_AudioSequence, MTB_AudioStack, MTB_AudioCut] |
|
|