|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
|
import collections |
|
|
import json |
|
|
import logging |
|
|
import os |
|
|
import pathlib |
|
|
import random |
|
|
import re |
|
|
import subprocess |
|
|
import warnings |
|
|
from collections import defaultdict |
|
|
from contextlib import contextmanager |
|
|
from dataclasses import dataclass |
|
|
from datetime import datetime |
|
|
from pathlib import Path |
|
|
from shutil import copyfile |
|
|
from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union |
|
|
|
|
|
import k2 |
|
|
import k2.version |
|
|
import kaldialign |
|
|
import sentencepiece as spm |
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
import torch.nn as nn |
|
|
from lhotse.dataset.signal_transforms import time_warp as time_warp_impl |
|
|
from packaging import version |
|
|
from pypinyin import lazy_pinyin, pinyin |
|
|
from pypinyin.contrib.tone_convert import to_finals, to_finals_tone, to_initials |
|
|
from torch.utils.tensorboard import SummaryWriter |
|
|
|
|
|
from icefall.checkpoint import average_checkpoints |
|
|
|
|
|
Pathlike = Union[str, Path] |
|
|
|
|
|
TORCH_VERSION = version.parse(torch.__version__) |
|
|
|
|
|
|
|
|
def create_grad_scaler(device="cuda", **kwargs): |
|
|
""" |
|
|
Creates a GradScaler compatible with both torch < 2.3.0 and >= 2.3.0. |
|
|
Accepts all kwargs like: enabled, init_scale, growth_factor, etc. |
|
|
|
|
|
/icefall/egs/librispeech/ASR/./zipformer/train.py:1451: FutureWarning: |
|
|
`torch.cuda.amp.GradScaler(args...)` is deprecated. Please use |
|
|
`torch.amp.GradScaler('cuda', args...)` instead. |
|
|
""" |
|
|
if TORCH_VERSION >= version.parse("2.3.0"): |
|
|
from torch.amp import GradScaler |
|
|
|
|
|
return GradScaler(device=device, **kwargs) |
|
|
else: |
|
|
with warnings.catch_warnings(): |
|
|
warnings.simplefilter("ignore", category=FutureWarning) |
|
|
return torch.cuda.amp.GradScaler(**kwargs) |
|
|
|
|
|
|
|
|
@contextmanager |
|
|
def torch_autocast(device_type="cuda", **kwargs): |
|
|
""" |
|
|
To fix the following warnings: |
|
|
/icefall/egs/librispeech/ASR/zipformer/model.py:323: |
|
|
FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. |
|
|
Please use `torch.amp.autocast('cuda', args...)` instead. |
|
|
with torch.cuda.amp.autocast(enabled=False): |
|
|
""" |
|
|
if TORCH_VERSION >= version.parse("2.3.0"): |
|
|
|
|
|
with torch.amp.autocast(device_type=device_type, **kwargs): |
|
|
yield |
|
|
else: |
|
|
|
|
|
with warnings.catch_warnings(): |
|
|
warnings.simplefilter("ignore", category=FutureWarning) |
|
|
with torch.cuda.amp.autocast(**kwargs): |
|
|
yield |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def is_jit_tracing(): |
|
|
if torch.jit.is_scripting(): |
|
|
return False |
|
|
elif torch.jit.is_tracing(): |
|
|
return True |
|
|
return False |
|
|
|
|
|
|
|
|
@contextmanager |
|
|
def get_executor(): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
name = subprocess.check_output("hostname -f", shell=True, text=True) |
|
|
if name.strip().endswith(".clsp.jhu.edu"): |
|
|
import plz |
|
|
from distributed import Client |
|
|
|
|
|
with plz.setup_cluster() as cluster: |
|
|
cluster.scale(80) |
|
|
yield Client(cluster) |
|
|
return |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
yield None |
|
|
|
|
|
|
|
|
def str2bool(v): |
|
|
"""Used in argparse.ArgumentParser.add_argument to indicate |
|
|
that a type is a bool type and user can enter |
|
|
|
|
|
- yes, true, t, y, 1, to represent True |
|
|
- no, false, f, n, 0, to represent False |
|
|
|
|
|
See https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse # noqa |
|
|
""" |
|
|
if isinstance(v, bool): |
|
|
return v |
|
|
if v.lower() in ("yes", "true", "t", "y", "1"): |
|
|
return True |
|
|
elif v.lower() in ("no", "false", "f", "n", "0"): |
|
|
return False |
|
|
else: |
|
|
raise argparse.ArgumentTypeError("Boolean value expected.") |
|
|
|
|
|
|
|
|
def setup_logger( |
|
|
log_filename: Pathlike, |
|
|
log_level: str = "info", |
|
|
use_console: bool = True, |
|
|
) -> None: |
|
|
"""Setup log level. |
|
|
|
|
|
Args: |
|
|
log_filename: |
|
|
The filename to save the log. |
|
|
log_level: |
|
|
The log level to use, e.g., "debug", "info", "warning", "error", |
|
|
"critical" |
|
|
use_console: |
|
|
True to also print logs to console. |
|
|
""" |
|
|
now = datetime.now() |
|
|
date_time = now.strftime("%Y-%m-%d-%H-%M-%S") |
|
|
if dist.is_available() and dist.is_initialized(): |
|
|
world_size = dist.get_world_size() |
|
|
rank = dist.get_rank() |
|
|
formatter = f"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] ({rank}/{world_size}) %(message)s" |
|
|
log_filename = f"{log_filename}-{date_time}-{rank}" |
|
|
else: |
|
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" |
|
|
log_filename = f"{log_filename}-{date_time}" |
|
|
|
|
|
os.makedirs(os.path.dirname(log_filename), exist_ok=True) |
|
|
|
|
|
level = logging.ERROR |
|
|
if log_level == "debug": |
|
|
level = logging.DEBUG |
|
|
elif log_level == "info": |
|
|
level = logging.INFO |
|
|
elif log_level == "warning": |
|
|
level = logging.WARNING |
|
|
elif log_level == "critical": |
|
|
level = logging.CRITICAL |
|
|
|
|
|
logging.basicConfig( |
|
|
filename=log_filename, |
|
|
format=formatter, |
|
|
level=level, |
|
|
filemode="w", |
|
|
force=True, |
|
|
) |
|
|
if use_console: |
|
|
console = logging.StreamHandler() |
|
|
console.setLevel(level) |
|
|
console.setFormatter(logging.Formatter(formatter)) |
|
|
logging.getLogger("").addHandler(console) |
|
|
|
|
|
|
|
|
class AttributeDict(dict): |
|
|
def __getattr__(self, key): |
|
|
if key in self: |
|
|
return self[key] |
|
|
raise AttributeError(f"No such attribute '{key}'") |
|
|
|
|
|
def __setattr__(self, key, value): |
|
|
self[key] = value |
|
|
|
|
|
def __delattr__(self, key): |
|
|
if key in self: |
|
|
del self[key] |
|
|
return |
|
|
raise AttributeError(f"No such attribute '{key}'") |
|
|
|
|
|
def __str__(self, indent: int = 2): |
|
|
tmp = {} |
|
|
for k, v in self.items(): |
|
|
|
|
|
if isinstance(v, (pathlib.Path, torch.device, torch.dtype)): |
|
|
v = str(v) |
|
|
tmp[k] = v |
|
|
return json.dumps(tmp, indent=indent, sort_keys=True) |
|
|
|
|
|
|
|
|
def encode_supervisions( |
|
|
supervisions: dict, |
|
|
subsampling_factor: int, |
|
|
token_ids: Optional[List[List[int]]] = None, |
|
|
) -> Tuple[torch.Tensor, Union[List[str], List[List[int]]]]: |
|
|
""" |
|
|
Encodes Lhotse's ``batch["supervisions"]`` dict into |
|
|
a pair of torch Tensor, and a list of transcription strings or token indexes |
|
|
|
|
|
The supervision tensor has shape ``(batch_size, 3)``. |
|
|
Its second dimension contains information about sequence index [0], |
|
|
start frames [1] and num frames [2]. |
|
|
|
|
|
The batch items might become re-ordered during this operation -- the |
|
|
returned tensor and list of strings are guaranteed to be consistent with |
|
|
each other. |
|
|
""" |
|
|
supervision_segments = torch.stack( |
|
|
( |
|
|
supervisions["sequence_idx"], |
|
|
torch.div( |
|
|
supervisions["start_frame"], |
|
|
subsampling_factor, |
|
|
rounding_mode="floor", |
|
|
), |
|
|
torch.div( |
|
|
supervisions["num_frames"], |
|
|
subsampling_factor, |
|
|
rounding_mode="floor", |
|
|
), |
|
|
), |
|
|
1, |
|
|
).to(torch.int32) |
|
|
|
|
|
indices = torch.argsort(supervision_segments[:, 2], descending=True) |
|
|
supervision_segments = supervision_segments[indices] |
|
|
|
|
|
if token_ids is None: |
|
|
texts = supervisions["text"] |
|
|
res = [texts[idx] for idx in indices] |
|
|
else: |
|
|
res = [token_ids[idx] for idx in indices] |
|
|
|
|
|
return supervision_segments, res |
|
|
|
|
|
|
|
|
def get_texts( |
|
|
best_paths: k2.Fsa, return_ragged: bool = False |
|
|
) -> Union[List[List[int]], k2.RaggedTensor]: |
|
|
"""Extract the texts (as word IDs) from the best-path FSAs. |
|
|
Args: |
|
|
best_paths: |
|
|
A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e. |
|
|
containing multiple FSAs, which is expected to be the result |
|
|
of k2.shortest_path (otherwise the returned values won't |
|
|
be meaningful). |
|
|
return_ragged: |
|
|
True to return a ragged tensor with two axes [utt][word_id]. |
|
|
False to return a list-of-list word IDs. |
|
|
Returns: |
|
|
Returns a list of lists of int, containing the label sequences we |
|
|
decoded. |
|
|
""" |
|
|
if isinstance(best_paths.aux_labels, k2.RaggedTensor): |
|
|
|
|
|
aux_labels = best_paths.aux_labels.remove_values_leq(0) |
|
|
|
|
|
aux_shape = best_paths.arcs.shape().compose(aux_labels.shape) |
|
|
|
|
|
|
|
|
aux_shape = aux_shape.remove_axis(1) |
|
|
aux_shape = aux_shape.remove_axis(1) |
|
|
aux_labels = k2.RaggedTensor(aux_shape, aux_labels.values) |
|
|
else: |
|
|
|
|
|
aux_shape = best_paths.arcs.shape().remove_axis(1) |
|
|
aux_labels = k2.RaggedTensor(aux_shape, best_paths.aux_labels) |
|
|
|
|
|
aux_labels = aux_labels.remove_values_leq(0) |
|
|
|
|
|
assert aux_labels.num_axes == 2 |
|
|
if return_ragged: |
|
|
return aux_labels |
|
|
else: |
|
|
return aux_labels.tolist() |
|
|
|
|
|
|
|
|
def encode_supervisions_otc( |
|
|
supervisions: dict, |
|
|
subsampling_factor: int, |
|
|
token_ids: Optional[List[List[int]]] = None, |
|
|
) -> Tuple[torch.Tensor, Union[List[str], List[List[int]]]]: |
|
|
""" |
|
|
Encodes Lhotse's ``batch["supervisions"]`` dict into |
|
|
a pair of torch Tensor, and a list of transcription strings or token indexes |
|
|
|
|
|
The supervision tensor has shape ``(batch_size, 3)``. |
|
|
Its second dimension contains information about sequence index [0], |
|
|
start frames [1] and num frames [2]. |
|
|
|
|
|
The batch items might become re-ordered during this operation -- the |
|
|
returned tensor and list of strings are guaranteed to be consistent with |
|
|
each other. |
|
|
""" |
|
|
supervision_segments = torch.stack( |
|
|
( |
|
|
supervisions["sequence_idx"], |
|
|
torch.div( |
|
|
supervisions["start_frame"], |
|
|
subsampling_factor, |
|
|
rounding_mode="floor", |
|
|
), |
|
|
torch.div( |
|
|
supervisions["num_frames"], |
|
|
subsampling_factor, |
|
|
rounding_mode="floor", |
|
|
), |
|
|
), |
|
|
1, |
|
|
).to(torch.int32) |
|
|
|
|
|
indices = torch.argsort(supervision_segments[:, 2], descending=True) |
|
|
supervision_segments = supervision_segments[indices] |
|
|
|
|
|
ids = [] |
|
|
verbatim_texts = [] |
|
|
sorted_ids = [] |
|
|
sorted_verbatim_texts = [] |
|
|
|
|
|
for cut in supervisions["cut"]: |
|
|
id = cut.id |
|
|
if hasattr(cut.supervisions[0], "verbatim_text"): |
|
|
verbatim_text = cut.supervisions[0].verbatim_text |
|
|
else: |
|
|
verbatim_text = "" |
|
|
ids.append(id) |
|
|
verbatim_texts.append(verbatim_text) |
|
|
|
|
|
for index in indices.tolist(): |
|
|
sorted_ids.append(ids[index]) |
|
|
sorted_verbatim_texts.append(verbatim_texts[index]) |
|
|
|
|
|
if token_ids is None: |
|
|
texts = supervisions["text"] |
|
|
res = [texts[idx] for idx in indices] |
|
|
else: |
|
|
res = [token_ids[idx] for idx in indices] |
|
|
|
|
|
return supervision_segments, res, sorted_ids, sorted_verbatim_texts |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class KeywordResult: |
|
|
|
|
|
|
|
|
timestamps: List[int] |
|
|
|
|
|
|
|
|
hyps: List[int] |
|
|
|
|
|
|
|
|
phrase: str |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class DecodingResults: |
|
|
|
|
|
|
|
|
timestamps: List[List[int]] |
|
|
|
|
|
|
|
|
|
|
|
hyps: Union[List[List[int]], k2.RaggedTensor] |
|
|
|
|
|
|
|
|
scores: Optional[List[List[float]]] = None |
|
|
|
|
|
|
|
|
def get_texts_with_timestamp( |
|
|
best_paths: k2.Fsa, return_ragged: bool = False |
|
|
) -> DecodingResults: |
|
|
"""Extract the texts (as word IDs) and timestamps (as frame indexes) |
|
|
from the best-path FSAs. |
|
|
Args: |
|
|
best_paths: |
|
|
A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e. |
|
|
containing multiple FSAs, which is expected to be the result |
|
|
of k2.shortest_path (otherwise the returned values won't |
|
|
be meaningful). |
|
|
return_ragged: |
|
|
True to return a ragged tensor with two axes [utt][word_id]. |
|
|
False to return a list-of-list word IDs. |
|
|
Returns: |
|
|
Returns a list of lists of int, containing the label sequences we |
|
|
decoded. |
|
|
""" |
|
|
if isinstance(best_paths.aux_labels, k2.RaggedTensor): |
|
|
all_aux_shape = ( |
|
|
best_paths.arcs.shape().remove_axis(1).compose(best_paths.aux_labels.shape) |
|
|
) |
|
|
all_aux_labels = k2.RaggedTensor(all_aux_shape, best_paths.aux_labels.values) |
|
|
|
|
|
aux_labels = best_paths.aux_labels.remove_values_leq(0) |
|
|
|
|
|
aux_shape = best_paths.arcs.shape().compose(aux_labels.shape) |
|
|
|
|
|
aux_shape = aux_shape.remove_axis(1) |
|
|
aux_shape = aux_shape.remove_axis(1) |
|
|
aux_labels = k2.RaggedTensor(aux_shape, aux_labels.values) |
|
|
else: |
|
|
|
|
|
aux_shape = best_paths.arcs.shape().remove_axis(1) |
|
|
all_aux_labels = k2.RaggedTensor(aux_shape, best_paths.aux_labels) |
|
|
|
|
|
aux_labels = all_aux_labels.remove_values_leq(0) |
|
|
|
|
|
assert aux_labels.num_axes == 2 |
|
|
|
|
|
timestamps = [] |
|
|
if isinstance(best_paths.aux_labels, k2.RaggedTensor): |
|
|
for p in range(all_aux_labels.dim0): |
|
|
time = [] |
|
|
for i, arc in enumerate(all_aux_labels[p].tolist()): |
|
|
if len(arc) == 1 and arc[0] > 0: |
|
|
time.append(i) |
|
|
timestamps.append(time) |
|
|
else: |
|
|
for labels in all_aux_labels.tolist(): |
|
|
time = [i for i, v in enumerate(labels) if v > 0] |
|
|
timestamps.append(time) |
|
|
|
|
|
return DecodingResults( |
|
|
timestamps=timestamps, |
|
|
hyps=aux_labels if return_ragged else aux_labels.tolist(), |
|
|
) |
|
|
|
|
|
|
|
|
def get_alignments(best_paths: k2.Fsa, kind: str) -> List[List[int]]: |
|
|
"""Extract labels or aux_labels from the best-path FSAs. |
|
|
|
|
|
Args: |
|
|
best_paths: |
|
|
A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e. |
|
|
containing multiple FSAs, which is expected to be the result |
|
|
of k2.shortest_path (otherwise the returned values won't |
|
|
be meaningful). |
|
|
kind: |
|
|
Possible values are: "labels" and "aux_labels". Caution: When it is |
|
|
"labels", the resulting alignments contain repeats. |
|
|
Returns: |
|
|
Returns a list of lists of int, containing the token sequences we |
|
|
decoded. For `ans[i]`, its length equals to the number of frames |
|
|
after subsampling of the i-th utterance in the batch. |
|
|
|
|
|
Example: |
|
|
When `kind` is `labels`, one possible alignment example is (with |
|
|
repeats):: |
|
|
|
|
|
c c c blk a a blk blk t t t blk blk |
|
|
|
|
|
If `kind` is `aux_labels`, the above example changes to:: |
|
|
|
|
|
c blk blk blk a blk blk blk t blk blk blk blk |
|
|
|
|
|
""" |
|
|
assert kind in ("labels", "aux_labels") |
|
|
|
|
|
token_shape = best_paths.arcs.shape().remove_axis(1) |
|
|
|
|
|
tokens = k2.RaggedTensor(token_shape, getattr(best_paths, kind).contiguous()) |
|
|
tokens = tokens.remove_values_eq(-1) |
|
|
return tokens.tolist() |
|
|
|
|
|
|
|
|
def save_alignments( |
|
|
alignments: Dict[str, List[int]], |
|
|
subsampling_factor: int, |
|
|
filename: str, |
|
|
) -> None: |
|
|
"""Save alignments to a file. |
|
|
|
|
|
Args: |
|
|
alignments: |
|
|
A dict containing alignments. Keys of the dict are utterances and |
|
|
values are the corresponding framewise alignments after subsampling. |
|
|
subsampling_factor: |
|
|
The subsampling factor of the model. |
|
|
filename: |
|
|
Path to save the alignments. |
|
|
Returns: |
|
|
Return None. |
|
|
""" |
|
|
ali_dict = { |
|
|
"subsampling_factor": subsampling_factor, |
|
|
"alignments": alignments, |
|
|
} |
|
|
torch.save(ali_dict, filename) |
|
|
|
|
|
|
|
|
def load_alignments(filename: str) -> Tuple[int, Dict[str, List[int]]]: |
|
|
"""Load alignments from a file. |
|
|
|
|
|
Args: |
|
|
filename: |
|
|
Path to the file containing alignment information. |
|
|
The file should be saved by :func:`save_alignments`. |
|
|
Returns: |
|
|
Return a tuple containing: |
|
|
- subsampling_factor: The subsampling_factor used to compute |
|
|
the alignments. |
|
|
- alignments: A dict containing utterances and their corresponding |
|
|
framewise alignment, after subsampling. |
|
|
""" |
|
|
ali_dict = torch.load(filename, weights_only=False) |
|
|
subsampling_factor = ali_dict["subsampling_factor"] |
|
|
alignments = ali_dict["alignments"] |
|
|
return subsampling_factor, alignments |
|
|
|
|
|
|
|
|
def store_transcripts( |
|
|
filename: Pathlike, texts: Iterable[Tuple[str, str, str]], char_level: bool = False |
|
|
) -> None: |
|
|
"""Save predicted results and reference transcripts to a file. |
|
|
|
|
|
Args: |
|
|
filename: |
|
|
File to save the results to. |
|
|
texts: |
|
|
An iterable of tuples. The first element is the cur_id, the second is |
|
|
the reference transcript and the third element is the predicted result. |
|
|
If it is a multi-talker ASR system, the ref and hyp may also be lists of |
|
|
strings. |
|
|
Returns: |
|
|
Return None. |
|
|
""" |
|
|
with open(filename, "w", encoding="utf8") as f: |
|
|
for cut_id, ref, hyp in texts: |
|
|
if char_level: |
|
|
ref = list("".join(ref)) |
|
|
hyp = list("".join(hyp)) |
|
|
print(f"{cut_id}:\tref={ref}", file=f) |
|
|
print(f"{cut_id}:\thyp={hyp}", file=f) |
|
|
|
|
|
|
|
|
def store_transcripts_and_timestamps( |
|
|
filename: Pathlike, |
|
|
texts: Iterable[Tuple[str, List[str], List[str], List[float], List[float]]], |
|
|
) -> None: |
|
|
"""Save predicted results and reference transcripts as well as their timestamps |
|
|
to a file. |
|
|
|
|
|
Args: |
|
|
filename: |
|
|
File to save the results to. |
|
|
texts: |
|
|
An iterable of tuples. The first element is the cur_id, the second is |
|
|
the reference transcript and the third element is the predicted result. |
|
|
Returns: |
|
|
Return None. |
|
|
""" |
|
|
with open(filename, "w", encoding="utf8") as f: |
|
|
for cut_id, ref, hyp, time_ref, time_hyp in texts: |
|
|
print(f"{cut_id}:\tref={ref}", file=f) |
|
|
print(f"{cut_id}:\thyp={hyp}", file=f) |
|
|
|
|
|
if len(time_ref) > 0: |
|
|
if isinstance(time_ref[0], tuple): |
|
|
|
|
|
s = ( |
|
|
"[" |
|
|
+ ", ".join(["(%0.3f, %.03f)" % (i, j) for (i, j) in time_ref]) |
|
|
+ "]" |
|
|
) |
|
|
else: |
|
|
|
|
|
s = "[" + ", ".join(["%0.3f" % i for i in time_ref]) + "]" |
|
|
print(f"{cut_id}:\ttimestamp_ref={s}", file=f) |
|
|
|
|
|
if len(time_hyp) > 0: |
|
|
if isinstance(time_hyp[0], tuple): |
|
|
|
|
|
s = ( |
|
|
"[" |
|
|
+ ", ".join(["(%0.3f, %.03f)" % (i, j) for (i, j) in time_hyp]) |
|
|
+ "]" |
|
|
) |
|
|
else: |
|
|
|
|
|
s = "[" + ", ".join(["%0.3f" % i for i in time_hyp]) + "]" |
|
|
print(f"{cut_id}:\ttimestamp_hyp={s}", file=f) |
|
|
|
|
|
|
|
|
def store_translations( |
|
|
filename: Pathlike, texts: Iterable[Tuple[str, str, str]], |
|
|
lowercase: bool = True) -> None: |
|
|
"""Save predicted results and reference transcripts to a file. |
|
|
|
|
|
Args: |
|
|
filename: |
|
|
File to save the results to. |
|
|
texts: |
|
|
An iterable of tuples. The first element is the cur_id, the second is |
|
|
the reference transcript and the third element is the reference translation |
|
|
and the fourth element is the predicted result. |
|
|
Returns: |
|
|
Return None. |
|
|
""" |
|
|
bleu = BLEU(lowercase=lowercase) |
|
|
hyp_list = [] |
|
|
ref_list = [] |
|
|
dir_ = os.path.dirname(filename) |
|
|
reftgt = os.path.join(dir_, "reftgt-" + str(os.path.basename(filename))) |
|
|
refsrc = os.path.join(dir_, "refsrc-"+str(os.path.basename(filename))) |
|
|
hyp = os.path.join(dir_, "hyp-"+str( os.path.basename(filename))) |
|
|
bleu_file = os.path.join(dir_, "bleu-"+str( os.path.basename(filename))) |
|
|
with open(filename, "w") as f, open(reftgt, "w") as f_tgt, open(hyp, "w") as f_hyp, open(refsrc, "w") as f_src: |
|
|
for cut_id, ref, ref_tgt, hyp in texts: |
|
|
ref = " ".join(ref) |
|
|
ref_tgt = " ".join(ref_tgt) |
|
|
hyp = " ".join(hyp) |
|
|
print(f"{cut_id}: ref {ref}", file=f) |
|
|
print(f"{cut_id}: ref_tgt {ref_tgt}", file=f) |
|
|
print(f"{cut_id}: hyp {hyp}", file=f) |
|
|
print("\n", file=f) |
|
|
|
|
|
|
|
|
print(f"{ref}", file=f_src) |
|
|
print(f"{ref_tgt}", file=f_tgt) |
|
|
print(f"{hyp}", file=f_hyp) |
|
|
|
|
|
hyp_list.append(hyp) |
|
|
ref_list.append(ref_tgt) |
|
|
|
|
|
with open(bleu_file, 'w') as b: |
|
|
print(str(bleu.corpus_score(hyp_list, [ref_list])), file=b) |
|
|
print(f"BLEU signiture: {str(bleu.get_signature())}", file=b) |
|
|
|
|
|
logging.info( |
|
|
f"[{bleu.corpus_score(hyp_list, [ref_list])}] " |
|
|
f"BLEU signiture: {str(bleu.get_signature())}" |
|
|
) |
|
|
|
|
|
|
|
|
def write_error_stats( |
|
|
f: TextIO, |
|
|
test_set_name: str, |
|
|
results: List[Tuple[str, str]], |
|
|
enable_log: bool = True, |
|
|
compute_CER: bool = False, |
|
|
sclite_mode: bool = False, |
|
|
) -> float: |
|
|
"""Write statistics based on predicted results and reference transcripts. |
|
|
|
|
|
It will write the following to the given file: |
|
|
|
|
|
- WER |
|
|
- number of insertions, deletions, substitutions, corrects and total |
|
|
reference words. For example:: |
|
|
|
|
|
Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606 |
|
|
reference words (2337 correct) |
|
|
|
|
|
- The difference between the reference transcript and predicted result. |
|
|
An instance is given below:: |
|
|
|
|
|
THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES |
|
|
|
|
|
The above example shows that the reference word is `EDISON`, |
|
|
but it is predicted to `ADDISON` (a substitution error). |
|
|
|
|
|
Another example is:: |
|
|
|
|
|
FOR THE FIRST DAY (SIR->*) I THINK |
|
|
|
|
|
The reference word `SIR` is missing in the predicted |
|
|
results (a deletion error). |
|
|
results: |
|
|
An iterable of tuples. The first element is the cut_id, the second is |
|
|
the reference transcript and the third element is the predicted result. |
|
|
enable_log: |
|
|
If True, also print detailed WER to the console. |
|
|
Otherwise, it is written only to the given file. |
|
|
Returns: |
|
|
Return None. |
|
|
""" |
|
|
subs: Dict[Tuple[str, str], int] = defaultdict(int) |
|
|
ins: Dict[str, int] = defaultdict(int) |
|
|
dels: Dict[str, int] = defaultdict(int) |
|
|
|
|
|
|
|
|
|
|
|
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0]) |
|
|
num_corr = 0 |
|
|
ERR = "*" |
|
|
|
|
|
if compute_CER: |
|
|
for i, res in enumerate(results): |
|
|
cut_id, ref, hyp = res |
|
|
ref = list("".join(ref)) |
|
|
hyp = list("".join(hyp)) |
|
|
results[i] = (cut_id, ref, hyp) |
|
|
|
|
|
for cut_id, ref, hyp in results: |
|
|
ali = kaldialign.align(ref, hyp, ERR, sclite_mode=sclite_mode) |
|
|
for ref_word, hyp_word in ali: |
|
|
if ref_word == ERR: |
|
|
ins[hyp_word] += 1 |
|
|
words[hyp_word][3] += 1 |
|
|
elif hyp_word == ERR: |
|
|
dels[ref_word] += 1 |
|
|
words[ref_word][4] += 1 |
|
|
elif hyp_word != ref_word: |
|
|
subs[(ref_word, hyp_word)] += 1 |
|
|
words[ref_word][1] += 1 |
|
|
words[hyp_word][2] += 1 |
|
|
else: |
|
|
words[ref_word][0] += 1 |
|
|
num_corr += 1 |
|
|
ref_len = sum([len(r) for _, r, _ in results]) |
|
|
sub_errs = sum(subs.values()) |
|
|
ins_errs = sum(ins.values()) |
|
|
del_errs = sum(dels.values()) |
|
|
tot_errs = sub_errs + ins_errs + del_errs |
|
|
tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len) |
|
|
|
|
|
if enable_log: |
|
|
logging.info( |
|
|
f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} " |
|
|
f"[{tot_errs} / {ref_len}, {ins_errs} ins, " |
|
|
f"{del_errs} del, {sub_errs} sub ]" |
|
|
) |
|
|
|
|
|
print(f"%WER = {tot_err_rate}", file=f) |
|
|
print( |
|
|
f"Errors: {ins_errs} insertions, {del_errs} deletions, " |
|
|
f"{sub_errs} substitutions, over {ref_len} reference " |
|
|
f"words ({num_corr} correct)", |
|
|
file=f, |
|
|
) |
|
|
print( |
|
|
"Search below for sections starting with PER-UTT DETAILS:, " |
|
|
"SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:", |
|
|
file=f, |
|
|
) |
|
|
|
|
|
print("", file=f) |
|
|
print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f) |
|
|
for cut_id, ref, hyp in results: |
|
|
ali = kaldialign.align(ref, hyp, ERR) |
|
|
combine_successive_errors = True |
|
|
if combine_successive_errors: |
|
|
ali = [[[x], [y]] for x, y in ali] |
|
|
for i in range(len(ali) - 1): |
|
|
if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]: |
|
|
ali[i + 1][0] = ali[i][0] + ali[i + 1][0] |
|
|
ali[i + 1][1] = ali[i][1] + ali[i + 1][1] |
|
|
ali[i] = [[], []] |
|
|
ali = [ |
|
|
[ |
|
|
list(filter(lambda a: a != ERR, x)), |
|
|
list(filter(lambda a: a != ERR, y)), |
|
|
] |
|
|
for x, y in ali |
|
|
] |
|
|
ali = list(filter(lambda x: x != [[], []], ali)) |
|
|
ali = [ |
|
|
[ |
|
|
ERR if x == [] else " ".join(x), |
|
|
ERR if y == [] else " ".join(y), |
|
|
] |
|
|
for x, y in ali |
|
|
] |
|
|
|
|
|
print( |
|
|
f"{cut_id}:\t" |
|
|
+ " ".join( |
|
|
( |
|
|
ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})" |
|
|
for ref_word, hyp_word in ali |
|
|
) |
|
|
), |
|
|
file=f, |
|
|
) |
|
|
|
|
|
print("", file=f) |
|
|
print("SUBSTITUTIONS: count ref -> hyp", file=f) |
|
|
|
|
|
for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True): |
|
|
print(f"{count} {ref} -> {hyp}", file=f) |
|
|
|
|
|
print("", file=f) |
|
|
print("DELETIONS: count ref", file=f) |
|
|
for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True): |
|
|
print(f"{count} {ref}", file=f) |
|
|
|
|
|
print("", file=f) |
|
|
print("INSERTIONS: count hyp", file=f) |
|
|
for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True): |
|
|
print(f"{count} {hyp}", file=f) |
|
|
|
|
|
print("", file=f) |
|
|
print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f) |
|
|
for _, word, counts in sorted( |
|
|
[(sum(v[1:]), k, v) for k, v in words.items()], reverse=True |
|
|
): |
|
|
(corr, ref_sub, hyp_sub, ins, dels) = counts |
|
|
tot_errs = ref_sub + hyp_sub + ins + dels |
|
|
ref_count = corr + ref_sub + dels |
|
|
hyp_count = corr + hyp_sub + ins |
|
|
|
|
|
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f) |
|
|
return float(tot_err_rate) |
|
|
|
|
|
|
|
|
def write_error_stats_with_timestamps( |
|
|
f: TextIO, |
|
|
test_set_name: str, |
|
|
results: List[ |
|
|
Tuple[ |
|
|
str, |
|
|
List[str], |
|
|
List[str], |
|
|
List[Union[float, Tuple[float, float]]], |
|
|
List[Union[float, Tuple[float, float]]], |
|
|
] |
|
|
], |
|
|
enable_log: bool = True, |
|
|
with_end_time: bool = False, |
|
|
) -> Tuple[float, Union[float, Tuple[float, float]], Union[float, Tuple[float, float]]]: |
|
|
"""Write statistics based on predicted results and reference transcripts |
|
|
as well as their timestamps. |
|
|
|
|
|
It will write the following to the given file: |
|
|
|
|
|
- WER |
|
|
- number of insertions, deletions, substitutions, corrects and total |
|
|
reference words. For example:: |
|
|
|
|
|
Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606 |
|
|
reference words (2337 correct) |
|
|
|
|
|
- The difference between the reference transcript and predicted result. |
|
|
An instance is given below:: |
|
|
|
|
|
THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES |
|
|
|
|
|
The above example shows that the reference word is `EDISON`, |
|
|
but it is predicted to `ADDISON` (a substitution error). |
|
|
|
|
|
Another example is:: |
|
|
|
|
|
FOR THE FIRST DAY (SIR->*) I THINK |
|
|
|
|
|
The reference word `SIR` is missing in the predicted |
|
|
results (a deletion error). |
|
|
results: |
|
|
An iterable of tuples. The first element is the cur_id, the second is |
|
|
the reference transcript and the third element is the predicted result. |
|
|
enable_log: |
|
|
If True, also print detailed WER to the console. |
|
|
Otherwise, it is written only to the given file. |
|
|
with_end_time: |
|
|
Whether use end timestamps. |
|
|
|
|
|
Returns: |
|
|
Return total word error rate and mean delay. |
|
|
""" |
|
|
subs: Dict[Tuple[str, str], int] = defaultdict(int) |
|
|
ins: Dict[str, int] = defaultdict(int) |
|
|
dels: Dict[str, int] = defaultdict(int) |
|
|
|
|
|
|
|
|
|
|
|
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0]) |
|
|
num_corr = 0 |
|
|
ERR = "*" |
|
|
|
|
|
all_delay = [] |
|
|
for cut_id, ref, hyp, time_ref, time_hyp in results: |
|
|
ali = kaldialign.align(ref, hyp, ERR) |
|
|
has_time = len(time_ref) > 0 and len(time_hyp) > 0 |
|
|
if has_time: |
|
|
|
|
|
p_hyp = 0 |
|
|
|
|
|
p_ref = 0 |
|
|
for ref_word, hyp_word in ali: |
|
|
if ref_word == ERR: |
|
|
ins[hyp_word] += 1 |
|
|
words[hyp_word][3] += 1 |
|
|
if has_time: |
|
|
p_hyp += 1 |
|
|
elif hyp_word == ERR: |
|
|
dels[ref_word] += 1 |
|
|
words[ref_word][4] += 1 |
|
|
if has_time: |
|
|
p_ref += 1 |
|
|
elif hyp_word != ref_word: |
|
|
subs[(ref_word, hyp_word)] += 1 |
|
|
words[ref_word][1] += 1 |
|
|
words[hyp_word][2] += 1 |
|
|
if has_time: |
|
|
p_hyp += 1 |
|
|
p_ref += 1 |
|
|
else: |
|
|
words[ref_word][0] += 1 |
|
|
num_corr += 1 |
|
|
if has_time: |
|
|
if with_end_time: |
|
|
all_delay.append( |
|
|
( |
|
|
time_hyp[p_hyp][0] - time_ref[p_ref][0], |
|
|
time_hyp[p_hyp][1] - time_ref[p_ref][1], |
|
|
) |
|
|
) |
|
|
else: |
|
|
all_delay.append(time_hyp[p_hyp] - time_ref[p_ref]) |
|
|
p_hyp += 1 |
|
|
p_ref += 1 |
|
|
if has_time: |
|
|
assert p_hyp == len(hyp), (p_hyp, len(hyp)) |
|
|
assert p_ref == len(ref), (p_ref, len(ref)) |
|
|
|
|
|
ref_len = sum([len(r) for _, r, _, _, _ in results]) |
|
|
sub_errs = sum(subs.values()) |
|
|
ins_errs = sum(ins.values()) |
|
|
del_errs = sum(dels.values()) |
|
|
tot_errs = sub_errs + ins_errs + del_errs |
|
|
tot_err_rate = float("%.2f" % (100.0 * tot_errs / ref_len)) |
|
|
|
|
|
if with_end_time: |
|
|
mean_delay = (float("inf"), float("inf")) |
|
|
var_delay = (float("inf"), float("inf")) |
|
|
else: |
|
|
mean_delay = float("inf") |
|
|
var_delay = float("inf") |
|
|
num_delay = len(all_delay) |
|
|
if num_delay > 0: |
|
|
if with_end_time: |
|
|
all_delay_start = [i[0] for i in all_delay] |
|
|
mean_delay_start = sum(all_delay_start) / num_delay |
|
|
var_delay_start = ( |
|
|
sum([(i - mean_delay_start) ** 2 for i in all_delay_start]) / num_delay |
|
|
) |
|
|
|
|
|
all_delay_end = [i[1] for i in all_delay] |
|
|
mean_delay_end = sum(all_delay_end) / num_delay |
|
|
var_delay_end = ( |
|
|
sum([(i - mean_delay_end) ** 2 for i in all_delay_end]) / num_delay |
|
|
) |
|
|
|
|
|
mean_delay = ( |
|
|
float("%.3f" % mean_delay_start), |
|
|
float("%.3f" % mean_delay_end), |
|
|
) |
|
|
var_delay = (float("%.3f" % var_delay_start), float("%.3f" % var_delay_end)) |
|
|
else: |
|
|
mean_delay = sum(all_delay) / num_delay |
|
|
var_delay = sum([(i - mean_delay) ** 2 for i in all_delay]) / num_delay |
|
|
mean_delay = float("%.3f" % mean_delay) |
|
|
var_delay = float("%.3f" % var_delay) |
|
|
|
|
|
if enable_log: |
|
|
logging.info( |
|
|
f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} " |
|
|
f"[{tot_errs} / {ref_len}, {ins_errs} ins, " |
|
|
f"{del_errs} del, {sub_errs} sub ]" |
|
|
) |
|
|
logging.info( |
|
|
f"[{test_set_name}] %symbol-delay mean (s): " |
|
|
f"{mean_delay}, variance: {var_delay} " |
|
|
f"computed on {num_delay} correct words" |
|
|
) |
|
|
|
|
|
print(f"%WER = {tot_err_rate}", file=f) |
|
|
print( |
|
|
f"Errors: {ins_errs} insertions, {del_errs} deletions, " |
|
|
f"{sub_errs} substitutions, over {ref_len} reference " |
|
|
f"words ({num_corr} correct)", |
|
|
file=f, |
|
|
) |
|
|
print( |
|
|
"Search below for sections starting with PER-UTT DETAILS:, " |
|
|
"SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:", |
|
|
file=f, |
|
|
) |
|
|
|
|
|
print("", file=f) |
|
|
print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f) |
|
|
for cut_id, ref, hyp, _, _ in results: |
|
|
ali = kaldialign.align(ref, hyp, ERR) |
|
|
combine_successive_errors = True |
|
|
if combine_successive_errors: |
|
|
ali = [[[x], [y]] for x, y in ali] |
|
|
for i in range(len(ali) - 1): |
|
|
if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]: |
|
|
ali[i + 1][0] = ali[i][0] + ali[i + 1][0] |
|
|
ali[i + 1][1] = ali[i][1] + ali[i + 1][1] |
|
|
ali[i] = [[], []] |
|
|
ali = [ |
|
|
[ |
|
|
list(filter(lambda a: a != ERR, x)), |
|
|
list(filter(lambda a: a != ERR, y)), |
|
|
] |
|
|
for x, y in ali |
|
|
] |
|
|
ali = list(filter(lambda x: x != [[], []], ali)) |
|
|
ali = [ |
|
|
[ |
|
|
ERR if x == [] else " ".join(x), |
|
|
ERR if y == [] else " ".join(y), |
|
|
] |
|
|
for x, y in ali |
|
|
] |
|
|
|
|
|
print( |
|
|
f"{cut_id}:\t" |
|
|
+ " ".join( |
|
|
( |
|
|
ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})" |
|
|
for ref_word, hyp_word in ali |
|
|
) |
|
|
), |
|
|
file=f, |
|
|
) |
|
|
|
|
|
print("", file=f) |
|
|
print("SUBSTITUTIONS: count ref -> hyp", file=f) |
|
|
|
|
|
for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True): |
|
|
print(f"{count} {ref} -> {hyp}", file=f) |
|
|
|
|
|
print("", file=f) |
|
|
print("DELETIONS: count ref", file=f) |
|
|
for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True): |
|
|
print(f"{count} {ref}", file=f) |
|
|
|
|
|
print("", file=f) |
|
|
print("INSERTIONS: count hyp", file=f) |
|
|
for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True): |
|
|
print(f"{count} {hyp}", file=f) |
|
|
|
|
|
print("", file=f) |
|
|
print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f) |
|
|
for _, word, counts in sorted( |
|
|
[(sum(v[1:]), k, v) for k, v in words.items()], reverse=True |
|
|
): |
|
|
(corr, ref_sub, hyp_sub, ins, dels) = counts |
|
|
tot_errs = ref_sub + hyp_sub + ins + dels |
|
|
ref_count = corr + ref_sub + dels |
|
|
hyp_count = corr + hyp_sub + ins |
|
|
|
|
|
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f) |
|
|
return float(tot_err_rate), float(mean_delay), float(var_delay) |
|
|
|
|
|
|
|
|
def write_surt_error_stats( |
|
|
f: TextIO, |
|
|
test_set_name: str, |
|
|
results: List[Tuple[str, str]], |
|
|
enable_log: bool = True, |
|
|
num_channels: int = 2, |
|
|
) -> float: |
|
|
"""Write statistics based on predicted results and reference transcripts for SURT |
|
|
multi-talker ASR systems. The difference between this and the `write_error_stats` |
|
|
is that this function finds the optimal speaker-agnostic WER using the ``meeteval`` |
|
|
toolkit. |
|
|
|
|
|
Args: |
|
|
f: File to write the statistics to. |
|
|
test_set_name: Name of the test set. |
|
|
results: List of tuples containing the utterance ID and the predicted |
|
|
transcript. |
|
|
enable_log: Whether to enable logging. |
|
|
num_channels: Number of output channels/branches. Defaults to 2. |
|
|
Returns: |
|
|
Return None. |
|
|
""" |
|
|
from meeteval.wer import wer |
|
|
|
|
|
subs: Dict[Tuple[str, str], int] = defaultdict(int) |
|
|
ins: Dict[str, int] = defaultdict(int) |
|
|
dels: Dict[str, int] = defaultdict(int) |
|
|
ref_lens: List[int] = [] |
|
|
|
|
|
print( |
|
|
"Search below for sections starting with PER-UTT DETAILS:, " |
|
|
"SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:", |
|
|
file=f, |
|
|
) |
|
|
|
|
|
print("", file=f) |
|
|
print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f) |
|
|
|
|
|
|
|
|
|
|
|
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0]) |
|
|
num_corr = 0 |
|
|
ERR = "*" |
|
|
for cut_id, ref, hyp in results: |
|
|
|
|
|
orc_wer = wer.orc_word_error_rate(ref, hyp) |
|
|
assignment = orc_wer.assignment |
|
|
refs = [[] for _ in range(num_channels)] |
|
|
|
|
|
for i, ref_text in zip(assignment, ref): |
|
|
refs[i] += ref_text.split() |
|
|
hyps = [hyp_text.split() for hyp_text in hyp] |
|
|
|
|
|
for ref_c, hyp_c in zip(refs, hyps): |
|
|
ref_lens.append(len(ref_c)) |
|
|
ali = kaldialign.align(ref_c, hyp_c, ERR) |
|
|
for ref_word, hyp_word in ali: |
|
|
if ref_word == ERR: |
|
|
ins[hyp_word] += 1 |
|
|
words[hyp_word][3] += 1 |
|
|
elif hyp_word == ERR: |
|
|
dels[ref_word] += 1 |
|
|
words[ref_word][4] += 1 |
|
|
elif hyp_word != ref_word: |
|
|
subs[(ref_word, hyp_word)] += 1 |
|
|
words[ref_word][1] += 1 |
|
|
words[hyp_word][2] += 1 |
|
|
else: |
|
|
words[ref_word][0] += 1 |
|
|
num_corr += 1 |
|
|
combine_successive_errors = True |
|
|
if combine_successive_errors: |
|
|
ali = [[[x], [y]] for x, y in ali] |
|
|
for i in range(len(ali) - 1): |
|
|
if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]: |
|
|
ali[i + 1][0] = ali[i][0] + ali[i + 1][0] |
|
|
ali[i + 1][1] = ali[i][1] + ali[i + 1][1] |
|
|
ali[i] = [[], []] |
|
|
ali = [ |
|
|
[ |
|
|
list(filter(lambda a: a != ERR, x)), |
|
|
list(filter(lambda a: a != ERR, y)), |
|
|
] |
|
|
for x, y in ali |
|
|
] |
|
|
ali = list(filter(lambda x: x != [[], []], ali)) |
|
|
ali = [ |
|
|
[ |
|
|
ERR if x == [] else " ".join(x), |
|
|
ERR if y == [] else " ".join(y), |
|
|
] |
|
|
for x, y in ali |
|
|
] |
|
|
|
|
|
print( |
|
|
f"{cut_id}:\t" |
|
|
+ " ".join( |
|
|
( |
|
|
( |
|
|
ref_word |
|
|
if ref_word == hyp_word |
|
|
else f"({ref_word}->{hyp_word})" |
|
|
) |
|
|
for ref_word, hyp_word in ali |
|
|
) |
|
|
), |
|
|
file=f, |
|
|
) |
|
|
ref_len = sum(ref_lens) |
|
|
sub_errs = sum(subs.values()) |
|
|
ins_errs = sum(ins.values()) |
|
|
del_errs = sum(dels.values()) |
|
|
tot_errs = sub_errs + ins_errs + del_errs |
|
|
tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len) |
|
|
|
|
|
if enable_log: |
|
|
logging.info( |
|
|
f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} " |
|
|
f"[{tot_errs} / {ref_len}, {ins_errs} ins, " |
|
|
f"{del_errs} del, {sub_errs} sub ]" |
|
|
) |
|
|
|
|
|
print(f"%WER = {tot_err_rate}", file=f) |
|
|
print( |
|
|
f"Errors: {ins_errs} insertions, {del_errs} deletions, " |
|
|
f"{sub_errs} substitutions, over {ref_len} reference " |
|
|
f"words ({num_corr} correct)", |
|
|
file=f, |
|
|
) |
|
|
|
|
|
print("", file=f) |
|
|
print("SUBSTITUTIONS: count ref -> hyp", file=f) |
|
|
|
|
|
for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True): |
|
|
print(f"{count} {ref} -> {hyp}", file=f) |
|
|
|
|
|
print("", file=f) |
|
|
print("DELETIONS: count ref", file=f) |
|
|
for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True): |
|
|
print(f"{count} {ref}", file=f) |
|
|
|
|
|
print("", file=f) |
|
|
print("INSERTIONS: count hyp", file=f) |
|
|
for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True): |
|
|
print(f"{count} {hyp}", file=f) |
|
|
|
|
|
print("", file=f) |
|
|
print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f) |
|
|
for _, word, counts in sorted( |
|
|
[(sum(v[1:]), k, v) for k, v in words.items()], reverse=True |
|
|
): |
|
|
(corr, ref_sub, hyp_sub, ins, dels) = counts |
|
|
tot_errs = ref_sub + hyp_sub + ins + dels |
|
|
ref_count = corr + ref_sub + dels |
|
|
hyp_count = corr + hyp_sub + ins |
|
|
|
|
|
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f) |
|
|
|
|
|
print(f"%WER = {tot_err_rate}", file=f) |
|
|
return float(tot_err_rate) |
|
|
|
|
|
|
|
|
class MetricsTracker(collections.defaultdict): |
|
|
def __init__(self): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
super(MetricsTracker, self).__init__(int) |
|
|
|
|
|
def __add__(self, other: "MetricsTracker") -> "MetricsTracker": |
|
|
ans = MetricsTracker() |
|
|
for k, v in self.items(): |
|
|
ans[k] = v |
|
|
for k, v in other.items(): |
|
|
if v - v == 0: |
|
|
ans[k] = ans[k] + v |
|
|
return ans |
|
|
|
|
|
def __mul__(self, alpha: float) -> "MetricsTracker": |
|
|
ans = MetricsTracker() |
|
|
for k, v in self.items(): |
|
|
ans[k] = v * alpha |
|
|
return ans |
|
|
|
|
|
def __str__(self) -> str: |
|
|
ans_frames = "" |
|
|
ans_utterances = "" |
|
|
for k, v in self.norm_items(): |
|
|
norm_value = "%.4g" % v |
|
|
if "utt_" not in k: |
|
|
ans_frames += str(k) + "=" + str(norm_value) + ", " |
|
|
else: |
|
|
ans_utterances += str(k) + "=" + str(norm_value) |
|
|
if k == "utt_duration": |
|
|
ans_utterances += " frames, " |
|
|
elif k == "utt_pad_proportion": |
|
|
ans_utterances += ", " |
|
|
else: |
|
|
raise ValueError(f"Unexpected key: {k}") |
|
|
frames = "%.2f" % self["frames"] |
|
|
ans_frames += "over " + str(frames) + " frames. " |
|
|
if ans_utterances != "": |
|
|
utterances = "%.2f" % self["utterances"] |
|
|
ans_utterances += "over " + str(utterances) + " utterances." |
|
|
|
|
|
return ans_frames + ans_utterances |
|
|
|
|
|
def norm_items(self) -> List[Tuple[str, float]]: |
|
|
""" |
|
|
Returns a list of pairs, like: |
|
|
[('ctc_loss', 0.1), ('att_loss', 0.07)] |
|
|
""" |
|
|
num_frames = self["frames"] if "frames" in self else 1 |
|
|
num_utterances = self["utterances"] if "utterances" in self else 1 |
|
|
ans = [] |
|
|
for k, v in self.items(): |
|
|
if k == "frames" or k == "utterances": |
|
|
continue |
|
|
norm_value = ( |
|
|
float(v) / num_frames if "utt_" not in k else float(v) / num_utterances |
|
|
) |
|
|
ans.append((k, norm_value)) |
|
|
return ans |
|
|
|
|
|
def reduce(self, device): |
|
|
""" |
|
|
Reduce using torch.distributed, which I believe ensures that |
|
|
all processes get the total. |
|
|
""" |
|
|
keys = sorted(self.keys()) |
|
|
s = torch.tensor([float(self[k]) for k in keys], device=device) |
|
|
dist.all_reduce(s, op=dist.ReduceOp.SUM) |
|
|
for k, v in zip(keys, s.cpu().tolist()): |
|
|
self[k] = v |
|
|
|
|
|
def write_summary( |
|
|
self, |
|
|
tb_writer: SummaryWriter, |
|
|
prefix: str, |
|
|
batch_idx: int, |
|
|
) -> None: |
|
|
"""Add logging information to a TensorBoard writer. |
|
|
|
|
|
Args: |
|
|
tb_writer: a TensorBoard writer |
|
|
prefix: a prefix for the name of the loss, e.g. "train/valid_", |
|
|
or "train/current_" |
|
|
batch_idx: The current batch index, used as the x-axis of the plot. |
|
|
""" |
|
|
for k, v in self.norm_items(): |
|
|
tb_writer.add_scalar(prefix + k, v, batch_idx) |
|
|
|
|
|
|
|
|
def concat(ragged: k2.RaggedTensor, value: int, direction: str) -> k2.RaggedTensor: |
|
|
"""Prepend a value to the beginning of each sublist or append a value. |
|
|
to the end of each sublist. |
|
|
|
|
|
Args: |
|
|
ragged: |
|
|
A ragged tensor with two axes. |
|
|
value: |
|
|
The value to prepend or append. |
|
|
direction: |
|
|
It can be either "left" or "right". If it is "left", we |
|
|
prepend the value to the beginning of each sublist; |
|
|
if it is "right", we append the value to the end of each |
|
|
sublist. |
|
|
|
|
|
Returns: |
|
|
Return a new ragged tensor, whose sublists either start with |
|
|
or end with the given value. |
|
|
|
|
|
>>> a = k2.RaggedTensor([[1, 3], [5]]) |
|
|
>>> a |
|
|
[ [ 1 3 ] [ 5 ] ] |
|
|
>>> concat(a, value=0, direction="left") |
|
|
[ [ 0 1 3 ] [ 0 5 ] ] |
|
|
>>> concat(a, value=0, direction="right") |
|
|
[ [ 1 3 0 ] [ 5 0 ] ] |
|
|
|
|
|
""" |
|
|
dtype = ragged.dtype |
|
|
device = ragged.device |
|
|
|
|
|
assert ragged.num_axes == 2, f"num_axes: {ragged.num_axes}" |
|
|
pad_values = torch.full( |
|
|
size=(ragged.tot_size(0), 1), |
|
|
fill_value=value, |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
pad = k2.RaggedTensor(pad_values) |
|
|
|
|
|
if direction == "left": |
|
|
ans = k2.ragged.cat([pad, ragged], axis=1) |
|
|
elif direction == "right": |
|
|
ans = k2.ragged.cat([ragged, pad], axis=1) |
|
|
else: |
|
|
raise ValueError( |
|
|
f'Unsupported direction: {direction}. " \ |
|
|
"Expect either "left" or "right"' |
|
|
) |
|
|
return ans |
|
|
|
|
|
|
|
|
def add_sos(ragged: k2.RaggedTensor, sos_id: int) -> k2.RaggedTensor: |
|
|
"""Add SOS to each sublist. |
|
|
|
|
|
Args: |
|
|
ragged: |
|
|
A ragged tensor with two axes. |
|
|
sos_id: |
|
|
The ID of the SOS symbol. |
|
|
|
|
|
Returns: |
|
|
Return a new ragged tensor, where each sublist starts with SOS. |
|
|
|
|
|
>>> a = k2.RaggedTensor([[1, 3], [5]]) |
|
|
>>> a |
|
|
[ [ 1 3 ] [ 5 ] ] |
|
|
>>> add_sos(a, sos_id=0) |
|
|
[ [ 0 1 3 ] [ 0 5 ] ] |
|
|
|
|
|
""" |
|
|
return concat(ragged, sos_id, direction="left") |
|
|
|
|
|
|
|
|
def add_eos(ragged: k2.RaggedTensor, eos_id: int) -> k2.RaggedTensor: |
|
|
"""Add EOS to each sublist. |
|
|
|
|
|
Args: |
|
|
ragged: |
|
|
A ragged tensor with two axes. |
|
|
eos_id: |
|
|
The ID of the EOS symbol. |
|
|
|
|
|
Returns: |
|
|
Return a new ragged tensor, where each sublist ends with EOS. |
|
|
|
|
|
>>> a = k2.RaggedTensor([[1, 3], [5]]) |
|
|
>>> a |
|
|
[ [ 1 3 ] [ 5 ] ] |
|
|
>>> add_eos(a, eos_id=0) |
|
|
[ [ 1 3 0 ] [ 5 0 ] ] |
|
|
|
|
|
""" |
|
|
return concat(ragged, eos_id, direction="right") |
|
|
|
|
|
|
|
|
def make_pad_mask( |
|
|
lengths: torch.Tensor, |
|
|
max_len: int = 0, |
|
|
pad_left: bool = False, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
lengths: |
|
|
A 1-D tensor containing sentence lengths. |
|
|
max_len: |
|
|
The length of masks. |
|
|
pad_left: |
|
|
If ``False`` (default), padding is on the right. |
|
|
If ``True``, padding is on the left. |
|
|
Returns: |
|
|
Return a 2-D bool tensor, where masked positions |
|
|
are filled with `True` and non-masked positions are |
|
|
filled with `False`. |
|
|
|
|
|
>>> lengths = torch.tensor([1, 3, 2, 5]) |
|
|
>>> make_pad_mask(lengths) |
|
|
tensor([[False, True, True, True, True], |
|
|
[False, False, False, True, True], |
|
|
[False, False, True, True, True], |
|
|
[False, False, False, False, False]]) |
|
|
""" |
|
|
assert lengths.ndim == 1, lengths.ndim |
|
|
max_len = max(max_len, lengths.max()) |
|
|
n = lengths.size(0) |
|
|
seq_range = torch.arange(0, max_len, device=lengths.device) |
|
|
expanded_lengths = seq_range.unsqueeze(0).expand(n, max_len) |
|
|
|
|
|
if pad_left: |
|
|
mask = expanded_lengths < (max_len - lengths).unsqueeze(1) |
|
|
else: |
|
|
mask = expanded_lengths >= lengths.unsqueeze(-1) |
|
|
|
|
|
return mask |
|
|
|
|
|
|
|
|
|
|
|
def subsequent_chunk_mask( |
|
|
size: int, |
|
|
chunk_size: int, |
|
|
num_left_chunks: int = -1, |
|
|
device: torch.device = torch.device("cpu"), |
|
|
) -> torch.Tensor: |
|
|
"""Create mask for subsequent steps (size, size) with chunk size, |
|
|
this is for streaming encoder |
|
|
Args: |
|
|
size (int): size of mask |
|
|
chunk_size (int): size of chunk |
|
|
num_left_chunks (int): number of left chunks |
|
|
<0: use full chunk |
|
|
>=0: use num_left_chunks |
|
|
device (torch.device): "cpu" or "cuda" or torch.Tensor.device |
|
|
Returns: |
|
|
torch.Tensor: mask |
|
|
Examples: |
|
|
>>> subsequent_chunk_mask(4, 2) |
|
|
[[1, 1, 0, 0], |
|
|
[1, 1, 0, 0], |
|
|
[1, 1, 1, 1], |
|
|
[1, 1, 1, 1]] |
|
|
""" |
|
|
ret = torch.zeros(size, size, device=device, dtype=torch.bool) |
|
|
for i in range(size): |
|
|
if num_left_chunks < 0: |
|
|
start = 0 |
|
|
else: |
|
|
start = max((i // chunk_size - num_left_chunks) * chunk_size, 0) |
|
|
ending = min((i // chunk_size + 1) * chunk_size, size) |
|
|
ret[i, start:ending] = True |
|
|
return ret |
|
|
|
|
|
|
|
|
def l1_norm(x): |
|
|
return torch.sum(torch.abs(x)) |
|
|
|
|
|
|
|
|
def l2_norm(x): |
|
|
return torch.sum(torch.pow(x, 2)) |
|
|
|
|
|
|
|
|
def linf_norm(x): |
|
|
return torch.max(torch.abs(x)) |
|
|
|
|
|
|
|
|
def measure_weight_norms(model: nn.Module, norm: str = "l2") -> Dict[str, float]: |
|
|
""" |
|
|
Compute the norms of the model's parameters. |
|
|
|
|
|
:param model: a torch.nn.Module instance |
|
|
:param norm: how to compute the norm. Available values: 'l1', 'l2', 'linf' |
|
|
:return: a dict mapping from parameter's name to its norm. |
|
|
""" |
|
|
with torch.no_grad(): |
|
|
norms = {} |
|
|
for name, param in model.named_parameters(): |
|
|
if norm == "l1": |
|
|
val = l1_norm(param) |
|
|
elif norm == "l2": |
|
|
val = l2_norm(param) |
|
|
elif norm == "linf": |
|
|
val = linf_norm(param) |
|
|
else: |
|
|
raise ValueError(f"Unknown norm type: {norm}") |
|
|
norms[name] = val.item() |
|
|
return norms |
|
|
|
|
|
|
|
|
def measure_gradient_norms(model: nn.Module, norm: str = "l1") -> Dict[str, float]: |
|
|
""" |
|
|
Compute the norms of the gradients for each of model's parameters. |
|
|
|
|
|
:param model: a torch.nn.Module instance |
|
|
:param norm: how to compute the norm. Available values: 'l1', 'l2', 'linf' |
|
|
:return: a dict mapping from parameter's name to its gradient's norm. |
|
|
""" |
|
|
with torch.no_grad(): |
|
|
norms = {} |
|
|
for name, param in model.named_parameters(): |
|
|
if norm == "l1": |
|
|
val = l1_norm(param.grad) |
|
|
elif norm == "l2": |
|
|
val = l2_norm(param.grad) |
|
|
elif norm == "linf": |
|
|
val = linf_norm(param.grad) |
|
|
else: |
|
|
raise ValueError(f"Unknown norm type: {norm}") |
|
|
norms[name] = val.item() |
|
|
return norms |
|
|
|
|
|
|
|
|
def get_parameter_groups_with_lrs( |
|
|
model: nn.Module, |
|
|
lr: float, |
|
|
include_names: bool = False, |
|
|
freeze_modules: List[str] = [], |
|
|
) -> List[dict]: |
|
|
""" |
|
|
This is for use with the ScaledAdam optimizers (more recent versions that accept lists of |
|
|
named-parameters; we can, if needed, create a version without the names). |
|
|
|
|
|
It provides a way to specify learning-rate scales inside the module, so that if |
|
|
any nn.Module in the hierarchy has a floating-point parameter 'lr_scale', it will |
|
|
scale the LR of any parameters inside that module or its submodules. Note: you |
|
|
can set module parameters outside the __init__ function, e.g.: |
|
|
>>> a = nn.Linear(10, 10) |
|
|
>>> a.lr_scale = 0.5 |
|
|
|
|
|
Returns: a list of dicts, of the following form: |
|
|
if include_names == False: |
|
|
[ { 'params': [ tensor1, tensor2, ... ], 'lr': 0.01 }, |
|
|
{ 'params': [ tensor3, tensor4, ... ], 'lr': 0.005 }, |
|
|
... ] |
|
|
if include_names == true: |
|
|
[ { 'named_params': [ (name1, tensor1, (name2, tensor2), ... ], 'lr': 0.01 }, |
|
|
{ 'named_params': [ (name3, tensor3), (name4, tensor4), ... ], 'lr': 0.005 }, |
|
|
... ] |
|
|
|
|
|
""" |
|
|
named_modules = list(model.named_modules()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flat_lr_scale = defaultdict(lambda: 1.0) |
|
|
names = [] |
|
|
for name, m in model.named_modules(): |
|
|
names.append(name) |
|
|
if hasattr(m, "lr_scale"): |
|
|
flat_lr_scale[name] = m.lr_scale |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lr_to_params = defaultdict(list) |
|
|
|
|
|
for name, parameter in model.named_parameters(): |
|
|
split_name = name.split(".") |
|
|
|
|
|
prefix = split_name[0] |
|
|
if prefix == "module": |
|
|
module_name = split_name[1] |
|
|
if module_name in freeze_modules: |
|
|
logging.info(f"Remove {name} from parameters") |
|
|
continue |
|
|
else: |
|
|
if prefix in freeze_modules: |
|
|
logging.info(f"Remove {name} from parameters") |
|
|
continue |
|
|
cur_lr = lr * flat_lr_scale[prefix] |
|
|
if prefix != "": |
|
|
cur_lr *= flat_lr_scale[""] |
|
|
for part in split_name[1:]: |
|
|
prefix = ".".join([prefix, part]) |
|
|
cur_lr *= flat_lr_scale[prefix] |
|
|
lr_to_params[cur_lr].append((name, parameter) if include_names else parameter) |
|
|
|
|
|
if include_names: |
|
|
return [{"named_params": pairs, "lr": lr} for lr, pairs in lr_to_params.items()] |
|
|
else: |
|
|
return [{"params": params, "lr": lr} for lr, params in lr_to_params.items()] |
|
|
|
|
|
|
|
|
def optim_step_and_measure_param_change( |
|
|
model: nn.Module, |
|
|
old_parameters: Dict[str, nn.parameter.Parameter], |
|
|
) -> Dict[str, float]: |
|
|
""" |
|
|
Measure the "relative change in parameters per minibatch." |
|
|
It is understood as a ratio between the L2 norm of the difference between original and updates parameters, |
|
|
and the L2 norm of the original parameter. It is given by the formula: |
|
|
|
|
|
.. math:: |
|
|
|
|
|
\begin{aligned} |
|
|
\delta = \frac{\Vert\theta - \theta_{new}\Vert^2}{\Vert\theta\Vert^2} |
|
|
\end{aligned} |
|
|
|
|
|
This function is supposed to be used as follows: |
|
|
|
|
|
.. code-block:: python |
|
|
|
|
|
old_parameters = { |
|
|
n: p.detach().clone() for n, p in model.named_parameters() |
|
|
} |
|
|
|
|
|
optimizer.step() |
|
|
|
|
|
deltas = optim_step_and_measure_param_change(old_parameters) |
|
|
|
|
|
Args: |
|
|
model: A torch.nn.Module instance. |
|
|
old_parameters: |
|
|
A Dict of named_parameters before optimizer.step(). |
|
|
|
|
|
Return: |
|
|
A Dict containing the relative change for each parameter. |
|
|
""" |
|
|
relative_change = {} |
|
|
with torch.no_grad(): |
|
|
for n, p_new in model.named_parameters(): |
|
|
p_orig = old_parameters[n] |
|
|
delta = l2_norm(p_orig - p_new) / l2_norm(p_orig) |
|
|
relative_change[n] = delta.item() |
|
|
return relative_change |
|
|
|
|
|
|
|
|
def load_averaged_model( |
|
|
model_dir: str, |
|
|
model: torch.nn.Module, |
|
|
epoch: int, |
|
|
avg: int, |
|
|
device: torch.device, |
|
|
): |
|
|
""" |
|
|
Load a model which is the average of all checkpoints |
|
|
|
|
|
:param model_dir: a str of the experiment directory |
|
|
:param model: a torch.nn.Module instance |
|
|
|
|
|
:param epoch: the last epoch to load from |
|
|
:param avg: how many models to average from |
|
|
:param device: move model to this device |
|
|
|
|
|
:return: A model averaged |
|
|
""" |
|
|
|
|
|
|
|
|
start = max(epoch - avg + 1, 0) |
|
|
filenames = [f"{model_dir}/epoch-{i}.pt" for i in range(start, epoch + 1)] |
|
|
|
|
|
logging.info(f"averaging {filenames}") |
|
|
model.to(device) |
|
|
model.load_state_dict(average_checkpoints(filenames, device=device)) |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
def text_to_pinyin( |
|
|
txt: str, mode: str = "full_with_tone", errors: str = "default" |
|
|
) -> List[str]: |
|
|
""" |
|
|
Convert a Chinese text (might contain some latin characters) to pinyin sequence. |
|
|
|
|
|
Args: |
|
|
txt: |
|
|
The input Chinese text. |
|
|
mode: |
|
|
The style of the output pinyin, should be: |
|
|
full_with_tone : zhōng guó |
|
|
full_no_tone : zhong guo |
|
|
partial_with_tone : zh ōng g uó |
|
|
partial_no_tone : zh ong g uo |
|
|
errors: |
|
|
How to handle the characters (latin) that has no pinyin. |
|
|
default : output the same as input. |
|
|
split : split into single characters (i.e. alphabets) |
|
|
|
|
|
Return: |
|
|
Return a list of str. |
|
|
|
|
|
Examples: |
|
|
txt: 想吃KFC |
|
|
output: ['xiǎng', 'chī', 'KFC'] # mode=full_with_tone; errors=default |
|
|
output: ['xiǎng', 'chī', 'K', 'F', 'C'] # mode=full_with_tone; errors=split |
|
|
output: ['xiang', 'chi', 'KFC'] # mode=full_no_tone; errors=default |
|
|
output: ['xiang', 'chi', 'K', 'F', 'C'] # mode=full_no_tone; errors=split |
|
|
output: ['x', 'iǎng', 'ch', 'ī', 'KFC'] # mode=partial_with_tone; errors=default |
|
|
output: ['x', 'iang', 'ch', 'i', 'KFC'] # mode=partial_no_tone; errors=default |
|
|
""" |
|
|
|
|
|
assert mode in ( |
|
|
"full_with_tone", |
|
|
"full_no_tone", |
|
|
"partial_no_tone", |
|
|
"partial_with_tone", |
|
|
), mode |
|
|
|
|
|
assert errors in ("default", "split"), errors |
|
|
|
|
|
txt = txt.strip() |
|
|
res = [] |
|
|
if "full" in mode: |
|
|
if errors == "default": |
|
|
py = pinyin(txt) if mode == "full_with_tone" else lazy_pinyin(txt) |
|
|
else: |
|
|
py = ( |
|
|
pinyin(txt, errors=lambda x: list(x)) |
|
|
if mode == "full_with_tone" |
|
|
else lazy_pinyin(txt, errors=lambda x: list(x)) |
|
|
) |
|
|
res = [x[0] for x in py] if mode == "full_with_tone" else py |
|
|
else: |
|
|
if errors == "default": |
|
|
py = pinyin(txt) if mode == "partial_with_tone" else lazy_pinyin(txt) |
|
|
else: |
|
|
py = ( |
|
|
pinyin(txt, errors=lambda x: list(x)) |
|
|
if mode == "partial_with_tone" |
|
|
else lazy_pinyin(txt, errors=lambda x: list(x)) |
|
|
) |
|
|
py = [x[0] for x in py] if mode == "partial_with_tone" else py |
|
|
for x in py: |
|
|
initial = to_initials(x, strict=False) |
|
|
final = ( |
|
|
to_finals(x, strict=False) |
|
|
if mode == "partial_no_tone" |
|
|
else to_finals_tone(x, strict=False) |
|
|
) |
|
|
if initial == "" and final == "": |
|
|
res.append(x) |
|
|
else: |
|
|
if initial != "": |
|
|
res.append(initial) |
|
|
if final != "": |
|
|
res.append(final) |
|
|
return res |
|
|
|
|
|
|
|
|
def tokenize_by_bpe_model( |
|
|
sp: spm.SentencePieceProcessor, |
|
|
txt: str, |
|
|
) -> str: |
|
|
""" |
|
|
Tokenize text with bpe model. This function is from |
|
|
https://github1s.com/wenet-e2e/wenet/blob/main/wenet/dataset/processor.py#L322-L342. |
|
|
Args: |
|
|
sp: spm.SentencePieceProcessor. |
|
|
txt: str |
|
|
|
|
|
Return: |
|
|
A new string which includes chars and bpes. |
|
|
""" |
|
|
tokens = [] |
|
|
|
|
|
|
|
|
pattern = re.compile(r"([\u4e00-\u9fff])") |
|
|
|
|
|
|
|
|
|
|
|
chars = pattern.split(txt.upper()) |
|
|
mix_chars = [w for w in chars if len(w.strip()) > 0] |
|
|
for ch_or_w in mix_chars: |
|
|
|
|
|
if pattern.fullmatch(ch_or_w) is not None: |
|
|
tokens.append(ch_or_w) |
|
|
|
|
|
|
|
|
else: |
|
|
for p in sp.encode_as_pieces(ch_or_w): |
|
|
tokens.append(p) |
|
|
txt_with_bpe = "/".join(tokens) |
|
|
|
|
|
return txt_with_bpe |
|
|
|
|
|
|
|
|
def tokenize_by_CJK_char(line: str) -> str: |
|
|
""" |
|
|
Tokenize a line of text with CJK char. |
|
|
|
|
|
Note: All return characters will be upper case. |
|
|
|
|
|
Example: |
|
|
input = "你好世界是 hello world 的中文" |
|
|
output = "你 好 世 界 是 HELLO WORLD 的 中 文" |
|
|
|
|
|
Args: |
|
|
line: |
|
|
The input text. |
|
|
|
|
|
Return: |
|
|
A new string tokenize by CJK char. |
|
|
""" |
|
|
|
|
|
pattern = re.compile( |
|
|
r"([\u1100-\u11ff\u2e80-\ua4cf\ua840-\uD7AF\uF900-\uFAFF\uFE30-\uFE4F\uFF65-\uFFDC\U00020000-\U0002FFFF])" |
|
|
) |
|
|
chars = pattern.split(line.strip().upper()) |
|
|
return " ".join([w.strip() for w in chars if w.strip()]) |
|
|
|
|
|
|
|
|
def tokenize_by_ja_char(line: str) -> str: |
|
|
""" |
|
|
Tokenize a line of text with Japanese characters. |
|
|
|
|
|
Note: All non-Japanese characters will be upper case. |
|
|
|
|
|
Example: |
|
|
input = "こんにちは世界は hello world の日本語" |
|
|
output = "こ ん に ち は 世 界 は HELLO WORLD の 日 本 語" |
|
|
|
|
|
Args: |
|
|
line: |
|
|
The input text. |
|
|
|
|
|
Return: |
|
|
A new string tokenized by Japanese characters. |
|
|
""" |
|
|
pattern = re.compile(r"([\u3040-\u309F\u30A0-\u30FF\u4E00-\u9FFF])") |
|
|
chars = pattern.split(line.strip()) |
|
|
return " ".join( |
|
|
[w.strip().upper() if not pattern.match(w) else w for w in chars if w.strip()] |
|
|
) |
|
|
|
|
|
|
|
|
def display_and_save_batch( |
|
|
batch: dict, |
|
|
params: AttributeDict, |
|
|
sp: spm.SentencePieceProcessor, |
|
|
) -> None: |
|
|
"""Display the batch statistics and save the batch into disk. |
|
|
|
|
|
Args: |
|
|
batch: |
|
|
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` |
|
|
for the content in it. |
|
|
params: |
|
|
Parameters for training. See :func:`get_params`. |
|
|
sp: |
|
|
The BPE model. |
|
|
""" |
|
|
from lhotse.utils import uuid4 |
|
|
|
|
|
filename = f"{params.exp_dir}/batch-{uuid4()}.pt" |
|
|
logging.info(f"Saving batch to {filename}") |
|
|
torch.save(batch, filename) |
|
|
|
|
|
supervisions = batch["supervisions"] |
|
|
features = batch["inputs"] |
|
|
|
|
|
logging.info(f"features shape: {features.shape}") |
|
|
|
|
|
y = sp.encode(supervisions["text"], out_type=int) |
|
|
num_tokens = sum(len(i) for i in y) |
|
|
logging.info(f"num tokens: {num_tokens}") |
|
|
|
|
|
|
|
|
def convert_timestamp( |
|
|
frames: List[int], |
|
|
subsampling_factor: int, |
|
|
frame_shift_ms: float = 10, |
|
|
) -> List[float]: |
|
|
"""Convert frame numbers to time (in seconds) given subsampling factor |
|
|
and frame shift (in milliseconds). |
|
|
|
|
|
Args: |
|
|
frames: |
|
|
A list of frame numbers after subsampling. |
|
|
subsampling_factor: |
|
|
The subsampling factor of the model. |
|
|
frame_shift_ms: |
|
|
Frame shift in milliseconds between two contiguous frames. |
|
|
Return: |
|
|
Return the time in seconds corresponding to each given frame. |
|
|
""" |
|
|
frame_shift = frame_shift_ms / 1000.0 |
|
|
time = [] |
|
|
for f in frames: |
|
|
time.append(round(f * subsampling_factor * frame_shift, ndigits=3)) |
|
|
|
|
|
return time |
|
|
|
|
|
|
|
|
def parse_timestamp(tokens: List[str], timestamp: List[float]) -> List[float]: |
|
|
""" |
|
|
Parse timestamp of each word. |
|
|
|
|
|
Args: |
|
|
tokens: |
|
|
List of tokens. |
|
|
timestamp: |
|
|
List of timestamp of each token. |
|
|
|
|
|
Returns: |
|
|
List of timestamp of each word. |
|
|
""" |
|
|
start_token = b"\xe2\x96\x81".decode() |
|
|
assert len(tokens) == len(timestamp), (len(tokens), len(timestamp)) |
|
|
ans = [] |
|
|
for i in range(len(tokens)): |
|
|
flag = False |
|
|
if i == 0 or tokens[i].startswith(start_token): |
|
|
flag = True |
|
|
if len(tokens[i]) == 1 and tokens[i].startswith(start_token): |
|
|
|
|
|
if i == len(tokens) - 1: |
|
|
|
|
|
flag = False |
|
|
elif tokens[i + 1].startswith(start_token): |
|
|
|
|
|
flag = False |
|
|
if flag: |
|
|
ans.append(timestamp[i]) |
|
|
return ans |
|
|
|
|
|
|
|
|
def parse_hyp_and_timestamp( |
|
|
res: DecodingResults, |
|
|
subsampling_factor: int, |
|
|
frame_shift_ms: float = 10, |
|
|
sp: Optional[spm.SentencePieceProcessor] = None, |
|
|
word_table: Optional[k2.SymbolTable] = None, |
|
|
) -> Tuple[List[List[str]], List[List[float]]]: |
|
|
"""Parse hypothesis and timestamp. |
|
|
|
|
|
Args: |
|
|
res: |
|
|
A DecodingResults object. |
|
|
subsampling_factor: |
|
|
The integer subsampling factor. |
|
|
frame_shift_ms: |
|
|
The float frame shift used for feature extraction. |
|
|
sp: |
|
|
The BPE model. |
|
|
word_table: |
|
|
The word symbol table. |
|
|
|
|
|
Returns: |
|
|
Return a list of hypothesis and timestamp. |
|
|
""" |
|
|
hyps = [] |
|
|
timestamps = [] |
|
|
|
|
|
N = len(res.hyps) |
|
|
assert len(res.timestamps) == N, (len(res.timestamps), N) |
|
|
use_word_table = False |
|
|
if word_table is not None: |
|
|
assert sp is None |
|
|
use_word_table = True |
|
|
else: |
|
|
assert sp is not None and word_table is None |
|
|
|
|
|
for i in range(N): |
|
|
time = convert_timestamp(res.timestamps[i], subsampling_factor, frame_shift_ms) |
|
|
if use_word_table: |
|
|
words = [word_table[i] for i in res.hyps[i]] |
|
|
else: |
|
|
tokens = sp.id_to_piece(res.hyps[i]) |
|
|
words = sp.decode_pieces(tokens).split() |
|
|
time = parse_timestamp(tokens, time) |
|
|
assert len(time) == len(words), (len(time), len(words)) |
|
|
|
|
|
hyps.append(words) |
|
|
timestamps.append(time) |
|
|
|
|
|
return hyps, timestamps |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def is_module_available(*modules: str) -> bool: |
|
|
r"""Returns if a top-level module with :attr:`name` exists *without** |
|
|
importing it. This is generally safer than try-catch block around a |
|
|
`import X`. |
|
|
|
|
|
Note: "borrowed" from torchaudio: |
|
|
""" |
|
|
import importlib |
|
|
|
|
|
return all(importlib.util.find_spec(m) is not None for m in modules) |
|
|
|
|
|
|
|
|
def filter_uneven_sized_batch(batch: dict, allowed_max_frames: int): |
|
|
"""For the uneven-sized batch, the total duration after padding would possibly |
|
|
cause OOM. Hence, for each batch, which is sorted in descending order by length, |
|
|
we simply drop the last few shortest samples, so that the retained total frames |
|
|
(after padding) would not exceed the given allow_max_frames. |
|
|
|
|
|
Args: |
|
|
batch: |
|
|
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` |
|
|
for the content in it. |
|
|
allowed_max_frames: |
|
|
The allowed max number of frames in batch. |
|
|
""" |
|
|
features = batch["inputs"] |
|
|
supervisions = batch["supervisions"] |
|
|
|
|
|
N, T, _ = features.size() |
|
|
assert T == supervisions["num_frames"].max(), (T, supervisions["num_frames"].max()) |
|
|
kept_num_utt = allowed_max_frames // T |
|
|
|
|
|
if kept_num_utt >= N or kept_num_utt == 0: |
|
|
return batch |
|
|
|
|
|
|
|
|
logging.info( |
|
|
f"Filtering uneven-sized batch, original batch size is {N}, " |
|
|
f"retained batch size is {kept_num_utt}." |
|
|
) |
|
|
batch["inputs"] = features[:kept_num_utt] |
|
|
for k, v in supervisions.items(): |
|
|
assert len(v) == N, (len(v), N) |
|
|
batch["supervisions"][k] = v[:kept_num_utt] |
|
|
|
|
|
return batch |
|
|
|
|
|
|
|
|
def parse_bpe_start_end_pairs( |
|
|
tokens: List[str], is_first_token: List[bool] |
|
|
) -> List[Tuple[int, int]]: |
|
|
"""Parse pairs of start and end frame indexes for each word. |
|
|
|
|
|
Args: |
|
|
tokens: |
|
|
List of BPE tokens. |
|
|
is_first_token: |
|
|
List of bool values, which indicates whether it is the first token, |
|
|
i.e., not repeat or blank. |
|
|
|
|
|
Returns: |
|
|
List of (start-frame-index, end-frame-index) pairs for each word. |
|
|
""" |
|
|
assert len(tokens) == len(is_first_token), (len(tokens), len(is_first_token)) |
|
|
|
|
|
start_token = b"\xe2\x96\x81".decode() |
|
|
blank_token = "<blk>" |
|
|
|
|
|
non_blank_idx = [i for i in range(len(tokens)) if tokens[i] != blank_token] |
|
|
num_non_blank = len(non_blank_idx) |
|
|
|
|
|
pairs = [] |
|
|
start = -1 |
|
|
end = -1 |
|
|
for j in range(num_non_blank): |
|
|
|
|
|
i = non_blank_idx[j] |
|
|
|
|
|
found_start = False |
|
|
if is_first_token[i] and (j == 0 or tokens[i].startswith(start_token)): |
|
|
found_start = True |
|
|
if tokens[i] == start_token: |
|
|
if j == num_non_blank - 1: |
|
|
|
|
|
found_start = False |
|
|
elif is_first_token[non_blank_idx[j + 1]] and tokens[ |
|
|
non_blank_idx[j + 1] |
|
|
].startswith(start_token): |
|
|
|
|
|
found_start = False |
|
|
if found_start: |
|
|
start = i |
|
|
|
|
|
if start != -1: |
|
|
found_end = False |
|
|
if j == num_non_blank - 1: |
|
|
|
|
|
found_end = True |
|
|
elif is_first_token[non_blank_idx[j + 1]] and tokens[ |
|
|
non_blank_idx[j + 1] |
|
|
].startswith(start_token): |
|
|
|
|
|
found_end = True |
|
|
if found_end: |
|
|
end = i |
|
|
|
|
|
if start != -1 and end != -1: |
|
|
if not all([tokens[t] == start_token for t in range(start, end + 1)]): |
|
|
|
|
|
pairs.append((start, end)) |
|
|
|
|
|
start = -1 |
|
|
end = -1 |
|
|
|
|
|
return pairs |
|
|
|
|
|
|
|
|
def parse_bpe_timestamps_and_texts( |
|
|
best_paths: k2.Fsa, sp: spm.SentencePieceProcessor |
|
|
) -> Tuple[List[Tuple[int, int]], List[List[str]]]: |
|
|
"""Parse timestamps (frame indexes) and texts. |
|
|
|
|
|
Args: |
|
|
best_paths: |
|
|
A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e. |
|
|
containing multiple FSAs, which is expected to be the result |
|
|
of k2.shortest_path (otherwise the returned values won't |
|
|
be meaningful). Its attributes `labels` and `aux_labels` |
|
|
are both BPE tokens. |
|
|
sp: |
|
|
The BPE model. |
|
|
|
|
|
Returns: |
|
|
utt_index_pairs: |
|
|
A list of pair list. utt_index_pairs[i] is a list of |
|
|
(start-frame-index, end-frame-index) pairs for each word in |
|
|
utterance-i. |
|
|
utt_words: |
|
|
A list of str list. utt_words[i] is a word list of utterence-i. |
|
|
""" |
|
|
shape = best_paths.arcs.shape().remove_axis(1) |
|
|
|
|
|
|
|
|
labels = k2.RaggedTensor(shape, best_paths.labels.contiguous()) |
|
|
|
|
|
labels = labels.remove_values_eq(-1) |
|
|
labels = labels.tolist() |
|
|
|
|
|
|
|
|
aux_labels = k2.RaggedTensor(shape, best_paths.aux_labels.contiguous()) |
|
|
|
|
|
|
|
|
all_aux_labels = aux_labels.remove_values_eq(-1) |
|
|
|
|
|
all_aux_labels = all_aux_labels.tolist() |
|
|
|
|
|
|
|
|
out_aux_labels = aux_labels.remove_values_leq(0) |
|
|
|
|
|
out_aux_labels = out_aux_labels.tolist() |
|
|
|
|
|
utt_index_pairs = [] |
|
|
utt_words = [] |
|
|
for i in range(len(labels)): |
|
|
tokens = sp.id_to_piece(labels[i]) |
|
|
words = sp.decode(out_aux_labels[i]).split() |
|
|
|
|
|
|
|
|
is_first_token = [a != 0 for a in all_aux_labels[i]] |
|
|
index_pairs = parse_bpe_start_end_pairs(tokens, is_first_token) |
|
|
assert len(index_pairs) == len(words), (len(index_pairs), len(words), tokens) |
|
|
utt_index_pairs.append(index_pairs) |
|
|
utt_words.append(words) |
|
|
|
|
|
return utt_index_pairs, utt_words |
|
|
|
|
|
|
|
|
def parse_timestamps_and_texts( |
|
|
best_paths: k2.Fsa, word_table: k2.SymbolTable |
|
|
) -> Tuple[List[Tuple[int, int]], List[List[str]]]: |
|
|
"""Parse timestamps (frame indexes) and texts. |
|
|
|
|
|
Args: |
|
|
best_paths: |
|
|
A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e. |
|
|
containing multiple FSAs, which is expected to be the result |
|
|
of k2.shortest_path (otherwise the returned values won't |
|
|
be meaningful). Attribute `labels` is the prediction unit, |
|
|
e.g., phone or BPE tokens. Attribute `aux_labels` is the word index. |
|
|
word_table: |
|
|
The word symbol table. |
|
|
|
|
|
Returns: |
|
|
utt_index_pairs: |
|
|
A list of pair list. utt_index_pairs[i] is a list of |
|
|
(start-frame-index, end-frame-index) pairs for each word in |
|
|
utterance-i. |
|
|
utt_words: |
|
|
A list of str list. utt_words[i] is a word list of utterence-i. |
|
|
""" |
|
|
|
|
|
word_ids = get_texts(best_paths) |
|
|
|
|
|
shape = best_paths.arcs.shape().remove_axis(1) |
|
|
|
|
|
|
|
|
labels = k2.RaggedTensor(shape, best_paths.labels.contiguous()) |
|
|
|
|
|
labels = labels.remove_values_eq(-1) |
|
|
labels = labels.tolist() |
|
|
|
|
|
|
|
|
aux_shape = shape.compose(best_paths.aux_labels.shape) |
|
|
aux_labels = k2.RaggedTensor(aux_shape, best_paths.aux_labels.values.contiguous()) |
|
|
aux_labels = aux_labels.tolist() |
|
|
|
|
|
utt_index_pairs = [] |
|
|
utt_words = [] |
|
|
for i, (label, aux_label) in enumerate(zip(labels, aux_labels)): |
|
|
num_arcs = len(label) |
|
|
|
|
|
assert num_arcs == len(aux_label) - 1, (num_arcs, len(aux_label)) |
|
|
|
|
|
index_pairs = [] |
|
|
start = -1 |
|
|
end = -1 |
|
|
for arc in range(num_arcs): |
|
|
|
|
|
if label[arc] != 0 and len(aux_label[arc]) != 0: |
|
|
if start != -1 and end != -1: |
|
|
index_pairs.append((start, end)) |
|
|
start = arc |
|
|
if label[arc] != 0: |
|
|
end = arc |
|
|
if start != -1 and end != -1: |
|
|
index_pairs.append((start, end)) |
|
|
|
|
|
words = [word_table[w] for w in word_ids[i]] |
|
|
assert len(index_pairs) == len(words), (len(index_pairs), len(words)) |
|
|
|
|
|
utt_index_pairs.append(index_pairs) |
|
|
utt_words.append(words) |
|
|
|
|
|
return utt_index_pairs, utt_words |
|
|
|
|
|
|
|
|
def parse_fsa_timestamps_and_texts( |
|
|
best_paths: k2.Fsa, |
|
|
sp: Optional[spm.SentencePieceProcessor] = None, |
|
|
word_table: Optional[k2.SymbolTable] = None, |
|
|
subsampling_factor: int = 4, |
|
|
frame_shift_ms: float = 10, |
|
|
) -> Tuple[List[Tuple[float, float]], List[List[str]]]: |
|
|
"""Parse timestamps (in seconds) and texts for given decoded fsa paths. |
|
|
Currently it supports two cases: |
|
|
(1) ctc-decoding, the attributes `labels` and `aux_labels` |
|
|
are both BPE tokens. In this case, sp should be provided. |
|
|
(2) HLG-based 1best, the attribtute `labels` is the prediction unit, |
|
|
e.g., phone or BPE tokens; attribute `aux_labels` is the word index. |
|
|
In this case, word_table should be provided. |
|
|
|
|
|
Args: |
|
|
best_paths: |
|
|
A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e. |
|
|
containing multiple FSAs, which is expected to be the result |
|
|
of k2.shortest_path (otherwise the returned values won't |
|
|
be meaningful). |
|
|
sp: |
|
|
The BPE model. |
|
|
word_table: |
|
|
The word symbol table. |
|
|
subsampling_factor: |
|
|
The subsampling factor of the model. |
|
|
frame_shift_ms: |
|
|
Frame shift in milliseconds between two contiguous frames. |
|
|
|
|
|
Returns: |
|
|
utt_time_pairs: |
|
|
A list of pair list. utt_time_pairs[i] is a list of |
|
|
(start-time, end-time) pairs for each word in |
|
|
utterance-i. |
|
|
utt_words: |
|
|
A list of str list. utt_words[i] is a word list of utterence-i. |
|
|
""" |
|
|
if sp is not None: |
|
|
assert word_table is None, "word_table is not needed if sp is provided." |
|
|
utt_index_pairs, utt_words = parse_bpe_timestamps_and_texts( |
|
|
best_paths=best_paths, sp=sp |
|
|
) |
|
|
elif word_table is not None: |
|
|
assert sp is None, "sp is not needed if word_table is provided." |
|
|
utt_index_pairs, utt_words = parse_timestamps_and_texts( |
|
|
best_paths=best_paths, word_table=word_table |
|
|
) |
|
|
else: |
|
|
raise ValueError("Either sp or word_table should be provided.") |
|
|
|
|
|
utt_time_pairs = [] |
|
|
for utt in utt_index_pairs: |
|
|
start = convert_timestamp( |
|
|
frames=[i[0] for i in utt], |
|
|
subsampling_factor=subsampling_factor, |
|
|
frame_shift_ms=frame_shift_ms, |
|
|
) |
|
|
end = convert_timestamp( |
|
|
|
|
|
frames=[i[1] + 1 for i in utt], |
|
|
subsampling_factor=subsampling_factor, |
|
|
frame_shift_ms=frame_shift_ms, |
|
|
) |
|
|
utt_time_pairs.append(list(zip(start, end))) |
|
|
|
|
|
return utt_time_pairs, utt_words |
|
|
|
|
|
|
|
|
|
|
|
def is_cjk(character): |
|
|
""" |
|
|
Python port of Moses' code to check for CJK character. |
|
|
|
|
|
>>> is_cjk(u'\u33fe') |
|
|
True |
|
|
>>> is_cjk(u'\uFE5F') |
|
|
False |
|
|
|
|
|
:param character: The character that needs to be checked. |
|
|
:type character: char |
|
|
:return: bool |
|
|
""" |
|
|
return any( |
|
|
[ |
|
|
start <= ord(character) <= end |
|
|
for start, end in [ |
|
|
(4352, 4607), |
|
|
(11904, 42191), |
|
|
(43072, 43135), |
|
|
(44032, 55215), |
|
|
(63744, 64255), |
|
|
(65072, 65103), |
|
|
(65381, 65500), |
|
|
(131072, 196607), |
|
|
] |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
def symlink_or_copy(exp_dir: Path, src: str, dst: str): |
|
|
""" |
|
|
In the experiment directory, create a symlink pointing to src named dst. |
|
|
If symlink creation fails (Windows?), fall back to copyfile.""" |
|
|
|
|
|
dir_fd = os.open(exp_dir, os.O_RDONLY) |
|
|
try: |
|
|
os.remove(dst, dir_fd=dir_fd) |
|
|
except FileNotFoundError: |
|
|
pass |
|
|
try: |
|
|
os.symlink(src=src, dst=dst, dir_fd=dir_fd) |
|
|
except OSError: |
|
|
copyfile(src=exp_dir / src, dst=exp_dir / dst) |
|
|
os.close(dir_fd) |
|
|
|
|
|
|
|
|
def num_tokens( |
|
|
token_table: k2.SymbolTable, disambig_pattern: str = re.compile(r"^#\d+$") |
|
|
) -> int: |
|
|
"""Return the number of tokens excluding those from |
|
|
disambiguation symbols. |
|
|
|
|
|
Caution: |
|
|
0 is not a token ID so it is excluded from the return value. |
|
|
""" |
|
|
symbols = token_table.symbols |
|
|
ans = [] |
|
|
for s in symbols: |
|
|
if not disambig_pattern.match(s): |
|
|
ans.append(token_table[s]) |
|
|
num_tokens = len(ans) |
|
|
if 0 in ans: |
|
|
num_tokens -= 1 |
|
|
return num_tokens |
|
|
|
|
|
|
|
|
|
|
|
def time_warp( |
|
|
features: torch.Tensor, |
|
|
p: float = 0.9, |
|
|
time_warp_factor: Optional[int] = 80, |
|
|
supervision_segments: Optional[torch.Tensor] = None, |
|
|
): |
|
|
"""Apply time warping on a batch of features""" |
|
|
if time_warp_factor is None or time_warp_factor < 1: |
|
|
return features |
|
|
assert ( |
|
|
len(features.shape) == 3 |
|
|
), f"SpecAugment only supports batches of single-channel feature matrices. {features.shape}" |
|
|
features = features.clone() |
|
|
if supervision_segments is None: |
|
|
|
|
|
for sequence_idx in range(features.size(0)): |
|
|
if random.random() > p: |
|
|
|
|
|
continue |
|
|
features[sequence_idx] = time_warp_impl( |
|
|
features[sequence_idx], factor=time_warp_factor |
|
|
) |
|
|
else: |
|
|
|
|
|
for sequence_idx, start_frame, num_frames in supervision_segments: |
|
|
if random.random() > p: |
|
|
|
|
|
continue |
|
|
end_frame = start_frame + num_frames |
|
|
features[sequence_idx, start_frame:end_frame] = time_warp_impl( |
|
|
features[sequence_idx, start_frame:end_frame], factor=time_warp_factor |
|
|
) |
|
|
|
|
|
return features |
|
|
|