visnet / auto_avsr /lightning.py
okregent's picture
Add Gradio app with full VSR pipeline
b5ece2c
import torch
import torchaudio
from cosine import WarmupCosineScheduler
from datamodule.transforms import TextTransform
from espnet.nets.batch_beam_search import BatchBeamSearch
from espnet.nets.pytorch_backend.e2e_asr_conformer import E2E
from espnet.nets.scorers.length_bonus import LengthBonus
from pytorch_lightning import LightningModule
def compute_word_level_distance(seq1, seq2):
seq1, seq2 = seq1.lower().split(), seq2.lower().split()
return torchaudio.functional.edit_distance(seq1, seq2)
class ModelModule(LightningModule):
def __init__(self, args):
super().__init__()
self.args = args
self.save_hyperparameters(args)
self.modality = args.modality
self.text_transform = TextTransform()
self.token_list = self.text_transform.token_list
self.model = E2E(len(self.token_list), self.modality, ctc_weight=getattr(args, "ctc_weight", 0.1))
# -- initialise
if getattr(args, "pretrained_model_path", None):
ckpt = torch.load(args.pretrained_model_path, map_location=lambda storage, loc: storage)
if getattr(args, "transfer_frontend", False):
tmp_ckpt = {k: v for k, v in ckpt["model_state_dict"].items() if k.startswith("trunk.") or k.startswith("frontend3D.")}
self.model.frontend.load_state_dict(tmp_ckpt)
print("Pretrained weights of the frontend component are loaded successfully.")
elif getattr(args, "transfer_encoder", False):
tmp_ckpt = {k.replace("frontend.",""):v for k,v in ckpt.items() if k.startswith("frontend.")}
self.model.frontend.load_state_dict(tmp_ckpt)
tmp_ckpt = {k.replace("proj_encoder.",""):v for k,v in ckpt.items() if k.startswith("proj_encoder.")}
self.model.proj_encoder.load_state_dict(tmp_ckpt)
tmp_ckpt = {k.replace("encoder.",""):v for k,v in ckpt.items() if k.startswith("encoder.")}
self.model.encoder.load_state_dict(tmp_ckpt)
print("Pretrained weights of the frontend, proj_encoder and encoder component are loaded successfully.")
else:
self.model.load_state_dict(ckpt)
print("Pretrained weights of the full model are loaded successfully.")
def configure_optimizers(self):
optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay, betas=(0.9, 0.98))
scheduler = WarmupCosineScheduler(optimizer, self.args.warmup_epochs, self.args.max_epochs, len(self.trainer.datamodule.train_dataloader()) / self.trainer.num_devices / self.trainer.num_nodes)
scheduler = {"scheduler": scheduler, "interval": "step"}
return [optimizer], [scheduler]
def forward(self, sample):
self.beam_search = get_beam_search_decoder(self.model, self.token_list)
x = self.model.frontend(sample.unsqueeze(0))
x = self.model.proj_encoder(x)
enc_feat, _ = self.model.encoder(x, None)
enc_feat = enc_feat.squeeze(0)
nbest_hyps = self.beam_search(enc_feat)
nbest_hyps = [h.asdict() for h in nbest_hyps[: min(len(nbest_hyps), 1)]]
predicted_token_id = torch.tensor(list(map(int, nbest_hyps[0]["yseq"][1:])))
predicted = self.text_transform.post_process(predicted_token_id).replace("<eos>", "")
return predicted
def validation_step(self, batch, batch_idx):
return self._step(batch, batch_idx, step_type="val")
def test_step(self, sample, sample_idx):
x = self.model.frontend(sample["input"].unsqueeze(0))
x = self.model.proj_encoder(x)
enc_feat, _ = self.model.encoder(x, None)
enc_feat = enc_feat.squeeze(0)
nbest_hyps = self.beam_search(enc_feat)
nbest_hyps = [h.asdict() for h in nbest_hyps[: min(len(nbest_hyps), 1)]]
predicted_token_id = torch.tensor(list(map(int, nbest_hyps[0]["yseq"][1:])))
predicted = self.text_transform.post_process(predicted_token_id).replace("<eos>", "")
actual_token_id = sample["target"]
actual = self.text_transform.post_process(actual_token_id)
self.total_edit_distance += compute_word_level_distance(actual, predicted)
self.total_length += len(actual.split())
return
def training_step(self, batch, batch_idx):
loss = self._step(batch, batch_idx, "train")
batch_size = batch["inputs"].size(0)
batch_sizes = self.all_gather(batch_size)
loss *= batch_sizes.size(0) / batch_sizes.sum() # world size / batch size
self.log("monitoring_step", torch.tensor(self.global_step, dtype=torch.float32))
return loss
def _step(self, batch, batch_idx, step_type):
loss, loss_ctc, loss_att, acc = self.model(batch["inputs"], batch["input_lengths"], batch["targets"])
batch_size = len(batch["inputs"])
if step_type == "train":
self.log("loss", loss, on_step=True, on_epoch=True, batch_size=batch_size)
self.log("loss_ctc", loss_ctc, on_step=False, on_epoch=True, batch_size=batch_size, sync_dist=True)
self.log("loss_att", loss_att, on_step=False, on_epoch=True, batch_size=batch_size, sync_dist=True)
self.log("decoder_acc", acc, on_step=True, on_epoch=True, batch_size=batch_size, sync_dist=True)
else:
self.log("loss_val", loss, batch_size=batch_size, sync_dist=True)
self.log("loss_ctc_val", loss_ctc, batch_size=batch_size, sync_dist=True)
self.log("loss_att_val", loss_att, batch_size=batch_size, sync_dist=True)
self.log("decoder_acc_val", acc, batch_size=batch_size, sync_dist=True)
if step_type == "train":
self.log("monitoring_step", torch.tensor(self.global_step, dtype=torch.float32))
return loss
def on_test_epoch_start(self):
self.total_length = 0
self.total_edit_distance = 0
self.text_transform = TextTransform()
self.beam_search = get_beam_search_decoder(self.model, self.token_list)
def on_test_epoch_end(self):
self.log("wer", self.total_edit_distance / self.total_length)
def get_beam_search_decoder(
model,
token_list,
rnnlm=None,
rnnlm_conf=None,
penalty=0,
ctc_weight=0.1,
lm_weight=0.0,
beam_size=40,
):
sos = model.odim - 1
eos = model.odim - 1
scorers = model.scorers()
scorers["lm"] = None
scorers["length_bonus"] = LengthBonus(len(token_list))
weights = {
"decoder": 1.0 - ctc_weight,
"ctc": ctc_weight,
"lm": lm_weight,
"length_bonus": penalty,
}
return BatchBeamSearch(
beam_size=beam_size,
vocab_size=len(token_list),
weights=weights,
scorers=scorers,
sos=sos,
eos=eos,
token_list=token_list,
pre_beam_score_key=None if ctc_weight == 1.0 else "decoder",
)