rd_hubert / modeling_distiller.py
Hguimaraes's picture
Upload model
8ff90c3 verified
"""
Builder for Distiller
Author: Heng-Jui Chang (https://github.com/vectominist)
"""
import torch
from torch import nn
from .configuration_distiller import DistillerConfig
from .distiller_w2v2_modules import (
ConvFeatureExtractionModel,
GradMultiply,
)
from .distiller_modules import (
TransformerEncoder,
SplitLinear,
)
class DistillerModel(nn.Module):
"""
Distiller Model
"""
def __init__(self, config: DistillerConfig):
super().__init__()
self.config = config
self.conv_layers = eval(config.extractor_conv_feature_layers)
feat_emb_dim = self.conv_layers[-1][0]
self.feature_extractor = ConvFeatureExtractionModel(
self.conv_layers,
dropout=config.extractor_dropout,
mode=config.extractor_mode,
conv_bias=False,
)
self.feature_grad_mult = config.feature_grad_mult
self.n_tasks = config.n_tasks
self.task_emb_type = config.task_emb_type
final_emb_size = config.encoder_embed_dim
if self.task_emb_type == "add":
self.task_embedding = nn.Embedding(config.n_tasks, config.encoder_embed_dim)
nn.init.normal_(self.task_embedding.weight, 0.0, 0.1)
elif self.task_emb_type == "concat":
assert config.task_emb_size > 0
feat_emb_dim += config.task_emb_size
self.task_embedding = nn.Embedding(config.n_tasks, config.task_emb_size)
elif self.task_emb_type == "concat-last":
assert config.task_emb_size > 0
self.task_embedding = nn.Embedding(config.n_tasks, config.task_emb_size)
final_emb_size += config.task_emb_size
elif self.task_emb_type == "expand-last":
self.pred_layer_id = config.pred_layer_id
assert self.n_tasks == len(self.pred_layer_id)
print(
f"[DistillerModel] - Expands the output dimension by {self.n_tasks} times"
)
print(f"[DistillerModel] - Pred layers: {self.pred_layer_id}")
elif self.task_emb_type == "self-hidden":
self.pred_layer_id = config.pred_layer_id
assert self.n_tasks == len(self.pred_layer_id)
assert self.n_tasks == config.encoder_layers + 1
print("[DistillerModel] - Predicting with self-hidden layers")
print(f"[DistillerModel] - Pred layers: {self.pred_layer_id}")
elif self.task_emb_type == "none":
print(
f"[DistillerModel] - Disabled task embedding (predicts only layer {self.n_tasks})"
)
else:
raise NotImplementedError(f"Unknown task emb type {self.task_emb_type}")
self.post_extract_proj = (
nn.Linear(feat_emb_dim, config.encoder_embed_dim)
if feat_emb_dim != config.encoder_embed_dim
else None
)
if config.encoder_layers > 0:
self.encoder = TransformerEncoder(config)
else:
self.encoder = nn.GELU()
final_dim = config.final_dim * (
1 if self.task_emb_type != "expand-last" else self.n_tasks
)
inter_dim = config.out_layer_inter_dim
inter_dim = inter_dim if inter_dim > 0 else final_emb_size
print(f"[DistillerModel] - Out layer type: {config.out_layer_type}")
if config.out_layer_type == "expand-last":
assert self.task_emb_type == "expand-last"
print(f"[DistillerModel] - Inter dim = {inter_dim}")
self.output_layer = nn.Sequential(
nn.Linear(final_emb_size, inter_dim * self.n_tasks),
nn.GELU(),
SplitLinear(inter_dim, self.n_tasks, config.final_dim),
)
elif config.out_layer_type in {"none", "self-hidden"}:
self.output_layer = None
else:
raise NotImplementedError(f"Unknown out layer type {config.out_layer_type}")
def forward_feature(self, wave, pad_mask):
"""Forward feature extractor"""
if self.feature_grad_mult > 0:
feat = self.feature_extractor(wave)
if self.feature_grad_mult != 1.0:
feat = GradMultiply.apply(feat, self.feature_grad_mult)
else:
with torch.no_grad():
feat = self.feature_extractor(wave)
feat = feat.transpose(1, 2) # B x T x D
pad_mask = self.cal_pad_mask(pad_mask, feat.shape[1])
return feat, pad_mask
def forward(self, wave, pad_mask, task_id=None, get_hidden=False, no_pred=False):
"""
Forward function
Input:
wave (FloatTensor): B x T_wave
pad_mask (BoolTensor): B x T_wave
task_id (LongTensor): N >= 1
"""
feat, pad_mask = self.forward_feature(wave, pad_mask)
if self.task_emb_type not in ["none", "expand-last", "self-hidden"]:
if task_id is None:
task_id = self.generate_task_id(feat.device)
elif isinstance(task_id, list):
task_id = torch.LongTensor(task_id).to(feat.device)
task_embs = self.task_embedding(task_id)
# N x D
n_sz = len(task_id)
else:
n_sz = 1
b_sz, t_sz, _ = feat.shape
if self.task_emb_type == "add":
# Add embs to feature
if self.post_extract_proj is not None:
feat_final = self.post_extract_proj(feat)
else:
feat_final = feat
feat_final = feat_final.unsqueeze(1) + task_embs.unsqueeze(0).unsqueeze(2)
elif self.task_emb_type == "concat":
# Concatenates embs to feature
feat_final = torch.cat(
[
feat.unsqueeze(1).expand(-1, n_sz, -1, -1),
task_embs.unsqueeze(0).unsqueeze(2).expand(b_sz, -1, t_sz, -1),
],
dim=-1,
)
if self.post_extract_proj is not None:
feat_final = self.post_extract_proj(feat_final)
else:
if self.post_extract_proj is not None:
feat_final = self.post_extract_proj(feat)
else:
feat_final = feat
feat_final = feat_final.unsqueeze(1)
# feat_final: B x N x T x D or B x 1 x T x D
pad_mask = pad_mask.unsqueeze(1).expand(-1, n_sz, -1).reshape(b_sz * n_sz, t_sz)
# BN x T
feat_final = feat_final.reshape(b_sz * n_sz, t_sz, -1)
# BN x T x D
layer_hiddens = []
if self.config.encoder_layers > 0:
get_hidden_tmp = (
True if (self.task_emb_type == "self-hidden") else get_hidden
)
hidden, layer_hiddens = self.encoder(
feat_final, ~pad_mask.bool(), get_hidden=get_hidden_tmp
)
else:
hidden = self.encoder(feat_final)
if not no_pred:
if self.task_emb_type == "self-hidden":
pred = torch.stack([feat_final] + layer_hiddens, dim=1)
else:
pred = self.output_layer(hidden).reshape(b_sz, n_sz, t_sz, -1)
# B x N x T x D
else:
pred = None
if (not no_pred) and self.task_emb_type == "expand-last":
assert n_sz == 1, n_sz
pred = (
pred.squeeze(1)
.reshape(b_sz, t_sz, self.n_tasks, -1)
.permute(0, 2, 1, 3)
)
# B x N x T x D
if get_hidden:
return feat, feat_final, pred, pad_mask, layer_hiddens
else:
return feat, feat_final, pred, pad_mask
def cal_pad_mask(self, pad_mask, max_len):
"""Calculates pad mask after conv."""
pad_len = (pad_mask > 0).sum(1).long()
for _, k_size, s_size in self.conv_layers:
pad_len = (pad_len - k_size) // s_size + 1
new_pad_mask = torch.ones(
(pad_mask.shape[0], max_len), dtype=pad_mask.dtype, device=pad_mask.device
)
for idx in range(pad_len.shape[0]):
new_pad_mask[idx, pad_len[idx] :] = 0
return new_pad_mask
def generate_task_id(self, device):
return torch.arange(self.n_tasks, device=device, dtype=torch.long)