from typing import List import numpy as np import torch import torch.nn as nn class BaseModel: def __init__(self, weights: str) -> None: self.model = self._load_weights(weights) def _load_weights(self, weights: str) -> torch.nn.Module: """ Load model weights from the specified path or huggingface path. Returns: A PyTorch model or Huggingface model with loaded weights """ pass def predict(self, audios: np.ndarray) -> List[float]: audios = self.preprocess_audios(audios) return self.forward(audios) def preprocess_audios(self, audios: torch.Tensor) -> torch.Tensor: """ Batched preprocessing """ return NotImplementedError def forward(self, audios: torch.Tensor) -> torch.Tensor: """ Batched forward pass """ return NotImplementedError class BaseMultimodalModel(BaseModel): def predict(self, audios: List[np.ndarray], texts: List[str] = None) -> List[float]: return self.forward(audios, texts) def preprocess_audio(self, audios: np.ndarray) -> torch.Tensor: """ Batched preprocessing (resampling, spectrogram etc.) """ return NotImplementedError def preprocess_text(self, texts: List[str]) -> torch.Tensor: """ Batched preprocessing (tokenization etc.) """ return NotImplementedError def forward(self, audios: torch.Tensor, texts: torch.Tensor) -> torch.Tensor: """ Batched forward pass """ return NotImplementedError