| | from abc import ABC |
| |
|
| | from .patching_utils import combine_patches, generate_patches, get_shape |
| |
|
| |
|
| | class PatchStrategy(ABC): |
| | def __init__(self, tstride, tshape, fstride, fshape, input_fdim, input_tdim): |
| | self.tstride = tstride |
| | self.tshape = tshape |
| | self.fstride = fstride |
| | self.fshape = fshape |
| | self.input_fdim = input_fdim |
| | self.input_tdim = input_tdim |
| |
|
| | def _patch(self, x): |
| | patches = generate_patches( |
| | input=x, |
| | fstride=self.fstride, |
| | tstride=self.tstride, |
| | fshape=self.fshape, |
| | tshape=self.tshape, |
| | ) |
| | return patches |
| |
|
| | def patch(self, x): |
| | return self._patch(x) |
| |
|
| | def embed(self, x, patch_embedder): |
| | return patch_embedder(x) |
| |
|
| | def patch_and_embed(self, x, patch_embedder): |
| | """ |
| | Generate patches from the input spectrogram and embed them. |
| | |
| | This method creates patches based on the frequency and temporal stride/shape |
| | parameters, and then applies the given patch embedding function. |
| | |
| | Parameters |
| | ---------- |
| | x : torch.Tensor |
| | The input spectrogram tensor to be patched and embedded. |
| | patch_embedder : Callable |
| | A function that applies embedding to the patches. |
| | |
| | Returns |
| | ------- |
| | Tuple[torch.Tensor, torch.Tensor] |
| | The generated patches and their embeddings. |
| | """ |
| | |
| | patches = generate_patches( |
| | input=x, |
| | fstride=self.fstride, |
| | tstride=self.tstride, |
| | fshape=self.fshape, |
| | tshape=self.tshape, |
| | ) |
| | x = patch_embedder(x) |
| | return patches, x |
| |
|
| | def get_patch_size(self): |
| | p_f_dim, p_t_dim = get_shape( |
| | fstride=self.fstride, |
| | tstride=self.tstride, |
| | input_fdim=self.input_fdim, |
| | input_tdim=self.input_tdim, |
| | fshape=self.fshape, |
| | tshape=self.tshape, |
| | ) |
| | return p_f_dim, p_t_dim |
| |
|
| | def combine_patches(self, patches, original_size): |
| | return combine_patches( |
| | patches, original_size, self.fstride, self.tstride, self.fshape, self.tshape |
| | ) |
| |
|
| |
|
| | class TimePatching(PatchStrategy): |
| | def __init__( |
| | self, input_tdim, tstride=2, tshape=2, fstride=128, fshape=128, input_fdim=128 |
| | ): |
| | super().__init__( |
| | tstride=tstride, |
| | tshape=tshape, |
| | fstride=fstride, |
| | fshape=fshape, |
| | input_fdim=input_fdim, |
| | input_tdim=input_tdim, |
| | ) |
| |
|
| |
|
| | class FramePatching(PatchStrategy): |
| | def __init__( |
| | self, input_tdim, tstride=16, tshape=16, fstride=16, fshape=16, input_fdim=128 |
| | ): |
| | super().__init__( |
| | tstride=tstride, |
| | tshape=tshape, |
| | fstride=fstride, |
| | fshape=fshape, |
| | input_fdim=input_fdim, |
| | input_tdim=input_tdim, |
| | ) |
| |
|