File size: 1,615 Bytes
a9df28d 04c17ac a9df28d 04c17ac a9df28d 04c17ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
from typing import List
import numpy as np
import torch
import torch.nn as nn
class BaseModel:
def __init__(self, weights: str) -> None:
self.model = self._load_weights(weights)
def _load_weights(self, weights: str) -> torch.nn.Module:
"""
Load model weights from the specified path or huggingface path.
Returns:
A PyTorch model or Huggingface model with loaded weights
"""
pass
def predict(self, audios: np.ndarray) -> List[float]:
audios = self.preprocess_audios(audios)
return self.forward(audios)
def preprocess_audios(self, audios: torch.Tensor) -> torch.Tensor:
"""
Batched preprocessing
"""
return NotImplementedError
def forward(self, audios: torch.Tensor) -> torch.Tensor:
"""
Batched forward pass
"""
return NotImplementedError
class BaseMultimodalModel(BaseModel):
def predict(self, audios: List[np.ndarray], texts: List[str] = None) -> List[float]:
return self.forward(audios, texts)
def preprocess_audio(self, audios: np.ndarray) -> torch.Tensor:
"""
Batched preprocessing (resampling, spectrogram etc.)
"""
return NotImplementedError
def preprocess_text(self, texts: List[str]) -> torch.Tensor:
"""
Batched preprocessing (tokenization etc.)
"""
return NotImplementedError
def forward(self, audios: torch.Tensor, texts: torch.Tensor) -> torch.Tensor:
"""
Batched forward pass
"""
return NotImplementedError
|