|
|
from abc import abstractmethod, ABC |
|
|
from typing import Sequence |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
|
|
|
class AutoEncoderBase(ABC): |
|
|
def __init__( |
|
|
self, downsampling_ratio: int, sample_rate: int, |
|
|
latent_shape: Sequence[int | None] |
|
|
): |
|
|
self.downsampling_ratio = downsampling_ratio |
|
|
self.sample_rate = sample_rate |
|
|
self.latent_token_rate = sample_rate // downsampling_ratio |
|
|
self.latent_shape = latent_shape |
|
|
self.time_dim = latent_shape.index(None) + 1 |
|
|
|
|
|
@abstractmethod |
|
|
def encode( |
|
|
self, waveform: torch.Tensor, waveform_lengths: torch.Tensor |
|
|
) -> tuple[torch.Tensor, torch.Tensor]: |
|
|
... |
|
|
|