PhoneticXeus / src /espnet_import /subsampling.py
Shikhar
Deploy PhoneticXeus Gradio demo (CPU)
84f8437
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2019 Shigeki Karita
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Subsampling layer definition."""
import torch
from src.espnet_import.embedding import PositionalEncoding
class TooShortUttError(Exception):
"""Raised when the utt is too short for subsampling.
Args:
message (str): Message for error catch
actual_size (int): the short size that cannot pass the subsampling
limit (int): the limit size for subsampling
"""
def __init__(self, message, actual_size, limit):
"""Construct a TooShortUttError for error handler."""
super().__init__(message)
self.actual_size = actual_size
self.limit = limit
def check_short_utt(ins, size):
"""Check if the utterance is too short for subsampling."""
if isinstance(ins, Conv1dSubsampling1) and size < 5:
return True, 5
if isinstance(ins, Conv1dSubsampling2) and size < 5:
return True, 5
if isinstance(ins, Conv1dSubsampling3) and size < 7:
return True, 7
if isinstance(ins, Conv2dSubsampling1) and size < 5:
return True, 5
if isinstance(ins, Conv2dSubsampling2) and size < 7:
return True, 7
if isinstance(ins, Conv2dSubsampling) and size < 7:
return True, 7
if isinstance(ins, Conv2dSubsampling6) and size < 11:
return True, 11
if isinstance(ins, Conv2dSubsampling8) and size < 15:
return True, 15
return False, -1
def _upgrade_legacy_subsampling_state_dict(state_dict, prefix):
"""Remap legacy nn.Sequential keys for subsampling modules."""
w_new = prefix + "out.weight"
b_new = prefix + "out.bias"
w_old = prefix + "out.0.weight"
b_old = prefix + "out.0.bias"
if w_new not in state_dict and w_old in state_dict:
state_dict[w_new] = state_dict.pop(w_old)
elif w_new in state_dict and w_old in state_dict:
state_dict.pop(w_old)
if b_new not in state_dict and b_old in state_dict:
state_dict[b_new] = state_dict.pop(b_old)
elif b_new in state_dict and b_old in state_dict:
state_dict.pop(b_old)
old_pos_prefix = prefix + "out.1."
new_pos_prefix = prefix + "pos_enc."
for k in list(state_dict.keys()):
if not k.startswith(old_pos_prefix):
continue
new_k = new_pos_prefix + k[len(old_pos_prefix) :]
if new_k not in state_dict:
state_dict[new_k] = state_dict[k]
state_dict.pop(k, None)
class Conv1dSubsampling1(torch.nn.Module):
"""Convolutional 1D subsampling.
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
pos_enc (torch.nn.Module): Custom position encoding layer.
"""
def __init__(self, idim, odim, dropout_rate, pos_enc=None):
"""Construct an Conv1dSubsampling1 object."""
super(Conv1dSubsampling1, self).__init__()
self.conv = torch.nn.Sequential(
torch.nn.Conv1d(idim, odim, 3, 1),
torch.nn.ReLU(),
torch.nn.Conv1d(odim, odim, 3, 1),
torch.nn.ReLU(),
)
self.out = torch.nn.Linear(odim, odim)
self.pos_enc = (
pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate)
)
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
_upgrade_legacy_subsampling_state_dict(state_dict, prefix)
super()._load_from_state_dict(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)
def forward(self, x, x_mask, prefix_embeds=None):
"""Subsample x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
prefix_embeds (torch.Tensor or None): Prefix token embeddings
(#batch, prefix_len, odim).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 2.
torch.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 2.
"""
x = x.transpose(2, 1) # (#batch, idim, time)
x = self.conv(x)
b, c, t = x.size()
x = self.out(x.transpose(1, 2).contiguous())
if x_mask is not None:
x_mask = x_mask[:, :, :-2:1][:, :, :-2:1]
if prefix_embeds is not None:
x = torch.cat([prefix_embeds, x], dim=1)
if x_mask is not None:
x_mask = torch.cat(
[
torch.ones(
x_mask.shape[0],
1,
prefix_embeds.size(1),
dtype=x_mask.dtype,
device=x_mask.device,
),
x_mask,
],
dim=-1,
)
x = self.pos_enc(x)
return x, x_mask
def __getitem__(self, key):
"""Get item.
When reset_parameters() is called, if use_scaled_pos_enc is used,
return the positioning encoding.
"""
if key != -1:
raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
return self.pos_enc
class Conv1dSubsampling2(torch.nn.Module):
"""Convolutional 1D subsampling (to 1/2 length).
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
pos_enc (torch.nn.Module): Custom position encoding layer.
"""
def __init__(self, idim, odim, dropout_rate, pos_enc=None):
"""Construct an Conv1dSubsampling2 object."""
super(Conv1dSubsampling2, self).__init__()
self.conv = torch.nn.Sequential(
torch.nn.Conv1d(idim, odim, 3, 1),
torch.nn.ReLU(),
torch.nn.Conv1d(odim, odim, 3, 2),
torch.nn.ReLU(),
)
self.out = torch.nn.Linear(odim, odim)
self.pos_enc = (
pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate)
)
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
_upgrade_legacy_subsampling_state_dict(state_dict, prefix)
super()._load_from_state_dict(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)
def forward(self, x, x_mask, prefix_embeds=None):
"""Subsample x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
prefix_embeds (torch.Tensor or None): Prefix token embeddings
(#batch, prefix_len, odim).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 2.
torch.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 2.
"""
x = x.transpose(2, 1) # (#batch, idim, time)
x = self.conv(x)
b, c, t = x.size()
x = self.out(x.transpose(1, 2).contiguous())
if x_mask is not None:
x_mask = x_mask[:, :, :-2:1][:, :, :-2:2]
if prefix_embeds is not None:
x = torch.cat([prefix_embeds, x], dim=1)
if x_mask is not None:
x_mask = torch.cat(
[
torch.ones(
x_mask.shape[0],
1,
prefix_embeds.size(1),
dtype=x_mask.dtype,
device=x_mask.device,
),
x_mask,
],
dim=-1,
)
x = self.pos_enc(x)
return x, x_mask
def __getitem__(self, key):
"""Get item.
When reset_parameters() is called, if use_scaled_pos_enc is used,
return the positioning encoding.
"""
if key != -1:
raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
return self.pos_enc
class Conv1dSubsampling3(torch.nn.Module):
"""Convolutional 1D subsampling (to 1/3 length).
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
pos_enc (torch.nn.Module): Custom position encoding layer.
"""
def __init__(self, idim, odim, dropout_rate, pos_enc=None):
"""Construct an Conv1dSubsampling3 object."""
super(Conv1dSubsampling3, self).__init__()
self.conv = torch.nn.Sequential(
torch.nn.Conv1d(idim, odim, 3, 1),
torch.nn.ReLU(),
torch.nn.Conv1d(odim, odim, 5, 3),
torch.nn.ReLU(),
)
self.out = torch.nn.Linear(odim, odim)
self.pos_enc = (
pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate)
)
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
_upgrade_legacy_subsampling_state_dict(state_dict, prefix)
super()._load_from_state_dict(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)
def forward(self, x, x_mask, prefix_embeds=None):
"""Subsample x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
prefix_embeds (torch.Tensor or None): Prefix token embeddings
(#batch, prefix_len, odim).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 2.
torch.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 2.
"""
x = x.transpose(2, 1) # (#batch, idim, time)
x = self.conv(x)
b, c, t = x.size()
x = self.out(x.transpose(1, 2).contiguous())
if x_mask is not None:
x_mask = x_mask[:, :, :-2:1][:, :, :-4:3]
if prefix_embeds is not None:
x = torch.cat([prefix_embeds, x], dim=1)
if x_mask is not None:
x_mask = torch.cat(
[
torch.ones(
x_mask.shape[0],
1,
prefix_embeds.size(1),
dtype=x_mask.dtype,
device=x_mask.device,
),
x_mask,
],
dim=-1,
)
x = self.pos_enc(x)
return x, x_mask
def __getitem__(self, key):
"""Get item.
When reset_parameters() is called, if use_scaled_pos_enc is used,
return the positioning encoding.
"""
if key != -1:
raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
return self.pos_enc
class Conv2dSubsampling(torch.nn.Module):
"""Convolutional 2D subsampling (to 1/4 length).
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
pos_enc (torch.nn.Module): Custom position encoding layer.
"""
def __init__(self, idim, odim, dropout_rate, pos_enc=None):
"""Construct an Conv2dSubsampling object."""
super(Conv2dSubsampling, self).__init__()
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, odim, 3, 2),
torch.nn.ReLU(),
torch.nn.Conv2d(odim, odim, 3, 2),
torch.nn.ReLU(),
)
self.out = torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
self.pos_enc = (
pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate)
)
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
_upgrade_legacy_subsampling_state_dict(state_dict, prefix)
super()._load_from_state_dict(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)
def forward(self, x, x_mask, prefix_embeds=None):
"""Subsample x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
prefix_embeds (torch.Tensor or None): Prefix token embeddings
(#batch, prefix_len, odim).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 4.
torch.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 4.
"""
x = x.unsqueeze(1) # (b, c, t, f)
x = self.conv(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
if x_mask is not None:
x_mask = x_mask[:, :, :-2:2][:, :, :-2:2]
if prefix_embeds is not None:
x = torch.cat([prefix_embeds, x], dim=1)
if x_mask is not None:
x_mask = torch.cat(
[
torch.ones(
x_mask.shape[0],
1,
prefix_embeds.size(1),
dtype=x_mask.dtype,
device=x_mask.device,
),
x_mask,
],
dim=-1,
)
x = self.pos_enc(x)
return x, x_mask
# def __getitem__(self, key):
# """Get item.
# When reset_parameters() is called, if use_scaled_pos_enc is used,
# return the positioning encoding.
# """
# if key != -1:
# raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
# return self.out[key]
class Conv2dSubsampling1(torch.nn.Module):
"""Similar to Conv2dSubsampling module, but without any subsampling performed.
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
pos_enc (torch.nn.Module): Custom position encoding layer.
"""
def __init__(self, idim, odim, dropout_rate, pos_enc=None):
"""Construct an Conv2dSubsampling1 object."""
super(Conv2dSubsampling1, self).__init__()
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, odim, 3, 1),
torch.nn.ReLU(),
torch.nn.Conv2d(odim, odim, 3, 1),
torch.nn.ReLU(),
)
self.out = torch.nn.Linear(odim * (idim - 4), odim)
self.pos_enc = (
pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate)
)
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
_upgrade_legacy_subsampling_state_dict(state_dict, prefix)
super()._load_from_state_dict(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)
def forward(self, x, x_mask, prefix_embeds=None):
"""Pass x through 2 Conv2d layers without subsampling.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
prefix_embeds (torch.Tensor or None): Prefix token embeddings
(#batch, prefix_len, odim).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim).
where time' = time - 4.
torch.Tensor: Subsampled mask (#batch, 1, time').
where time' = time - 4.
"""
x = x.unsqueeze(1) # (b, c, t, f)
x = self.conv(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
if x_mask is not None:
x_mask = x_mask[:, :, :-4]
if prefix_embeds is not None:
x = torch.cat([prefix_embeds, x], dim=1)
if x_mask is not None:
x_mask = torch.cat(
[
torch.ones(
x_mask.shape[0],
1,
prefix_embeds.size(1),
dtype=x_mask.dtype,
device=x_mask.device,
),
x_mask,
],
dim=-1,
)
x = self.pos_enc(x)
return x, x_mask
def __getitem__(self, key):
"""Get item.
When reset_parameters() is called, if use_scaled_pos_enc is used,
return the positioning encoding.
"""
if key != -1:
raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
return self.pos_enc
class Conv2dSubsampling2(torch.nn.Module):
"""Convolutional 2D subsampling (to 1/2 length).
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
pos_enc (torch.nn.Module): Custom position encoding layer.
"""
def __init__(self, idim, odim, dropout_rate, pos_enc=None):
"""Construct an Conv2dSubsampling2 object."""
super(Conv2dSubsampling2, self).__init__()
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, odim, 3, 2),
torch.nn.ReLU(),
torch.nn.Conv2d(odim, odim, 3, 1),
torch.nn.ReLU(),
)
self.out = torch.nn.Linear(odim * (((idim - 1) // 2 - 2)), odim)
self.pos_enc = (
pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate)
)
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
_upgrade_legacy_subsampling_state_dict(state_dict, prefix)
super()._load_from_state_dict(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)
def forward(self, x, x_mask, prefix_embeds=None):
"""Subsample x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
prefix_embeds (torch.Tensor or None): Prefix token embeddings
(#batch, prefix_len, odim).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 2.
torch.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 2.
"""
x = x.unsqueeze(1) # (b, c, t, f)
x = self.conv(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
if x_mask is not None:
x_mask = x_mask[:, :, :-2:2][:, :, :-2:1]
if prefix_embeds is not None:
x = torch.cat([prefix_embeds, x], dim=1)
if x_mask is not None:
x_mask = torch.cat(
[
torch.ones(
x_mask.shape[0],
1,
prefix_embeds.size(1),
dtype=x_mask.dtype,
device=x_mask.device,
),
x_mask,
],
dim=-1,
)
x = self.pos_enc(x)
return x, x_mask
def __getitem__(self, key):
"""Get item.
When reset_parameters() is called, if use_scaled_pos_enc is used,
return the positioning encoding.
"""
if key != -1:
raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
return self.pos_enc
class Conv2dSubsampling6(torch.nn.Module):
"""Convolutional 2D subsampling (to 1/6 length).
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
pos_enc (torch.nn.Module): Custom position encoding layer.
"""
def __init__(self, idim, odim, dropout_rate, pos_enc=None):
"""Construct an Conv2dSubsampling6 object."""
super(Conv2dSubsampling6, self).__init__()
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, odim, 3, 2),
torch.nn.ReLU(),
torch.nn.Conv2d(odim, odim, 5, 3),
torch.nn.ReLU(),
)
self.out = torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3), odim)
self.pos_enc = (
pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate)
)
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
_upgrade_legacy_subsampling_state_dict(state_dict, prefix)
super()._load_from_state_dict(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)
def forward(self, x, x_mask, prefix_embeds=None):
"""Subsample x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
prefix_embeds (torch.Tensor or None): Prefix token embeddings
(#batch, prefix_len, odim).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 6.
torch.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 6.
"""
x = x.unsqueeze(1) # (b, c, t, f)
x = self.conv(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
if x_mask is not None:
x_mask = x_mask[:, :, :-2:2][:, :, :-4:3]
if prefix_embeds is not None:
x = torch.cat([prefix_embeds, x], dim=1)
if x_mask is not None:
x_mask = torch.cat(
[
torch.ones(
x_mask.shape[0],
1,
prefix_embeds.size(1),
dtype=x_mask.dtype,
device=x_mask.device,
),
x_mask,
],
dim=-1,
)
x = self.pos_enc(x)
return x, x_mask
class Conv2dSubsampling8(torch.nn.Module):
"""Convolutional 2D subsampling (to 1/8 length).
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
pos_enc (torch.nn.Module): Custom position encoding layer.
"""
def __init__(self, idim, odim, dropout_rate, pos_enc=None):
"""Construct an Conv2dSubsampling8 object."""
super(Conv2dSubsampling8, self).__init__()
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, odim, 3, 2),
torch.nn.ReLU(),
torch.nn.Conv2d(odim, odim, 3, 2),
torch.nn.ReLU(),
torch.nn.Conv2d(odim, odim, 3, 2),
torch.nn.ReLU(),
)
self.out = torch.nn.Linear(odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim)
self.pos_enc = (
pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate)
)
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
_upgrade_legacy_subsampling_state_dict(state_dict, prefix)
super()._load_from_state_dict(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)
def forward(self, x, x_mask, prefix_embeds=None):
"""Subsample x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
prefix_embeds (torch.Tensor or None): Prefix token embeddings
(#batch, prefix_len, odim).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 8.
torch.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 8.
"""
x = x.unsqueeze(1) # (b, c, t, f)
x = self.conv(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
if x_mask is not None:
x_mask = x_mask[:, :, :-2:2][:, :, :-2:2][:, :, :-2:2]
if prefix_embeds is not None:
x = torch.cat([prefix_embeds, x], dim=1)
if x_mask is not None:
x_mask = torch.cat(
[
torch.ones(
x_mask.shape[0],
1,
prefix_embeds.size(1),
dtype=x_mask.dtype,
device=x_mask.device,
),
x_mask,
],
dim=-1,
)
x = self.pos_enc(x)
return x, x_mask