firered-vad-stream / module_firered_vad_stream.py
MigoXV's picture
Upload FireRed VAD Stream safetensors model
8896961 verified
from __future__ import annotations
from collections.abc import Sequence
import torch
import torch.nn as nn
import torch.nn.functional as F
class FireRedVadStreamModule(nn.Module):
def __init__(
self,
idim: int,
odim: int,
R: int,
M: int,
H: int,
P: int,
N1: int,
S1: int,
N2: int = 0,
S2: int = 1,
dropout: float = 0.05,
) -> None:
super().__init__()
self.dfsmn = DFSMN(idim, R, M, H, P, N1, S1, N2, S2, dropout)
self.out = nn.Linear(H, odim)
@classmethod
def from_config(cls, config) -> "FireRedVadStreamModule":
return cls(
idim=config.idim,
odim=config.odim,
R=config.R,
M=config.M,
H=config.H,
P=config.P,
N1=config.N1,
S1=config.S1,
N2=config.N2,
S2=config.S2,
dropout=config.dropout,
)
def forward(
self,
input_features: torch.Tensor,
caches: Sequence[torch.Tensor] | None = None,
) -> tuple[torch.Tensor, list[torch.Tensor]]:
x, new_caches = self.dfsmn(input_features, caches=caches)
logits = self.out(x)
probs = torch.sigmoid(logits)
return probs, new_caches
class DFSMN(nn.Module):
def __init__(
self,
D: int,
R: int,
M: int,
H: int,
P: int,
N1: int,
S1: int,
N2: int = 0,
S2: int = 1,
dropout: float = 0.05,
) -> None:
super().__init__()
self.fc1 = nn.Sequential(nn.Linear(D, H), nn.ReLU(), nn.Dropout(dropout))
self.fc2 = nn.Sequential(nn.Linear(H, P), nn.ReLU(), nn.Dropout(dropout))
self.fsmn1 = FSMN(P, N1, S1, N2, S2)
self.fsmns = nn.ModuleList(
[DFSMNBlock(H, P, N1, S1, N2, S2, dropout) for _ in range(R - 1)]
)
dnn: list[nn.Module] = [nn.Linear(P, H), nn.ReLU(), nn.Dropout(dropout)]
for _ in range(M - 1):
dnn += [nn.Linear(H, H), nn.ReLU(), nn.Dropout(dropout)]
self.dnns = nn.Sequential(*dnn)
def forward(
self,
inputs: torch.Tensor,
input_lengths: torch.Tensor | None = None,
caches: Sequence[torch.Tensor] | None = None,
) -> tuple[torch.Tensor, list[torch.Tensor]]:
mask = None if input_lengths is None else get_mask_from_lengths(input_lengths)
h = self.fc1(inputs)
p = self.fc2(h)
new_caches = []
cache = None if caches is None else caches[0]
memory, new_cache = self.fsmn1(p, mask=mask, cache=cache)
new_caches.append(new_cache)
for i, fsmn in enumerate(self.fsmns, start=1):
cache = None if caches is None else caches[i]
memory, new_cache = fsmn(memory, mask=mask, cache=cache)
new_caches.append(new_cache)
output = self.dnns(memory)
return output, new_caches
def get_mask_from_lengths(lengths: torch.Tensor) -> torch.Tensor:
batch_size = lengths.size(0)
max_length = torch.max(lengths).item()
mask = torch.zeros(batch_size, max_length, device=lengths.device)
for i in range(batch_size):
mask[i, lengths[i] :] = 1
return mask.to(torch.uint8)
class DFSMNBlock(nn.Module):
def __init__(
self,
H: int,
P: int,
N1: int,
S1: int,
N2: int = 0,
S2: int = 1,
dropout: float = 0.05,
) -> None:
super().__init__()
self.fc1 = nn.Sequential(nn.Linear(P, H), nn.ReLU(), nn.Dropout(dropout))
self.fc2 = nn.Linear(H, P, bias=False)
self.fsmn = FSMN(P, N1, S1, N2, S2)
def forward(
self,
inputs: torch.Tensor,
mask: torch.Tensor | None = None,
cache: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
residual = inputs
h = self.fc1(inputs)
p = self.fc2(h)
memory, new_cache = self.fsmn(p, mask=mask, cache=cache)
output = memory + residual
return output, new_cache
class FSMN(nn.Module):
def __init__(
self,
P: int,
N1: int,
S1: int,
N2: int = 0,
S2: int = 1,
) -> None:
super().__init__()
if N1 < 1:
raise ValueError("N1 must be greater than or equal to 1")
self.N1, self.S1, self.N2, self.S2 = N1, S1, N2, S2
self.lookback_padding = (N1 - 1) * S1
self.lookback_filter = nn.Conv1d(
in_channels=P,
out_channels=P,
kernel_size=N1,
stride=1,
padding=self.lookback_padding,
dilation=S1,
groups=P,
bias=False,
)
if self.N2 > 0:
self.lookahead_filter = nn.Conv1d(
in_channels=P,
out_channels=P,
kernel_size=N2,
stride=1,
padding=(N2 - 1) * S2,
dilation=S2,
groups=P,
bias=False,
)
else:
self.lookahead_filter = nn.Identity()
def forward(
self,
inputs: torch.Tensor,
mask: torch.Tensor | None = None,
cache: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
sequence_length = inputs.size(1)
if mask is not None:
mask = mask.unsqueeze(-1)
inputs = inputs.masked_fill(mask, 0.0)
inputs = inputs.permute((0, 2, 1)).contiguous()
residual = inputs
if cache is not None:
inputs = torch.cat((cache, inputs), dim=2)
new_cache = inputs[:, :, -self.lookback_padding :]
lookback = self.lookback_filter(inputs)
if self.N1 > 1:
lookback = lookback[:, :, : -(self.N1 - 1) * self.S1]
if cache is not None:
lookback = lookback[:, :, cache.size(2) :]
memory = residual + lookback
if self.N2 > 0 and sequence_length > 1:
lookahead = self.lookahead_filter(inputs)
memory += F.pad(lookahead[:, :, self.N2 * self.S2 :], (0, self.S2))
memory = memory.permute((0, 2, 1)).contiguous()
if mask is not None:
memory = memory.masked_fill(mask, 0.0)
return memory, new_cache