CA_NSBS / src /models /base_model.py
1NEYRON1's picture
Update src/models/base_model.py
04c17ac verified
from typing import List
import numpy as np
import torch
import torch.nn as nn
class BaseModel:
def __init__(self, weights: str) -> None:
self.model = self._load_weights(weights)
def _load_weights(self, weights: str) -> torch.nn.Module:
"""
Load model weights from the specified path or huggingface path.
Returns:
A PyTorch model or Huggingface model with loaded weights
"""
pass
def predict(self, audios: np.ndarray) -> List[float]:
audios = self.preprocess_audios(audios)
return self.forward(audios)
def preprocess_audios(self, audios: torch.Tensor) -> torch.Tensor:
"""
Batched preprocessing
"""
return NotImplementedError
def forward(self, audios: torch.Tensor) -> torch.Tensor:
"""
Batched forward pass
"""
return NotImplementedError
class BaseMultimodalModel(BaseModel):
def predict(self, audios: List[np.ndarray], texts: List[str] = None) -> List[float]:
return self.forward(audios, texts)
def preprocess_audio(self, audios: np.ndarray) -> torch.Tensor:
"""
Batched preprocessing (resampling, spectrogram etc.)
"""
return NotImplementedError
def preprocess_text(self, texts: List[str]) -> torch.Tensor:
"""
Batched preprocessing (tokenization etc.)
"""
return NotImplementedError
def forward(self, audios: torch.Tensor, texts: torch.Tensor) -> torch.Tensor:
"""
Batched forward pass
"""
return NotImplementedError