File size: 1,615 Bytes
a9df28d
 
 
 
 
 
 
 
04c17ac
 
 
 
a9df28d
 
 
 
04c17ac
a9df28d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
04c17ac
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from typing import List

import numpy as np

import torch

import torch.nn as nn

class BaseModel:
    def __init__(self, weights: str) -> None:
        self.model = self._load_weights(weights)


    def _load_weights(self, weights: str) -> torch.nn.Module:
        """
        Load model weights from the specified path or huggingface path.

        Returns:
            A PyTorch model or Huggingface model with loaded weights
        """
        pass

    def predict(self, audios: np.ndarray) -> List[float]:
        audios = self.preprocess_audios(audios)
        return self.forward(audios)

    def preprocess_audios(self, audios: torch.Tensor) -> torch.Tensor:
        """
        Batched preprocessing
        """
        return NotImplementedError

    def forward(self, audios: torch.Tensor) -> torch.Tensor:
        """
        Batched forward pass
        """
        return NotImplementedError


class BaseMultimodalModel(BaseModel):
    def predict(self, audios: List[np.ndarray], texts: List[str] = None) -> List[float]:
        return self.forward(audios, texts)

    def preprocess_audio(self, audios: np.ndarray) -> torch.Tensor:
        """
        Batched preprocessing (resampling, spectrogram etc.)
        """
        return NotImplementedError

    def preprocess_text(self, texts: List[str]) -> torch.Tensor:
        """
        Batched preprocessing (tokenization etc.)
        """
        return NotImplementedError

    def forward(self, audios: torch.Tensor, texts: torch.Tensor) -> torch.Tensor:
        """
        Batched forward pass
        """
        return NotImplementedError