Spaces:
Running
Running
Shikhar commited on
Commit ·
84f8437
1
Parent(s): d876521
Deploy PhoneticXeus Gradio demo (CPU)
Browse files- README.md +15 -5
- app.py +110 -0
- requirements.txt +8 -0
- src/__init__.py +0 -0
- src/core/__init__.py +1 -0
- src/core/utils.py +59 -0
- src/espnet_import/__init__.py +0 -0
- src/espnet_import/attention.py +457 -0
- src/espnet_import/cgmlp.py +123 -0
- src/espnet_import/embedding.py +523 -0
- src/espnet_import/fastformer.py +153 -0
- src/espnet_import/label_smoothing_loss.py +64 -0
- src/espnet_import/layer_norm.py +43 -0
- src/espnet_import/nets_utils.py +690 -0
- src/espnet_import/positionwise_feed_forward.py +32 -0
- src/espnet_import/repeat.py +46 -0
- src/espnet_import/subsampling.py +873 -0
- src/model/__init__.py +0 -0
- src/model/powsm/__init__.py +0 -0
- src/model/powsm/ctc.py +230 -0
- src/model/powsm/e_branchformer.py +555 -0
- src/model/powsm/specaug.py +384 -0
- src/model/powsm/utils.py +80 -0
- src/model/xeusphoneme/__init__.py +0 -0
- src/model/xeusphoneme/builders.py +307 -0
- src/model/xeusphoneme/cnn_frontend.py +261 -0
- src/model/xeusphoneme/linear_layer.py +21 -0
- src/model/xeusphoneme/resources/ipa_vocab.json +430 -0
- src/model/xeusphoneme/xeuspr_inference.py +86 -0
- src/model/xeusphoneme/xeuspr_model.py +378 -0
- src/recipe/__init__.py +0 -0
- src/recipe/phone_recognition/__init__.py +0 -0
- src/recipe/phone_recognition/greedy_ctc_strategy.py +63 -0
- src/utils/__init__.py +1 -0
- src/utils/pylogger.py +23 -0
README.md
CHANGED
|
@@ -1,12 +1,22 @@
|
|
| 1 |
---
|
| 2 |
title: PhoneticXeus
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
title: PhoneticXeus
|
| 3 |
+
emoji: "\U0001F4DE"
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: "5.0"
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
+
license: apache-2.0
|
| 11 |
+
models:
|
| 12 |
+
- changelinglab/PhoneticXeus
|
| 13 |
+
hardware: cpu-basic
|
| 14 |
---
|
| 15 |
|
| 16 |
+
# PhoneticXeus -- Multilingual Phone Recognition
|
| 17 |
+
|
| 18 |
+
Record or upload audio to get an IPA phone transcription.
|
| 19 |
+
|
| 20 |
+
Based on [PhoneticXeus](https://huggingface.co/changelinglab/PhoneticXeus), a multilingual phone recognition model using self-conditioned CTC on the XEUS speech encoder.
|
| 21 |
+
|
| 22 |
+
Paper: [An Empirical Recipe for Universal Phone Recognition](https://arxiv.org/abs/2603.29042)
|
app.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
# Ensure vendored src/ is importable
|
| 5 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 6 |
+
|
| 7 |
+
import gradio as gr
|
| 8 |
+
import torch
|
| 9 |
+
import torchaudio
|
| 10 |
+
from huggingface_hub import hf_hub_download
|
| 11 |
+
|
| 12 |
+
from src.model.xeusphoneme.builders import build_xeus_pr_inference
|
| 13 |
+
|
| 14 |
+
MAX_SECONDS = 60
|
| 15 |
+
SAMPLE_RATE = 16000
|
| 16 |
+
|
| 17 |
+
inference = None
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def load_model():
|
| 21 |
+
ckpt = hf_hub_download(
|
| 22 |
+
"changelinglab/PhoneticXeus", "checkpoint-22000.ckpt"
|
| 23 |
+
)
|
| 24 |
+
vocab = os.path.join(
|
| 25 |
+
os.path.dirname(__file__),
|
| 26 |
+
"src", "model", "xeusphoneme", "resources", "ipa_vocab.json",
|
| 27 |
+
)
|
| 28 |
+
return build_xeus_pr_inference(
|
| 29 |
+
work_dir="/tmp/cache/xeus",
|
| 30 |
+
checkpoint=ckpt,
|
| 31 |
+
vocab_file=vocab,
|
| 32 |
+
hf_repo="espnet/xeus",
|
| 33 |
+
device="cpu",
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def transcribe(audio_path):
|
| 38 |
+
"""Run phone recognition on uploaded/recorded audio."""
|
| 39 |
+
global inference
|
| 40 |
+
if audio_path is None:
|
| 41 |
+
return "", ""
|
| 42 |
+
|
| 43 |
+
if inference is None:
|
| 44 |
+
inference = load_model()
|
| 45 |
+
|
| 46 |
+
waveform, sr = torchaudio.load(audio_path)
|
| 47 |
+
if sr != SAMPLE_RATE:
|
| 48 |
+
waveform = torchaudio.functional.resample(waveform, sr, SAMPLE_RATE)
|
| 49 |
+
waveform = waveform.mean(dim=0) # mono
|
| 50 |
+
waveform = waveform[: SAMPLE_RATE * MAX_SECONDS]
|
| 51 |
+
|
| 52 |
+
if waveform.numel() == 0:
|
| 53 |
+
return "", ""
|
| 54 |
+
|
| 55 |
+
results = inference(waveform)
|
| 56 |
+
|
| 57 |
+
processed = results[0]["processed_transcript"]
|
| 58 |
+
predicted = results[0]["predicted_transcript"]
|
| 59 |
+
spaced = " ".join(
|
| 60 |
+
t for t in predicted.split("/")
|
| 61 |
+
if not (t.startswith("<") and t.endswith(">"))
|
| 62 |
+
)
|
| 63 |
+
return spaced, processed
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
with gr.Blocks(title="PhoneticXeus") as demo:
|
| 67 |
+
gr.Markdown(
|
| 68 |
+
"# PhoneticXeus\n"
|
| 69 |
+
"Multilingual phone recognition -- record or upload audio "
|
| 70 |
+
"to get an IPA transcription.\n\n"
|
| 71 |
+
"Model: [changelinglab/PhoneticXeus]"
|
| 72 |
+
"(https://huggingface.co/changelinglab/PhoneticXeus) "
|
| 73 |
+
"| Paper: [arXiv 2603.29042]"
|
| 74 |
+
"(https://arxiv.org/abs/2603.29042)"
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
with gr.Row():
|
| 78 |
+
audio_input = gr.Audio(
|
| 79 |
+
sources=["microphone", "upload"],
|
| 80 |
+
type="filepath",
|
| 81 |
+
label="Input Audio",
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
btn = gr.Button("Transcribe", variant="primary")
|
| 85 |
+
|
| 86 |
+
with gr.Row():
|
| 87 |
+
phones_output = gr.Textbox(
|
| 88 |
+
label="IPA Phones (space-separated)",
|
| 89 |
+
lines=3,
|
| 90 |
+
show_copy_button=True,
|
| 91 |
+
)
|
| 92 |
+
raw_output = gr.Textbox(
|
| 93 |
+
label="Raw output (concatenated)",
|
| 94 |
+
lines=3,
|
| 95 |
+
show_copy_button=True,
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
btn.click(
|
| 99 |
+
fn=transcribe,
|
| 100 |
+
inputs=[audio_input],
|
| 101 |
+
outputs=[phones_output, raw_output],
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
gr.Markdown(
|
| 105 |
+
"---\n"
|
| 106 |
+
f"Max audio length: {MAX_SECONDS}s. "
|
| 107 |
+
"Audio is resampled to 16 kHz mono."
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
demo.launch()
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
torchaudio
|
| 3 |
+
huggingface_hub
|
| 4 |
+
pyyaml
|
| 5 |
+
typeguard
|
| 6 |
+
packaging
|
| 7 |
+
numpy
|
| 8 |
+
gradio>=5.0
|
src/__init__.py
ADDED
|
File without changes
|
src/core/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Core modules for PhoneticXeus."""
|
src/core/utils.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from huggingface_hub import snapshot_download
|
| 3 |
+
from huggingface_hub.utils import LocalEntryNotFoundError
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def download_hf_snapshot(
|
| 7 |
+
repo_id: str,
|
| 8 |
+
work_dir: str,
|
| 9 |
+
force_download: bool = False,
|
| 10 |
+
**kwargs,
|
| 11 |
+
) -> str:
|
| 12 |
+
"""Download a snapshot from Hugging Face Hub to `work_dir`.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
repo_id: e.g. "espnet/xeus"
|
| 16 |
+
work_dir: path to local directory where to store snapshot
|
| 17 |
+
force_download: if True, enforce re-download
|
| 18 |
+
**kwargs: other snapshot_download arguments
|
| 19 |
+
|
| 20 |
+
Returns:
|
| 21 |
+
The path to the local snapshot folder
|
| 22 |
+
"""
|
| 23 |
+
if force_download:
|
| 24 |
+
logging.info(
|
| 25 |
+
f"Force-downloading snapshot for {repo_id} into {work_dir}..."
|
| 26 |
+
)
|
| 27 |
+
path = snapshot_download(
|
| 28 |
+
repo_id=repo_id,
|
| 29 |
+
local_dir=work_dir,
|
| 30 |
+
force_download=True,
|
| 31 |
+
local_files_only=False,
|
| 32 |
+
**kwargs,
|
| 33 |
+
)
|
| 34 |
+
logging.info(f"Downloaded snapshot for {repo_id} to {path}")
|
| 35 |
+
return path
|
| 36 |
+
|
| 37 |
+
try:
|
| 38 |
+
path = snapshot_download(
|
| 39 |
+
repo_id=repo_id,
|
| 40 |
+
local_dir=work_dir,
|
| 41 |
+
local_files_only=True,
|
| 42 |
+
**kwargs,
|
| 43 |
+
)
|
| 44 |
+
logging.info(
|
| 45 |
+
f"Using existing local snapshot for {repo_id} at {path}"
|
| 46 |
+
)
|
| 47 |
+
return path
|
| 48 |
+
except LocalEntryNotFoundError:
|
| 49 |
+
logging.info(
|
| 50 |
+
f"No local snapshot found for {repo_id}. Downloading now..."
|
| 51 |
+
)
|
| 52 |
+
path = snapshot_download(
|
| 53 |
+
repo_id=repo_id,
|
| 54 |
+
local_dir=work_dir,
|
| 55 |
+
local_files_only=False,
|
| 56 |
+
**kwargs,
|
| 57 |
+
)
|
| 58 |
+
logging.info(f"Downloaded snapshot for {repo_id} to {path}")
|
| 59 |
+
return path
|
src/espnet_import/__init__.py
ADDED
|
File without changes
|
src/espnet_import/attention.py
ADDED
|
@@ -0,0 +1,457 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
# Copyright 2019 Shigeki Karita
|
| 5 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
| 6 |
+
|
| 7 |
+
"""Multi-Head Attention layer definition."""
|
| 8 |
+
|
| 9 |
+
import logging
|
| 10 |
+
import math
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from torch import nn
|
| 14 |
+
|
| 15 |
+
from src.espnet_import.layer_norm import LayerNorm
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
| 19 |
+
from flash_attn.bert_padding import pad_input, unpad_input
|
| 20 |
+
except Exception:
|
| 21 |
+
pass
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class MultiHeadedAttention(nn.Module):
|
| 25 |
+
"""Multi-Head Attention layer.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
n_head (int): The number of heads.
|
| 29 |
+
n_feat (int): The number of features.
|
| 30 |
+
dropout_rate (float): Dropout rate.
|
| 31 |
+
qk_norm (bool): Normalize q and k before dot product.
|
| 32 |
+
use_flash_attn (bool): Use flash_attn implementation.
|
| 33 |
+
causal (bool): Apply causal attention.
|
| 34 |
+
cross_attn (bool): Cross attention instead of self attention.
|
| 35 |
+
use_sdpa (bool): Use PyTorch's scaled dot product attention.
|
| 36 |
+
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
def __init__(
|
| 40 |
+
self,
|
| 41 |
+
n_head,
|
| 42 |
+
n_feat,
|
| 43 |
+
dropout_rate,
|
| 44 |
+
qk_norm=False,
|
| 45 |
+
use_flash_attn=False,
|
| 46 |
+
causal=False,
|
| 47 |
+
cross_attn=False,
|
| 48 |
+
use_sdpa=False,
|
| 49 |
+
):
|
| 50 |
+
"""Construct an MultiHeadedAttention object."""
|
| 51 |
+
super(MultiHeadedAttention, self).__init__()
|
| 52 |
+
|
| 53 |
+
assert n_feat % n_head == 0
|
| 54 |
+
# We assume d_v always equals d_k
|
| 55 |
+
self.d_k = n_feat // n_head
|
| 56 |
+
self.h = n_head
|
| 57 |
+
self.linear_q = nn.Linear(n_feat, n_feat)
|
| 58 |
+
self.linear_k = nn.Linear(n_feat, n_feat)
|
| 59 |
+
self.linear_v = nn.Linear(n_feat, n_feat)
|
| 60 |
+
self.linear_out = nn.Linear(n_feat, n_feat)
|
| 61 |
+
self.attn = None
|
| 62 |
+
self.dropout = (
|
| 63 |
+
nn.Dropout(p=dropout_rate) if not use_flash_attn else nn.Identity()
|
| 64 |
+
)
|
| 65 |
+
self.dropout_rate = dropout_rate
|
| 66 |
+
|
| 67 |
+
# LayerNorm for q and k
|
| 68 |
+
self.q_norm = LayerNorm(self.d_k) if qk_norm else nn.Identity()
|
| 69 |
+
self.k_norm = LayerNorm(self.d_k) if qk_norm else nn.Identity()
|
| 70 |
+
|
| 71 |
+
self.use_flash_attn = use_flash_attn
|
| 72 |
+
self.causal = causal # only used with flash_attn
|
| 73 |
+
self.cross_attn = cross_attn # only used with flash_attn
|
| 74 |
+
|
| 75 |
+
self.use_sdpa = use_sdpa
|
| 76 |
+
|
| 77 |
+
def forward_qkv(self, query, key, value, expand_kv=False):
|
| 78 |
+
"""Transform query, key and value.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
| 82 |
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
| 83 |
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
| 84 |
+
expand_kv (bool): Used only for partially autoregressive (PAR) decoding.
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
|
| 88 |
+
torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
|
| 89 |
+
torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
|
| 90 |
+
|
| 91 |
+
"""
|
| 92 |
+
n_batch = query.size(0)
|
| 93 |
+
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
|
| 94 |
+
|
| 95 |
+
if expand_kv:
|
| 96 |
+
k_shape = key.shape
|
| 97 |
+
k = (
|
| 98 |
+
self.linear_k(key[:1, :, :])
|
| 99 |
+
.expand(n_batch, k_shape[1], k_shape[2])
|
| 100 |
+
.view(n_batch, -1, self.h, self.d_k)
|
| 101 |
+
)
|
| 102 |
+
v_shape = value.shape
|
| 103 |
+
v = (
|
| 104 |
+
self.linear_v(value[:1, :, :])
|
| 105 |
+
.expand(n_batch, v_shape[1], v_shape[2])
|
| 106 |
+
.view(n_batch, -1, self.h, self.d_k)
|
| 107 |
+
)
|
| 108 |
+
else:
|
| 109 |
+
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
|
| 110 |
+
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
|
| 111 |
+
|
| 112 |
+
q = q.transpose(1, 2) # (batch, head, time1, d_k)
|
| 113 |
+
k = k.transpose(1, 2) # (batch, head, time2, d_k)
|
| 114 |
+
v = v.transpose(1, 2) # (batch, head, time2, d_k)
|
| 115 |
+
|
| 116 |
+
q = self.q_norm(q)
|
| 117 |
+
k = self.k_norm(k)
|
| 118 |
+
|
| 119 |
+
return q, k, v
|
| 120 |
+
|
| 121 |
+
def forward_attention(self, value, scores, mask):
|
| 122 |
+
"""Compute attention context vector.
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
|
| 126 |
+
scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
|
| 127 |
+
mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
torch.Tensor: Transformed value (#batch, time1, d_model)
|
| 131 |
+
weighted by the attention score (#batch, time1, time2).
|
| 132 |
+
|
| 133 |
+
"""
|
| 134 |
+
n_batch = value.size(0)
|
| 135 |
+
if mask is not None:
|
| 136 |
+
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
|
| 137 |
+
min_value = torch.finfo(scores.dtype).min
|
| 138 |
+
scores = scores.masked_fill(mask, min_value)
|
| 139 |
+
self.attn = torch.softmax(scores, dim=-1).masked_fill(
|
| 140 |
+
mask, 0.0
|
| 141 |
+
) # (batch, head, time1, time2)
|
| 142 |
+
else:
|
| 143 |
+
self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
|
| 144 |
+
|
| 145 |
+
p_attn = self.dropout(self.attn)
|
| 146 |
+
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
|
| 147 |
+
x = (
|
| 148 |
+
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
|
| 149 |
+
) # (batch, time1, d_model)
|
| 150 |
+
|
| 151 |
+
return self.linear_out(x) # (batch, time1, d_model)
|
| 152 |
+
|
| 153 |
+
def forward(self, query, key, value, mask, expand_kv=False):
|
| 154 |
+
"""Compute scaled dot product attention.
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
| 158 |
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
| 159 |
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
| 160 |
+
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
| 161 |
+
(#batch, time1, time2).
|
| 162 |
+
expand_kv (bool): Used only for partially autoregressive (PAR) decoding.
|
| 163 |
+
When set to `True`, `Linear` layers are computed only for the first
|
| 164 |
+
batch. This is useful to reduce the memory usage during decoding
|
| 165 |
+
when the batch size is #beam_size x #mask_count, which can be large.
|
| 166 |
+
Typically, in single waveform inference of PAR, `Linear` layers
|
| 167 |
+
should not be computed for all batches for source-attention.
|
| 168 |
+
|
| 169 |
+
Returns:
|
| 170 |
+
torch.Tensor: Output tensor (#batch, time1, d_model).
|
| 171 |
+
"""
|
| 172 |
+
# Use PyTorch's Scaled Dot Product Attention implementation
|
| 173 |
+
if getattr(self, "use_sdpa", False):
|
| 174 |
+
q, k, v = self.forward_qkv(query, key, value, expand_kv)
|
| 175 |
+
|
| 176 |
+
# The shape of mask must be broadcastable to the shape of attention weights
|
| 177 |
+
out = torch.nn.functional.scaled_dot_product_attention(
|
| 178 |
+
q,
|
| 179 |
+
k,
|
| 180 |
+
v,
|
| 181 |
+
mask.unsqueeze(1) if mask is not None else None,
|
| 182 |
+
dropout_p=self.dropout_rate if self.training else 0.0,
|
| 183 |
+
) # (batch, head, time1, d_k)
|
| 184 |
+
|
| 185 |
+
out = out.transpose(1, 2) # (batch, time1, head, d_k)
|
| 186 |
+
out = out.reshape(out.shape[0], out.shape[1], -1) # (batch, time1, d_model)
|
| 187 |
+
return self.linear_out(out) # (batch, time1, d_model)
|
| 188 |
+
|
| 189 |
+
# Use Flash Attention implementation
|
| 190 |
+
if self.use_flash_attn:
|
| 191 |
+
try:
|
| 192 |
+
# In the causal case, the last row will be the key mask
|
| 193 |
+
key_nonpad_mask = mask[:, -1, :] # (#batch, time2)
|
| 194 |
+
if self.cross_attn:
|
| 195 |
+
# For cross attention, we do not know the query padding
|
| 196 |
+
query_nonpad_mask = torch.ones(
|
| 197 |
+
size=query.shape[:2], dtype=torch.bool, device=query.device
|
| 198 |
+
)
|
| 199 |
+
else:
|
| 200 |
+
query_nonpad_mask = key_nonpad_mask
|
| 201 |
+
|
| 202 |
+
if key_nonpad_mask.eq(0).any():
|
| 203 |
+
# Use variable length implementation if padded
|
| 204 |
+
q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(
|
| 205 |
+
query, query_nonpad_mask
|
| 206 |
+
)[:4]
|
| 207 |
+
k, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(
|
| 208 |
+
key, key_nonpad_mask
|
| 209 |
+
)[:4]
|
| 210 |
+
v, _, _, _ = unpad_input(value, key_nonpad_mask)[:4]
|
| 211 |
+
|
| 212 |
+
q = self.linear_q(q).reshape(-1, self.h, self.d_k)
|
| 213 |
+
k = self.linear_k(k).reshape(-1, self.h, self.d_k)
|
| 214 |
+
v = self.linear_v(v).reshape(-1, self.h, self.d_k)
|
| 215 |
+
|
| 216 |
+
q = self.q_norm(q)
|
| 217 |
+
k = self.k_norm(k)
|
| 218 |
+
|
| 219 |
+
out = flash_attn_varlen_func(
|
| 220 |
+
q,
|
| 221 |
+
k,
|
| 222 |
+
v,
|
| 223 |
+
cu_seqlens_q,
|
| 224 |
+
cu_seqlens_k,
|
| 225 |
+
max_seqlen_q,
|
| 226 |
+
max_seqlen_k,
|
| 227 |
+
dropout_p=self.dropout_rate if self.training else 0.0,
|
| 228 |
+
causal=self.causal,
|
| 229 |
+
) # (total, nheads, headdim)
|
| 230 |
+
|
| 231 |
+
out = out.reshape(out.shape[0], -1)
|
| 232 |
+
out = self.linear_out(out)
|
| 233 |
+
|
| 234 |
+
out = pad_input(out, indices_q, query.shape[0], query.shape[1])
|
| 235 |
+
return out
|
| 236 |
+
|
| 237 |
+
else:
|
| 238 |
+
# Use fixed length implementation if not padded,
|
| 239 |
+
# which is faster than the variable length implementation
|
| 240 |
+
del key_nonpad_mask
|
| 241 |
+
q, k, v = self.forward_qkv(query, key, value)
|
| 242 |
+
|
| 243 |
+
out = flash_attn_func(
|
| 244 |
+
q.transpose(1, 2),
|
| 245 |
+
k.transpose(1, 2),
|
| 246 |
+
v.transpose(1, 2),
|
| 247 |
+
dropout_p=self.dropout_rate if self.training else 0.0,
|
| 248 |
+
causal=self.causal,
|
| 249 |
+
) # (batch_size, seqlen, nheads, headdim)
|
| 250 |
+
del q, k, v
|
| 251 |
+
|
| 252 |
+
out = out.reshape(out.shape[0], out.shape[1], -1)
|
| 253 |
+
out = self.linear_out(out)
|
| 254 |
+
return out
|
| 255 |
+
|
| 256 |
+
except Exception as e:
|
| 257 |
+
pass
|
| 258 |
+
self.use_flash_attn = False
|
| 259 |
+
|
| 260 |
+
# Fall back to the default implementation
|
| 261 |
+
q, k, v = self.forward_qkv(query, key, value, expand_kv)
|
| 262 |
+
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
|
| 263 |
+
return self.forward_attention(v, scores, mask)
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
class LegacyRelPositionMultiHeadedAttention(MultiHeadedAttention):
|
| 267 |
+
"""Multi-Head Attention layer with relative position encoding (old version).
|
| 268 |
+
|
| 269 |
+
Details can be found in https://github.com/espnet/espnet/pull/2816.
|
| 270 |
+
|
| 271 |
+
Paper: https://arxiv.org/abs/1901.02860
|
| 272 |
+
|
| 273 |
+
Args:
|
| 274 |
+
n_head (int): The number of heads.
|
| 275 |
+
n_feat (int): The number of features.
|
| 276 |
+
dropout_rate (float): Dropout rate.
|
| 277 |
+
zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
|
| 278 |
+
|
| 279 |
+
"""
|
| 280 |
+
|
| 281 |
+
def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False):
|
| 282 |
+
"""Construct an RelPositionMultiHeadedAttention object."""
|
| 283 |
+
super().__init__(n_head, n_feat, dropout_rate)
|
| 284 |
+
self.zero_triu = zero_triu
|
| 285 |
+
# linear transformation for positional encoding
|
| 286 |
+
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
|
| 287 |
+
# these two learnable bias are used in matrix c and matrix d
|
| 288 |
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
| 289 |
+
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
| 290 |
+
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
| 291 |
+
torch.nn.init.xavier_uniform_(self.pos_bias_u)
|
| 292 |
+
torch.nn.init.xavier_uniform_(self.pos_bias_v)
|
| 293 |
+
|
| 294 |
+
def rel_shift(self, x):
|
| 295 |
+
"""Compute relative positional encoding.
|
| 296 |
+
|
| 297 |
+
Args:
|
| 298 |
+
x (torch.Tensor): Input tensor (batch, head, time1, time2).
|
| 299 |
+
|
| 300 |
+
Returns:
|
| 301 |
+
torch.Tensor: Output tensor.
|
| 302 |
+
|
| 303 |
+
"""
|
| 304 |
+
zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
|
| 305 |
+
x_padded = torch.cat([zero_pad, x], dim=-1)
|
| 306 |
+
|
| 307 |
+
x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
|
| 308 |
+
x = x_padded[:, :, 1:].view_as(x)
|
| 309 |
+
|
| 310 |
+
if self.zero_triu:
|
| 311 |
+
ones = torch.ones((x.size(2), x.size(3)))
|
| 312 |
+
x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
|
| 313 |
+
|
| 314 |
+
return x
|
| 315 |
+
|
| 316 |
+
def forward(self, query, key, value, pos_emb, mask):
|
| 317 |
+
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
|
| 318 |
+
|
| 319 |
+
Args:
|
| 320 |
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
| 321 |
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
| 322 |
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
| 323 |
+
pos_emb (torch.Tensor): Positional embedding tensor (#batch, time1, size).
|
| 324 |
+
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
| 325 |
+
(#batch, time1, time2).
|
| 326 |
+
|
| 327 |
+
Returns:
|
| 328 |
+
torch.Tensor: Output tensor (#batch, time1, d_model).
|
| 329 |
+
|
| 330 |
+
"""
|
| 331 |
+
q, k, v = self.forward_qkv(query, key, value)
|
| 332 |
+
q = q.transpose(1, 2) # (batch, time1, head, d_k)
|
| 333 |
+
|
| 334 |
+
n_batch_pos = pos_emb.size(0)
|
| 335 |
+
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
|
| 336 |
+
p = p.transpose(1, 2) # (batch, head, time1, d_k)
|
| 337 |
+
|
| 338 |
+
# (batch, head, time1, d_k)
|
| 339 |
+
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
|
| 340 |
+
# (batch, head, time1, d_k)
|
| 341 |
+
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
|
| 342 |
+
|
| 343 |
+
# compute attention score
|
| 344 |
+
# first compute matrix a and matrix c
|
| 345 |
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
| 346 |
+
# (batch, head, time1, time2)
|
| 347 |
+
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
|
| 348 |
+
|
| 349 |
+
# compute matrix b and matrix d
|
| 350 |
+
# (batch, head, time1, time1)
|
| 351 |
+
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
|
| 352 |
+
matrix_bd = self.rel_shift(matrix_bd)
|
| 353 |
+
|
| 354 |
+
scores = (matrix_ac + matrix_bd) / math.sqrt(
|
| 355 |
+
self.d_k
|
| 356 |
+
) # (batch, head, time1, time2)
|
| 357 |
+
|
| 358 |
+
return self.forward_attention(v, scores, mask)
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
class RelPositionMultiHeadedAttention(MultiHeadedAttention):
|
| 362 |
+
"""Multi-Head Attention layer with relative position encoding (new implementation).
|
| 363 |
+
|
| 364 |
+
Details can be found in https://github.com/espnet/espnet/pull/2816.
|
| 365 |
+
|
| 366 |
+
Paper: https://arxiv.org/abs/1901.02860
|
| 367 |
+
|
| 368 |
+
Args:
|
| 369 |
+
n_head (int): The number of heads.
|
| 370 |
+
n_feat (int): The number of features.
|
| 371 |
+
dropout_rate (float): Dropout rate.
|
| 372 |
+
zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
|
| 373 |
+
|
| 374 |
+
"""
|
| 375 |
+
|
| 376 |
+
def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False):
|
| 377 |
+
"""Construct an RelPositionMultiHeadedAttention object."""
|
| 378 |
+
super().__init__(n_head, n_feat, dropout_rate)
|
| 379 |
+
self.zero_triu = zero_triu
|
| 380 |
+
# linear transformation for positional encoding
|
| 381 |
+
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
|
| 382 |
+
# these two learnable bias are used in matrix c and matrix d
|
| 383 |
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
| 384 |
+
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
| 385 |
+
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
| 386 |
+
torch.nn.init.xavier_uniform_(self.pos_bias_u)
|
| 387 |
+
torch.nn.init.xavier_uniform_(self.pos_bias_v)
|
| 388 |
+
|
| 389 |
+
def rel_shift(self, x):
|
| 390 |
+
"""Compute relative positional encoding.
|
| 391 |
+
|
| 392 |
+
Args:
|
| 393 |
+
x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
|
| 394 |
+
time1 means the length of query vector.
|
| 395 |
+
|
| 396 |
+
Returns:
|
| 397 |
+
torch.Tensor: Output tensor.
|
| 398 |
+
|
| 399 |
+
"""
|
| 400 |
+
zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
|
| 401 |
+
x_padded = torch.cat([zero_pad, x], dim=-1)
|
| 402 |
+
|
| 403 |
+
x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
|
| 404 |
+
x = x_padded[:, :, 1:].view_as(x)[
|
| 405 |
+
:, :, :, : x.size(-1) // 2 + 1
|
| 406 |
+
] # only keep the positions from 0 to time2
|
| 407 |
+
|
| 408 |
+
if self.zero_triu:
|
| 409 |
+
ones = torch.ones((x.size(2), x.size(3)), device=x.device)
|
| 410 |
+
x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
|
| 411 |
+
|
| 412 |
+
return x
|
| 413 |
+
|
| 414 |
+
def forward(self, query, key, value, pos_emb, mask):
|
| 415 |
+
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
|
| 416 |
+
|
| 417 |
+
Args:
|
| 418 |
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
| 419 |
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
| 420 |
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
| 421 |
+
pos_emb (torch.Tensor): Positional embedding tensor
|
| 422 |
+
(#batch, 2*time1-1, size).
|
| 423 |
+
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
| 424 |
+
(#batch, time1, time2).
|
| 425 |
+
|
| 426 |
+
Returns:
|
| 427 |
+
torch.Tensor: Output tensor (#batch, time1, d_model).
|
| 428 |
+
|
| 429 |
+
"""
|
| 430 |
+
q, k, v = self.forward_qkv(query, key, value)
|
| 431 |
+
q = q.transpose(1, 2) # (batch, time1, head, d_k)
|
| 432 |
+
|
| 433 |
+
n_batch_pos = pos_emb.size(0)
|
| 434 |
+
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
|
| 435 |
+
p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
|
| 436 |
+
|
| 437 |
+
# (batch, head, time1, d_k)
|
| 438 |
+
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
|
| 439 |
+
# (batch, head, time1, d_k)
|
| 440 |
+
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
|
| 441 |
+
|
| 442 |
+
# compute attention score
|
| 443 |
+
# first compute matrix a and matrix c
|
| 444 |
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
| 445 |
+
# (batch, head, time1, time2)
|
| 446 |
+
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
|
| 447 |
+
|
| 448 |
+
# compute matrix b and matrix d
|
| 449 |
+
# (batch, head, time1, 2*time1-1)
|
| 450 |
+
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
|
| 451 |
+
matrix_bd = self.rel_shift(matrix_bd)
|
| 452 |
+
|
| 453 |
+
scores = (matrix_ac + matrix_bd) / math.sqrt(
|
| 454 |
+
self.d_k
|
| 455 |
+
) # (batch, head, time1, time2)
|
| 456 |
+
|
| 457 |
+
return self.forward_attention(v, scores, mask)
|
src/espnet_import/cgmlp.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""MLP with convolutional gating (cgMLP) definition.
|
| 2 |
+
|
| 3 |
+
References:
|
| 4 |
+
https://openreview.net/forum?id=RA-zVvZLYIy
|
| 5 |
+
https://arxiv.org/abs/2105.08050
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from src.espnet_import.nets_utils import get_activation
|
| 12 |
+
from src.espnet_import.layer_norm import LayerNorm
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class ConvolutionalSpatialGatingUnit(torch.nn.Module):
|
| 16 |
+
"""Convolutional Spatial Gating Unit (CSGU)."""
|
| 17 |
+
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
size: int,
|
| 21 |
+
kernel_size: int,
|
| 22 |
+
dropout_rate: float,
|
| 23 |
+
use_linear_after_conv: bool,
|
| 24 |
+
gate_activation: str,
|
| 25 |
+
):
|
| 26 |
+
super().__init__()
|
| 27 |
+
|
| 28 |
+
n_channels = size // 2 # split input channels
|
| 29 |
+
self.norm = LayerNorm(n_channels)
|
| 30 |
+
self.conv = torch.nn.Conv1d(
|
| 31 |
+
n_channels,
|
| 32 |
+
n_channels,
|
| 33 |
+
kernel_size,
|
| 34 |
+
1,
|
| 35 |
+
(kernel_size - 1) // 2,
|
| 36 |
+
groups=n_channels,
|
| 37 |
+
)
|
| 38 |
+
if use_linear_after_conv:
|
| 39 |
+
self.linear = torch.nn.Linear(n_channels, n_channels)
|
| 40 |
+
else:
|
| 41 |
+
self.linear = None
|
| 42 |
+
|
| 43 |
+
if gate_activation == "identity":
|
| 44 |
+
self.act = torch.nn.Identity()
|
| 45 |
+
else:
|
| 46 |
+
self.act = get_activation(gate_activation)
|
| 47 |
+
|
| 48 |
+
self.dropout = torch.nn.Dropout(dropout_rate)
|
| 49 |
+
|
| 50 |
+
def espnet_initialization_fn(self):
|
| 51 |
+
torch.nn.init.normal_(self.conv.weight, std=1e-6)
|
| 52 |
+
torch.nn.init.ones_(self.conv.bias)
|
| 53 |
+
if self.linear is not None:
|
| 54 |
+
torch.nn.init.normal_(self.linear.weight, std=1e-6)
|
| 55 |
+
torch.nn.init.ones_(self.linear.bias)
|
| 56 |
+
|
| 57 |
+
def forward(self, x, gate_add=None):
|
| 58 |
+
"""Forward method
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
x (torch.Tensor): (N, T, D)
|
| 62 |
+
gate_add (torch.Tensor): (N, T, D/2)
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
out (torch.Tensor): (N, T, D/2)
|
| 66 |
+
"""
|
| 67 |
+
x_r, x_g = x.chunk(2, dim=-1)
|
| 68 |
+
|
| 69 |
+
x_g = self.norm(x_g) # (N, T, D/2)
|
| 70 |
+
x_g = self.conv(x_g.transpose(1, 2)).transpose(1, 2) # (N, T, D/2)
|
| 71 |
+
if self.linear is not None:
|
| 72 |
+
x_g = self.linear(x_g)
|
| 73 |
+
|
| 74 |
+
if gate_add is not None:
|
| 75 |
+
x_g = x_g + gate_add
|
| 76 |
+
|
| 77 |
+
x_g = self.act(x_g)
|
| 78 |
+
out = x_r * x_g # (N, T, D/2)
|
| 79 |
+
out = self.dropout(out)
|
| 80 |
+
return out
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class ConvolutionalGatingMLP(torch.nn.Module):
|
| 84 |
+
"""Convolutional Gating MLP (cgMLP)."""
|
| 85 |
+
|
| 86 |
+
def __init__(
|
| 87 |
+
self,
|
| 88 |
+
size: int,
|
| 89 |
+
linear_units: int,
|
| 90 |
+
kernel_size: int,
|
| 91 |
+
dropout_rate: float,
|
| 92 |
+
use_linear_after_conv: bool,
|
| 93 |
+
gate_activation: str,
|
| 94 |
+
):
|
| 95 |
+
super().__init__()
|
| 96 |
+
|
| 97 |
+
self.channel_proj1 = torch.nn.Sequential(
|
| 98 |
+
torch.nn.Linear(size, linear_units), torch.nn.GELU()
|
| 99 |
+
)
|
| 100 |
+
self.csgu = ConvolutionalSpatialGatingUnit(
|
| 101 |
+
size=linear_units,
|
| 102 |
+
kernel_size=kernel_size,
|
| 103 |
+
dropout_rate=dropout_rate,
|
| 104 |
+
use_linear_after_conv=use_linear_after_conv,
|
| 105 |
+
gate_activation=gate_activation,
|
| 106 |
+
)
|
| 107 |
+
self.channel_proj2 = torch.nn.Linear(linear_units // 2, size)
|
| 108 |
+
|
| 109 |
+
def forward(self, x, mask):
|
| 110 |
+
if isinstance(x, tuple):
|
| 111 |
+
xs_pad, pos_emb = x
|
| 112 |
+
else:
|
| 113 |
+
xs_pad, pos_emb = x, None
|
| 114 |
+
|
| 115 |
+
xs_pad = self.channel_proj1(xs_pad) # size -> linear_units
|
| 116 |
+
xs_pad = self.csgu(xs_pad) # linear_units -> linear_units/2
|
| 117 |
+
xs_pad = self.channel_proj2(xs_pad) # linear_units/2 -> size
|
| 118 |
+
|
| 119 |
+
if pos_emb is not None:
|
| 120 |
+
out = (xs_pad, pos_emb)
|
| 121 |
+
else:
|
| 122 |
+
out = xs_pad
|
| 123 |
+
return out
|
src/espnet_import/embedding.py
ADDED
|
@@ -0,0 +1,523 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
# Copyright 2019 Shigeki Karita
|
| 5 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
| 6 |
+
|
| 7 |
+
"""Positional Encoding Module."""
|
| 8 |
+
|
| 9 |
+
import logging
|
| 10 |
+
import math
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from packaging.version import parse as V
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# from espnet2.asr.frontend.cnn import dim_1_layer_norm
|
| 17 |
+
def dim_1_layer_norm(x, eps=1e-05, gamma=None, beta=None):
|
| 18 |
+
"""Functional version of Dim1LayerNorm."""
|
| 19 |
+
|
| 20 |
+
B, D, T = x.shape
|
| 21 |
+
mean = torch.mean(x, 1, keepdim=True)
|
| 22 |
+
variance = torch.mean((x - mean) ** 2, 1, keepdim=True)
|
| 23 |
+
|
| 24 |
+
x = (x - mean) * torch.rsqrt(variance + eps)
|
| 25 |
+
|
| 26 |
+
if gamma is not None:
|
| 27 |
+
x = x * gamma.view(1, -1, 1)
|
| 28 |
+
if beta is not None:
|
| 29 |
+
x = x + beta.view(1, -1, 1)
|
| 30 |
+
return x
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _pre_hook(
|
| 34 |
+
state_dict,
|
| 35 |
+
prefix,
|
| 36 |
+
local_metadata,
|
| 37 |
+
strict,
|
| 38 |
+
missing_keys,
|
| 39 |
+
unexpected_keys,
|
| 40 |
+
error_msgs,
|
| 41 |
+
):
|
| 42 |
+
"""Perform pre-hook in load_state_dict for backward compatibility.
|
| 43 |
+
|
| 44 |
+
Note:
|
| 45 |
+
We saved self.pe until v.0.5.2 but we have omitted it later.
|
| 46 |
+
Therefore, we remove the item "pe" from `state_dict` for backward compatibility.
|
| 47 |
+
|
| 48 |
+
"""
|
| 49 |
+
k = prefix + "pe"
|
| 50 |
+
if k in state_dict:
|
| 51 |
+
state_dict.pop(k)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class PositionalEncoding(torch.nn.Module):
|
| 55 |
+
"""Positional encoding.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
d_model (int): Embedding dimension.
|
| 59 |
+
dropout_rate (float): Dropout rate.
|
| 60 |
+
max_len (int): Maximum input length.
|
| 61 |
+
reverse (bool): Whether to reverse the input position. Only for
|
| 62 |
+
the class LegacyRelPositionalEncoding. We remove it in the current
|
| 63 |
+
class RelPositionalEncoding.
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False):
|
| 67 |
+
"""Construct an PositionalEncoding object."""
|
| 68 |
+
super(PositionalEncoding, self).__init__()
|
| 69 |
+
self.d_model = d_model
|
| 70 |
+
self.reverse = reverse
|
| 71 |
+
self.xscale = math.sqrt(self.d_model)
|
| 72 |
+
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
| 73 |
+
self.pe = None
|
| 74 |
+
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
| 75 |
+
self._register_load_state_dict_pre_hook(_pre_hook)
|
| 76 |
+
|
| 77 |
+
def extend_pe(self, x):
|
| 78 |
+
"""Reset the positional encodings."""
|
| 79 |
+
if self.pe is not None:
|
| 80 |
+
if self.pe.size(1) >= x.size(1):
|
| 81 |
+
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
| 82 |
+
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
| 83 |
+
return
|
| 84 |
+
pe = torch.zeros(x.size(1), self.d_model)
|
| 85 |
+
if self.reverse:
|
| 86 |
+
position = torch.arange(
|
| 87 |
+
x.size(1) - 1, -1, -1.0, dtype=torch.float32
|
| 88 |
+
).unsqueeze(1)
|
| 89 |
+
else:
|
| 90 |
+
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
| 91 |
+
div_term = torch.exp(
|
| 92 |
+
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
| 93 |
+
* -(math.log(10000.0) / self.d_model)
|
| 94 |
+
)
|
| 95 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 96 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 97 |
+
pe = pe.unsqueeze(0)
|
| 98 |
+
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
| 99 |
+
|
| 100 |
+
def forward(self, x: torch.Tensor):
|
| 101 |
+
"""Add positional encoding.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
x (torch.Tensor): Input tensor (batch, time, `*`).
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
torch.Tensor: Encoded tensor (batch, time, `*`).
|
| 108 |
+
"""
|
| 109 |
+
self.extend_pe(x)
|
| 110 |
+
x = x * self.xscale + self.pe[:, : x.size(1)]
|
| 111 |
+
return self.dropout(x)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class ScaledPositionalEncoding(PositionalEncoding):
|
| 115 |
+
"""Scaled positional encoding module.
|
| 116 |
+
|
| 117 |
+
See Sec. 3.2 https://arxiv.org/abs/1809.08895
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
d_model (int): Embedding dimension.
|
| 121 |
+
dropout_rate (float): Dropout rate.
|
| 122 |
+
max_len (int): Maximum input length.
|
| 123 |
+
|
| 124 |
+
"""
|
| 125 |
+
|
| 126 |
+
def __init__(self, d_model, dropout_rate, max_len=5000):
|
| 127 |
+
"""Initialize class."""
|
| 128 |
+
super().__init__(d_model=d_model, dropout_rate=dropout_rate, max_len=max_len)
|
| 129 |
+
self.alpha = torch.nn.Parameter(torch.tensor(1.0))
|
| 130 |
+
|
| 131 |
+
def reset_parameters(self):
|
| 132 |
+
"""Reset parameters."""
|
| 133 |
+
self.alpha.data = torch.tensor(1.0)
|
| 134 |
+
|
| 135 |
+
def forward(self, x):
|
| 136 |
+
"""Add positional encoding.
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
x (torch.Tensor): Input tensor (batch, time, `*`).
|
| 140 |
+
|
| 141 |
+
Returns:
|
| 142 |
+
torch.Tensor: Encoded tensor (batch, time, `*`).
|
| 143 |
+
|
| 144 |
+
"""
|
| 145 |
+
self.extend_pe(x)
|
| 146 |
+
x = x + self.alpha * self.pe[:, : x.size(1)]
|
| 147 |
+
return self.dropout(x)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class LearnableFourierPosEnc(torch.nn.Module):
|
| 151 |
+
"""Learnable Fourier Features for Positional Encoding.
|
| 152 |
+
|
| 153 |
+
See https://arxiv.org/pdf/2106.02795.pdf
|
| 154 |
+
|
| 155 |
+
Args:
|
| 156 |
+
d_model (int): Embedding dimension.
|
| 157 |
+
dropout_rate (float): Dropout rate.
|
| 158 |
+
max_len (int): Maximum input length.
|
| 159 |
+
gamma (float): init parameter for the positional kernel variance
|
| 160 |
+
see https://arxiv.org/pdf/2106.02795.pdf.
|
| 161 |
+
apply_scaling (bool): Whether to scale the input before adding the pos encoding.
|
| 162 |
+
hidden_dim (int): if not None, we modulate the pos encodings with
|
| 163 |
+
an MLP whose hidden layer has hidden_dim neurons.
|
| 164 |
+
"""
|
| 165 |
+
|
| 166 |
+
def __init__(
|
| 167 |
+
self,
|
| 168 |
+
d_model,
|
| 169 |
+
dropout_rate=0.0,
|
| 170 |
+
max_len=5000,
|
| 171 |
+
gamma=1.0,
|
| 172 |
+
apply_scaling=False,
|
| 173 |
+
hidden_dim=None,
|
| 174 |
+
):
|
| 175 |
+
"""Initialize class."""
|
| 176 |
+
super(LearnableFourierPosEnc, self).__init__()
|
| 177 |
+
|
| 178 |
+
self.d_model = d_model
|
| 179 |
+
|
| 180 |
+
if apply_scaling:
|
| 181 |
+
self.xscale = math.sqrt(self.d_model)
|
| 182 |
+
else:
|
| 183 |
+
self.xscale = 1.0
|
| 184 |
+
|
| 185 |
+
self.dropout = torch.nn.Dropout(dropout_rate)
|
| 186 |
+
self.max_len = max_len
|
| 187 |
+
|
| 188 |
+
self.gamma = gamma
|
| 189 |
+
if self.gamma is None:
|
| 190 |
+
self.gamma = self.d_model // 2
|
| 191 |
+
|
| 192 |
+
assert (
|
| 193 |
+
d_model % 2 == 0
|
| 194 |
+
), "d_model should be divisible by two in order to use this layer."
|
| 195 |
+
self.w_r = torch.nn.Parameter(torch.empty(1, d_model // 2))
|
| 196 |
+
self._reset() # init the weights
|
| 197 |
+
|
| 198 |
+
self.hidden_dim = hidden_dim
|
| 199 |
+
if self.hidden_dim is not None:
|
| 200 |
+
self.mlp = torch.nn.Sequential(
|
| 201 |
+
torch.nn.Linear(d_model, hidden_dim),
|
| 202 |
+
torch.nn.GELU(),
|
| 203 |
+
torch.nn.Linear(hidden_dim, d_model),
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
def _reset(self):
|
| 207 |
+
self.w_r.data = torch.normal(
|
| 208 |
+
0, (1 / math.sqrt(self.gamma)), (1, self.d_model // 2)
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
def extend_pe(self, x):
|
| 212 |
+
"""Reset the positional encodings."""
|
| 213 |
+
position_v = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1).to(x)
|
| 214 |
+
|
| 215 |
+
cosine = torch.cos(torch.matmul(position_v, self.w_r))
|
| 216 |
+
sine = torch.sin(torch.matmul(position_v, self.w_r))
|
| 217 |
+
pos_enc = torch.cat((cosine, sine), -1)
|
| 218 |
+
pos_enc /= math.sqrt(self.d_model)
|
| 219 |
+
|
| 220 |
+
if self.hidden_dim is None:
|
| 221 |
+
return pos_enc.unsqueeze(0)
|
| 222 |
+
else:
|
| 223 |
+
return self.mlp(pos_enc.unsqueeze(0))
|
| 224 |
+
|
| 225 |
+
def forward(self, x: torch.Tensor):
|
| 226 |
+
"""Add positional encoding.
|
| 227 |
+
|
| 228 |
+
Args:
|
| 229 |
+
x (torch.Tensor): Input tensor (batch, time, `*`).
|
| 230 |
+
|
| 231 |
+
Returns:
|
| 232 |
+
torch.Tensor: Encoded tensor (batch, time, `*`).
|
| 233 |
+
"""
|
| 234 |
+
pe = self.extend_pe(x)
|
| 235 |
+
x = x * self.xscale + pe
|
| 236 |
+
return self.dropout(x)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
class LegacyRelPositionalEncoding(PositionalEncoding):
|
| 240 |
+
"""Relative positional encoding module (old version).
|
| 241 |
+
|
| 242 |
+
Details can be found in https://github.com/espnet/espnet/pull/2816.
|
| 243 |
+
|
| 244 |
+
See : Appendix B in https://arxiv.org/abs/1901.02860
|
| 245 |
+
|
| 246 |
+
Args:
|
| 247 |
+
d_model (int): Embedding dimension.
|
| 248 |
+
dropout_rate (float): Dropout rate.
|
| 249 |
+
max_len (int): Maximum input length.
|
| 250 |
+
|
| 251 |
+
"""
|
| 252 |
+
|
| 253 |
+
def __init__(self, d_model, dropout_rate, max_len=5000):
|
| 254 |
+
"""Initialize class."""
|
| 255 |
+
super().__init__(
|
| 256 |
+
d_model=d_model,
|
| 257 |
+
dropout_rate=dropout_rate,
|
| 258 |
+
max_len=max_len,
|
| 259 |
+
reverse=True,
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
def forward(self, x):
|
| 263 |
+
"""Compute positional encoding.
|
| 264 |
+
|
| 265 |
+
Args:
|
| 266 |
+
x (torch.Tensor): Input tensor (batch, time, `*`).
|
| 267 |
+
|
| 268 |
+
Returns:
|
| 269 |
+
torch.Tensor: Encoded tensor (batch, time, `*`).
|
| 270 |
+
torch.Tensor: Positional embedding tensor (1, time, `*`).
|
| 271 |
+
|
| 272 |
+
"""
|
| 273 |
+
self.extend_pe(x)
|
| 274 |
+
x = x * self.xscale
|
| 275 |
+
pos_emb = self.pe[:, : x.size(1)]
|
| 276 |
+
return self.dropout(x), self.dropout(pos_emb)
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
class RelPositionalEncoding(torch.nn.Module):
|
| 280 |
+
"""Relative positional encoding module (new implementation).
|
| 281 |
+
|
| 282 |
+
Details can be found in https://github.com/espnet/espnet/pull/2816.
|
| 283 |
+
|
| 284 |
+
See : Appendix B in https://arxiv.org/abs/1901.02860
|
| 285 |
+
|
| 286 |
+
Args:
|
| 287 |
+
d_model (int): Embedding dimension.
|
| 288 |
+
dropout_rate (float): Dropout rate.
|
| 289 |
+
max_len (int): Maximum input length.
|
| 290 |
+
|
| 291 |
+
"""
|
| 292 |
+
|
| 293 |
+
def __init__(self, d_model, dropout_rate, max_len=5000):
|
| 294 |
+
"""Construct an PositionalEncoding object."""
|
| 295 |
+
super(RelPositionalEncoding, self).__init__()
|
| 296 |
+
self.d_model = d_model
|
| 297 |
+
self.xscale = math.sqrt(self.d_model)
|
| 298 |
+
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
| 299 |
+
self.pe = None
|
| 300 |
+
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
| 301 |
+
|
| 302 |
+
def extend_pe(self, x):
|
| 303 |
+
"""Reset the positional encodings."""
|
| 304 |
+
if self.pe is not None:
|
| 305 |
+
# self.pe contains both positive and negative parts
|
| 306 |
+
# the length of self.pe is 2 * input_len - 1
|
| 307 |
+
if self.pe.size(1) >= x.size(1) * 2 - 1:
|
| 308 |
+
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
| 309 |
+
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
| 310 |
+
return
|
| 311 |
+
# Suppose `i` means to the position of query vecotr and `j` means the
|
| 312 |
+
# position of key vector. We use position relative positions when keys
|
| 313 |
+
# are to the left (i>j) and negative relative positions otherwise (i<j).
|
| 314 |
+
pe_positive = torch.zeros(x.size(1), self.d_model)
|
| 315 |
+
pe_negative = torch.zeros(x.size(1), self.d_model)
|
| 316 |
+
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
| 317 |
+
div_term = torch.exp(
|
| 318 |
+
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
| 319 |
+
* -(math.log(10000.0) / self.d_model)
|
| 320 |
+
)
|
| 321 |
+
pe_positive[:, 0::2] = torch.sin(position * div_term)
|
| 322 |
+
pe_positive[:, 1::2] = torch.cos(position * div_term)
|
| 323 |
+
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
|
| 324 |
+
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
|
| 325 |
+
|
| 326 |
+
# Reserve the order of positive indices and concat both positive and
|
| 327 |
+
# negative indices. This is used to support the shifting trick
|
| 328 |
+
# as in https://arxiv.org/abs/1901.02860
|
| 329 |
+
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
|
| 330 |
+
pe_negative = pe_negative[1:].unsqueeze(0)
|
| 331 |
+
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
| 332 |
+
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
| 333 |
+
|
| 334 |
+
def forward(self, x: torch.Tensor):
|
| 335 |
+
"""Add positional encoding.
|
| 336 |
+
|
| 337 |
+
Args:
|
| 338 |
+
x (torch.Tensor): Input tensor (batch, time, `*`).
|
| 339 |
+
|
| 340 |
+
Returns:
|
| 341 |
+
torch.Tensor: Encoded tensor (batch, time, `*`).
|
| 342 |
+
|
| 343 |
+
"""
|
| 344 |
+
self.extend_pe(x)
|
| 345 |
+
x = x * self.xscale
|
| 346 |
+
pos_emb = self.pe[
|
| 347 |
+
:,
|
| 348 |
+
self.pe.size(1) // 2 - x.size(1) + 1 : self.pe.size(1) // 2 + x.size(1),
|
| 349 |
+
]
|
| 350 |
+
return self.dropout(x), self.dropout(pos_emb)
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
class StreamPositionalEncoding(torch.nn.Module):
|
| 354 |
+
"""Streaming Positional encoding.
|
| 355 |
+
|
| 356 |
+
Args:
|
| 357 |
+
d_model (int): Embedding dimension.
|
| 358 |
+
dropout_rate (float): Dropout rate.
|
| 359 |
+
max_len (int): Maximum input length.
|
| 360 |
+
|
| 361 |
+
"""
|
| 362 |
+
|
| 363 |
+
def __init__(self, d_model, dropout_rate, max_len=5000):
|
| 364 |
+
"""Construct an PositionalEncoding object."""
|
| 365 |
+
super(StreamPositionalEncoding, self).__init__()
|
| 366 |
+
self.d_model = d_model
|
| 367 |
+
self.xscale = math.sqrt(self.d_model)
|
| 368 |
+
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
| 369 |
+
self.pe = None
|
| 370 |
+
self.tmp = torch.tensor(0.0).expand(1, max_len)
|
| 371 |
+
self.extend_pe(self.tmp.size(1), self.tmp.device, self.tmp.dtype)
|
| 372 |
+
self._register_load_state_dict_pre_hook(_pre_hook)
|
| 373 |
+
|
| 374 |
+
def extend_pe(self, length, device, dtype):
|
| 375 |
+
"""Reset the positional encodings."""
|
| 376 |
+
if self.pe is not None:
|
| 377 |
+
if self.pe.size(1) >= length:
|
| 378 |
+
if self.pe.dtype != dtype or self.pe.device != device:
|
| 379 |
+
self.pe = self.pe.to(dtype=dtype, device=device)
|
| 380 |
+
return
|
| 381 |
+
pe = torch.zeros(length, self.d_model)
|
| 382 |
+
position = torch.arange(0, length, dtype=torch.float32).unsqueeze(1)
|
| 383 |
+
div_term = torch.exp(
|
| 384 |
+
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
| 385 |
+
* -(math.log(10000.0) / self.d_model)
|
| 386 |
+
)
|
| 387 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 388 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 389 |
+
pe = pe.unsqueeze(0)
|
| 390 |
+
self.pe = pe.to(device=device, dtype=dtype)
|
| 391 |
+
|
| 392 |
+
def forward(self, x: torch.Tensor, start_idx: int = 0):
|
| 393 |
+
"""Add positional encoding.
|
| 394 |
+
|
| 395 |
+
Args:
|
| 396 |
+
x (torch.Tensor): Input tensor (batch, time, `*`).
|
| 397 |
+
|
| 398 |
+
Returns:
|
| 399 |
+
torch.Tensor: Encoded tensor (batch, time, `*`).
|
| 400 |
+
|
| 401 |
+
"""
|
| 402 |
+
self.extend_pe(x.size(1) + start_idx, x.device, x.dtype)
|
| 403 |
+
x = x * self.xscale + self.pe[:, start_idx : start_idx + x.size(1)]
|
| 404 |
+
return self.dropout(x)
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
class ConvolutionalPositionalEmbedding(torch.nn.Module):
|
| 408 |
+
"""Convolutional positional embedding.
|
| 409 |
+
|
| 410 |
+
Used in wav2vec2/HuBERT SSL models.
|
| 411 |
+
https://arxiv.org/abs/1904.11660
|
| 412 |
+
|
| 413 |
+
Args:
|
| 414 |
+
embed_dim (int): Feature dimension of the input Tensor.
|
| 415 |
+
dropout (float): unused
|
| 416 |
+
max_len (int): unused
|
| 417 |
+
num_layers (int): number of conv layers
|
| 418 |
+
kernel_size (int): The number of frames to be use.
|
| 419 |
+
groups (int): The number of groups in feature dimensions.
|
| 420 |
+
weight_norm (str): [new, legacy, none].
|
| 421 |
+
How to init conv weights. Recommended setting is
|
| 422 |
+
none if num_layers > 1.
|
| 423 |
+
"""
|
| 424 |
+
|
| 425 |
+
def __init__(
|
| 426 |
+
self,
|
| 427 |
+
embed_dim: int,
|
| 428 |
+
dropout: float,
|
| 429 |
+
max_len: int = 5000,
|
| 430 |
+
num_layers: int = 1,
|
| 431 |
+
kernel_size: int = 128,
|
| 432 |
+
groups: int = 16,
|
| 433 |
+
weight_norm: str = "new",
|
| 434 |
+
use_residual: bool = False,
|
| 435 |
+
):
|
| 436 |
+
"""Initialize Convoluational Positional Embedding."""
|
| 437 |
+
super().__init__()
|
| 438 |
+
self.embed_dim = embed_dim
|
| 439 |
+
self.kernel_size = kernel_size
|
| 440 |
+
self.weight_norm = weight_norm
|
| 441 |
+
|
| 442 |
+
convs = []
|
| 443 |
+
for layer in range(num_layers):
|
| 444 |
+
conv = torch.nn.Conv1d(
|
| 445 |
+
in_channels=embed_dim,
|
| 446 |
+
out_channels=embed_dim,
|
| 447 |
+
kernel_size=kernel_size,
|
| 448 |
+
padding=kernel_size // 2,
|
| 449 |
+
groups=groups,
|
| 450 |
+
)
|
| 451 |
+
if weight_norm != "none" and weight_norm is not None:
|
| 452 |
+
std = math.sqrt((4 * (1.0)) / (kernel_size * embed_dim))
|
| 453 |
+
torch.nn.init.normal_(conv.weight, mean=0, std=std)
|
| 454 |
+
torch.nn.init.constant_(conv.bias, 0)
|
| 455 |
+
# torch.nn.utils.weight_norm leads to weird behavior
|
| 456 |
+
# with copy.deepcopy(). Usually isnt needed,
|
| 457 |
+
# but its important for models that use EMA
|
| 458 |
+
if weight_norm == "new":
|
| 459 |
+
if V(torch.__version__) >= V("2.2.0"):
|
| 460 |
+
conv = torch.nn.utils.parametrizations.weight_norm(
|
| 461 |
+
conv, name="weight", dim=2
|
| 462 |
+
)
|
| 463 |
+
else:
|
| 464 |
+
weight_norm = "legacy"
|
| 465 |
+
logging.warning(
|
| 466 |
+
"torch.nn.utils.parametrizations.weight_norm is only "
|
| 467 |
+
+ "supported for pytorch versions >= 2.2.0. "
|
| 468 |
+
+ "Defaulting to torch.nn.utils.weight_norm."
|
| 469 |
+
)
|
| 470 |
+
if weight_norm == "legacy":
|
| 471 |
+
conv = torch.nn.utils.weight_norm(conv, name="weight", dim=2)
|
| 472 |
+
convs.append(conv)
|
| 473 |
+
self.convs = torch.nn.ModuleList(convs)
|
| 474 |
+
self.num_remove: int = 1 if kernel_size % 2 == 0 else 0
|
| 475 |
+
self.use_residual = use_residual
|
| 476 |
+
|
| 477 |
+
def __prepare_scriptable__(self):
|
| 478 |
+
"""Prepare Scriptable method."""
|
| 479 |
+
for hook in self.conv._forward_pre_hooks.values():
|
| 480 |
+
# The hook we want to remove is an instance of WeightNorm class, so
|
| 481 |
+
# normally we would do `if isinstance(...)` but this class is not accessible
|
| 482 |
+
# because of shadowing, so we check the module name directly.
|
| 483 |
+
# https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3
|
| 484 |
+
if (
|
| 485 |
+
hook.__module__ == "torch.nn.utils.weight_norm"
|
| 486 |
+
and hook.__class__.__name__ == "WeightNorm"
|
| 487 |
+
):
|
| 488 |
+
logging.warning("Removing weight_norm from %s", self.__class__.__name__)
|
| 489 |
+
torch.nn.utils.remove_weight_norm(self.conv)
|
| 490 |
+
return self
|
| 491 |
+
|
| 492 |
+
def forward(self, x):
|
| 493 |
+
"""Forward Method.
|
| 494 |
+
|
| 495 |
+
Args:
|
| 496 |
+
x (Tensor): shape ``[batch, frame, feature]``.
|
| 497 |
+
|
| 498 |
+
Returns:
|
| 499 |
+
Tensor: The resulting feature. Shape ``[batch, frame, feature]``.
|
| 500 |
+
"""
|
| 501 |
+
if self.use_residual:
|
| 502 |
+
residual = x
|
| 503 |
+
|
| 504 |
+
x = x.transpose(-2, -1)
|
| 505 |
+
for conv in self.convs:
|
| 506 |
+
x = conv(x)
|
| 507 |
+
|
| 508 |
+
# remove extra padding
|
| 509 |
+
if self.num_remove > 0:
|
| 510 |
+
x = x[..., : -self.num_remove]
|
| 511 |
+
|
| 512 |
+
x = torch.nn.functional.gelu(x)
|
| 513 |
+
|
| 514 |
+
# manually normalize if the conv is not parameterized
|
| 515 |
+
# with weight norm
|
| 516 |
+
if self.weight_norm is None or self.weight_norm == "none":
|
| 517 |
+
x = dim_1_layer_norm(x)
|
| 518 |
+
|
| 519 |
+
x = x.transpose(-2, -1)
|
| 520 |
+
|
| 521 |
+
if self.use_residual:
|
| 522 |
+
x = x + residual
|
| 523 |
+
return x
|
src/espnet_import/fastformer.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Fastformer attention definition.
|
| 2 |
+
|
| 3 |
+
Reference:
|
| 4 |
+
Wu et al., "Fastformer: Additive Attention Can Be All You Need"
|
| 5 |
+
https://arxiv.org/abs/2108.09084
|
| 6 |
+
https://github.com/wuch15/Fastformer
|
| 7 |
+
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import numpy
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class FastSelfAttention(torch.nn.Module):
|
| 15 |
+
"""Fast self-attention used in Fastformer."""
|
| 16 |
+
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
size,
|
| 20 |
+
attention_heads,
|
| 21 |
+
dropout_rate,
|
| 22 |
+
):
|
| 23 |
+
super().__init__()
|
| 24 |
+
if size % attention_heads != 0:
|
| 25 |
+
raise ValueError(
|
| 26 |
+
f"Hidden size ({size}) is not an integer multiple "
|
| 27 |
+
f"of attention heads ({attention_heads})"
|
| 28 |
+
)
|
| 29 |
+
self.attention_head_size = size // attention_heads
|
| 30 |
+
self.num_attention_heads = attention_heads
|
| 31 |
+
|
| 32 |
+
self.query = torch.nn.Linear(size, size)
|
| 33 |
+
self.query_att = torch.nn.Linear(size, attention_heads)
|
| 34 |
+
self.key = torch.nn.Linear(size, size)
|
| 35 |
+
self.key_att = torch.nn.Linear(size, attention_heads)
|
| 36 |
+
self.transform = torch.nn.Linear(size, size)
|
| 37 |
+
self.dropout = torch.nn.Dropout(dropout_rate)
|
| 38 |
+
|
| 39 |
+
def espnet_initialization_fn(self):
|
| 40 |
+
self.apply(self.init_weights)
|
| 41 |
+
|
| 42 |
+
def init_weights(self, module):
|
| 43 |
+
if isinstance(module, torch.nn.Linear):
|
| 44 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
| 45 |
+
if isinstance(module, torch.nn.Linear) and module.bias is not None:
|
| 46 |
+
module.bias.data.zero_()
|
| 47 |
+
|
| 48 |
+
def transpose_for_scores(self, x):
|
| 49 |
+
"""Reshape and transpose to compute scores.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
x: (batch, time, size = n_heads * attn_dim)
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
(batch, n_heads, time, attn_dim)
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
new_x_shape = x.shape[:-1] + (
|
| 59 |
+
self.num_attention_heads,
|
| 60 |
+
self.attention_head_size,
|
| 61 |
+
)
|
| 62 |
+
return x.reshape(*new_x_shape).transpose(1, 2)
|
| 63 |
+
|
| 64 |
+
def forward(self, xs_pad, mask):
|
| 65 |
+
"""Forward method.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
xs_pad: (batch, time, size = n_heads * attn_dim)
|
| 69 |
+
mask: (batch, 1, time), nonpadding is 1, padding is 0
|
| 70 |
+
|
| 71 |
+
Returns:
|
| 72 |
+
torch.Tensor: (batch, time, size)
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
batch_size, seq_len, _ = xs_pad.shape
|
| 76 |
+
mixed_query_layer = self.query(xs_pad) # (batch, time, size)
|
| 77 |
+
mixed_key_layer = self.key(xs_pad) # (batch, time, size)
|
| 78 |
+
|
| 79 |
+
if mask is not None:
|
| 80 |
+
mask = mask.eq(0) # padding is 1, nonpadding is 0
|
| 81 |
+
|
| 82 |
+
# (batch, n_heads, time)
|
| 83 |
+
query_for_score = (
|
| 84 |
+
self.query_att(mixed_query_layer).transpose(1, 2)
|
| 85 |
+
/ self.attention_head_size**0.5
|
| 86 |
+
)
|
| 87 |
+
if mask is not None:
|
| 88 |
+
min_value = float(
|
| 89 |
+
numpy.finfo(
|
| 90 |
+
torch.tensor(0, dtype=query_for_score.dtype).numpy().dtype
|
| 91 |
+
).min
|
| 92 |
+
)
|
| 93 |
+
query_for_score = query_for_score.masked_fill(mask, min_value)
|
| 94 |
+
query_weight = torch.softmax(query_for_score, dim=-1).masked_fill(mask, 0.0)
|
| 95 |
+
else:
|
| 96 |
+
query_weight = torch.softmax(query_for_score, dim=-1)
|
| 97 |
+
|
| 98 |
+
query_weight = query_weight.unsqueeze(2) # (batch, n_heads, 1, time)
|
| 99 |
+
query_layer = self.transpose_for_scores(
|
| 100 |
+
mixed_query_layer
|
| 101 |
+
) # (batch, n_heads, time, attn_dim)
|
| 102 |
+
|
| 103 |
+
pooled_query = (
|
| 104 |
+
torch.matmul(query_weight, query_layer)
|
| 105 |
+
.transpose(1, 2)
|
| 106 |
+
.reshape(-1, 1, self.num_attention_heads * self.attention_head_size)
|
| 107 |
+
) # (batch, 1, size = n_heads * attn_dim)
|
| 108 |
+
pooled_query = self.dropout(pooled_query)
|
| 109 |
+
pooled_query_repeat = pooled_query.repeat(1, seq_len, 1) # (batch, time, size)
|
| 110 |
+
|
| 111 |
+
mixed_query_key_layer = (
|
| 112 |
+
mixed_key_layer * pooled_query_repeat
|
| 113 |
+
) # (batch, time, size)
|
| 114 |
+
|
| 115 |
+
# (batch, n_heads, time)
|
| 116 |
+
query_key_score = (
|
| 117 |
+
self.key_att(mixed_query_key_layer) / self.attention_head_size**0.5
|
| 118 |
+
).transpose(1, 2)
|
| 119 |
+
if mask is not None:
|
| 120 |
+
min_value = float(
|
| 121 |
+
numpy.finfo(
|
| 122 |
+
torch.tensor(0, dtype=query_key_score.dtype).numpy().dtype
|
| 123 |
+
).min
|
| 124 |
+
)
|
| 125 |
+
query_key_score = query_key_score.masked_fill(mask, min_value)
|
| 126 |
+
query_key_weight = torch.softmax(query_key_score, dim=-1).masked_fill(
|
| 127 |
+
mask, 0.0
|
| 128 |
+
)
|
| 129 |
+
else:
|
| 130 |
+
query_key_weight = torch.softmax(query_key_score, dim=-1)
|
| 131 |
+
|
| 132 |
+
query_key_weight = query_key_weight.unsqueeze(2) # (batch, n_heads, 1, time)
|
| 133 |
+
key_layer = self.transpose_for_scores(
|
| 134 |
+
mixed_query_key_layer
|
| 135 |
+
) # (batch, n_heads, time, attn_dim)
|
| 136 |
+
pooled_key = torch.matmul(
|
| 137 |
+
query_key_weight, key_layer
|
| 138 |
+
) # (batch, n_heads, 1, attn_dim)
|
| 139 |
+
pooled_key = self.dropout(pooled_key)
|
| 140 |
+
|
| 141 |
+
# NOTE: value = query, due to param sharing
|
| 142 |
+
weighted_value = (pooled_key * query_layer).transpose(
|
| 143 |
+
1, 2
|
| 144 |
+
) # (batch, time, n_heads, attn_dim)
|
| 145 |
+
weighted_value = weighted_value.reshape(
|
| 146 |
+
weighted_value.shape[:-2]
|
| 147 |
+
+ (self.num_attention_heads * self.attention_head_size,)
|
| 148 |
+
) # (batch, time, size)
|
| 149 |
+
weighted_value = (
|
| 150 |
+
self.dropout(self.transform(weighted_value)) + mixed_query_layer
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
return weighted_value
|
src/espnet_import/label_smoothing_loss.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
# Copyright 2019 Shigeki Karita
|
| 5 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
| 6 |
+
|
| 7 |
+
# from espnet/nets/pytorch_backend/transformer/label_smoothing_loss.py
|
| 8 |
+
"""Label smoothing module."""
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from torch import nn
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class LabelSmoothingLoss(nn.Module):
|
| 15 |
+
"""Label-smoothing loss.
|
| 16 |
+
|
| 17 |
+
:param int size: the number of class
|
| 18 |
+
:param int padding_idx: ignored class id
|
| 19 |
+
:param float smoothing: smoothing rate (0.0 means the conventional CE)
|
| 20 |
+
:param bool normalize_length: normalize loss by sequence length if True
|
| 21 |
+
:param torch.nn.Module criterion: loss function to be smoothed
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
size,
|
| 27 |
+
padding_idx,
|
| 28 |
+
smoothing,
|
| 29 |
+
normalize_length=False,
|
| 30 |
+
criterion=nn.KLDivLoss(reduction="none"),
|
| 31 |
+
):
|
| 32 |
+
"""Construct an LabelSmoothingLoss object."""
|
| 33 |
+
super(LabelSmoothingLoss, self).__init__()
|
| 34 |
+
self.criterion = criterion
|
| 35 |
+
self.padding_idx = padding_idx
|
| 36 |
+
self.confidence = 1.0 - smoothing
|
| 37 |
+
self.smoothing = smoothing
|
| 38 |
+
self.size = size
|
| 39 |
+
self.true_dist = None
|
| 40 |
+
self.normalize_length = normalize_length
|
| 41 |
+
|
| 42 |
+
def forward(self, x, target):
|
| 43 |
+
"""Compute loss between x and target.
|
| 44 |
+
|
| 45 |
+
:param torch.Tensor x: prediction (batch, seqlen, class)
|
| 46 |
+
:param torch.Tensor target:
|
| 47 |
+
target signal masked with self.padding_id (batch, seqlen)
|
| 48 |
+
:return: scalar float value
|
| 49 |
+
:rtype torch.Tensor
|
| 50 |
+
"""
|
| 51 |
+
assert x.size(2) == self.size
|
| 52 |
+
batch_size = x.size(0)
|
| 53 |
+
x = x.view(-1, self.size)
|
| 54 |
+
target = target.view(-1)
|
| 55 |
+
with torch.no_grad():
|
| 56 |
+
true_dist = x.clone()
|
| 57 |
+
true_dist.fill_(self.smoothing / (self.size - 1))
|
| 58 |
+
ignore = target == self.padding_idx # (B,)
|
| 59 |
+
total = len(target) - ignore.sum().item()
|
| 60 |
+
target = target.masked_fill(ignore, 0) # avoid -1 index
|
| 61 |
+
true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
|
| 62 |
+
kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
|
| 63 |
+
denom = total if self.normalize_length else batch_size
|
| 64 |
+
return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
|
src/espnet_import/layer_norm.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
# Copyright 2019 Shigeki Karita
|
| 5 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
| 6 |
+
# from https://github.com/espnet/espnet/blob/master/espnet2/legacy/nets/pytorch_backend/transformer/layer_norm.py
|
| 7 |
+
|
| 8 |
+
"""Layer normalization module."""
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class LayerNorm(torch.nn.LayerNorm):
|
| 14 |
+
"""Layer normalization module.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
nout (int): Output dim size.
|
| 18 |
+
dim (int): Dimension to be normalized.
|
| 19 |
+
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, nout, dim=-1):
|
| 23 |
+
"""Construct an LayerNorm object."""
|
| 24 |
+
super(LayerNorm, self).__init__(nout, eps=1e-12)
|
| 25 |
+
self.dim = dim
|
| 26 |
+
|
| 27 |
+
def forward(self, x):
|
| 28 |
+
"""Apply layer normalization.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
x (torch.Tensor): Input tensor.
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
torch.Tensor: Normalized tensor.
|
| 35 |
+
|
| 36 |
+
"""
|
| 37 |
+
if self.dim == -1:
|
| 38 |
+
return super(LayerNorm, self).forward(x)
|
| 39 |
+
return (
|
| 40 |
+
super(LayerNorm, self)
|
| 41 |
+
.forward(x.transpose(self.dim, -1))
|
| 42 |
+
.transpose(self.dim, -1)
|
| 43 |
+
)
|
src/espnet_import/nets_utils.py
ADDED
|
@@ -0,0 +1,690 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
# from https://github.com/espnet/espnet/blob/master/espnet2/legacy/nets/pytorch_backend/nets_utils.py
|
| 4 |
+
"""Network related utility tools."""
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
from typing import Dict, Optional
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def to_device(m, x):
|
| 14 |
+
"""Send tensor into the device of the module.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
m (torch.nn.Module): Torch module.
|
| 18 |
+
x (Tensor): Torch tensor.
|
| 19 |
+
|
| 20 |
+
Returns:
|
| 21 |
+
Tensor: Torch tensor located in the same place as torch module.
|
| 22 |
+
|
| 23 |
+
"""
|
| 24 |
+
if isinstance(m, torch.nn.Module):
|
| 25 |
+
device = next(m.parameters()).device
|
| 26 |
+
elif isinstance(m, torch.Tensor):
|
| 27 |
+
device = m.device
|
| 28 |
+
else:
|
| 29 |
+
raise TypeError(
|
| 30 |
+
"Expected torch.nn.Module or torch.tensor, " f"bot got: {type(m)}"
|
| 31 |
+
)
|
| 32 |
+
return x.to(device)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def pad_list(xs, pad_value):
|
| 36 |
+
"""Perform padding for the list of tensors.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
|
| 40 |
+
pad_value (float): Value for padding.
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
Tensor: Padded tensor (B, Tmax, `*`).
|
| 44 |
+
|
| 45 |
+
Examples:
|
| 46 |
+
>>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
|
| 47 |
+
>>> x
|
| 48 |
+
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
|
| 49 |
+
>>> pad_list(x, 0)
|
| 50 |
+
tensor([[1., 1., 1., 1.],
|
| 51 |
+
[1., 1., 0., 0.],
|
| 52 |
+
[1., 0., 0., 0.]])
|
| 53 |
+
|
| 54 |
+
"""
|
| 55 |
+
n_batch = len(xs)
|
| 56 |
+
max_len = max(x.size(0) for x in xs)
|
| 57 |
+
pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
|
| 58 |
+
|
| 59 |
+
for i in range(n_batch):
|
| 60 |
+
pad[i, : xs[i].size(0)] = xs[i]
|
| 61 |
+
|
| 62 |
+
return pad
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
@torch.compiler.disable
|
| 66 |
+
def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None):
|
| 67 |
+
"""Make mask tensor containing indices of padded part.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
lengths (LongTensor or List): Batch of lengths (B,).
|
| 71 |
+
xs (Tensor, optional): The reference tensor.
|
| 72 |
+
If set, masks will be the same shape as this tensor.
|
| 73 |
+
length_dim (int, optional): Dimension indicator of the above tensor.
|
| 74 |
+
See the example.
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
Tensor: Mask tensor containing indices of padded part.
|
| 78 |
+
dtype=torch.uint8 in PyTorch 1.2-
|
| 79 |
+
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
|
| 80 |
+
|
| 81 |
+
Examples:
|
| 82 |
+
With only lengths.
|
| 83 |
+
|
| 84 |
+
>>> lengths = [5, 3, 2]
|
| 85 |
+
>>> make_pad_mask(lengths)
|
| 86 |
+
masks = [[0, 0, 0, 0 ,0],
|
| 87 |
+
[0, 0, 0, 1, 1],
|
| 88 |
+
[0, 0, 1, 1, 1]]
|
| 89 |
+
|
| 90 |
+
With the reference tensor.
|
| 91 |
+
|
| 92 |
+
>>> xs = torch.zeros((3, 2, 4))
|
| 93 |
+
>>> make_pad_mask(lengths, xs)
|
| 94 |
+
tensor([[[0, 0, 0, 0],
|
| 95 |
+
[0, 0, 0, 0]],
|
| 96 |
+
[[0, 0, 0, 1],
|
| 97 |
+
[0, 0, 0, 1]],
|
| 98 |
+
[[0, 0, 1, 1],
|
| 99 |
+
[0, 0, 1, 1]]], dtype=torch.uint8)
|
| 100 |
+
>>> xs = torch.zeros((3, 2, 6))
|
| 101 |
+
>>> make_pad_mask(lengths, xs)
|
| 102 |
+
tensor([[[0, 0, 0, 0, 0, 1],
|
| 103 |
+
[0, 0, 0, 0, 0, 1]],
|
| 104 |
+
[[0, 0, 0, 1, 1, 1],
|
| 105 |
+
[0, 0, 0, 1, 1, 1]],
|
| 106 |
+
[[0, 0, 1, 1, 1, 1],
|
| 107 |
+
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
|
| 108 |
+
|
| 109 |
+
With the reference tensor and dimension indicator.
|
| 110 |
+
|
| 111 |
+
>>> xs = torch.zeros((3, 6, 6))
|
| 112 |
+
>>> make_pad_mask(lengths, xs, 1)
|
| 113 |
+
tensor([[[0, 0, 0, 0, 0, 0],
|
| 114 |
+
[0, 0, 0, 0, 0, 0],
|
| 115 |
+
[0, 0, 0, 0, 0, 0],
|
| 116 |
+
[0, 0, 0, 0, 0, 0],
|
| 117 |
+
[0, 0, 0, 0, 0, 0],
|
| 118 |
+
[1, 1, 1, 1, 1, 1]],
|
| 119 |
+
[[0, 0, 0, 0, 0, 0],
|
| 120 |
+
[0, 0, 0, 0, 0, 0],
|
| 121 |
+
[0, 0, 0, 0, 0, 0],
|
| 122 |
+
[1, 1, 1, 1, 1, 1],
|
| 123 |
+
[1, 1, 1, 1, 1, 1],
|
| 124 |
+
[1, 1, 1, 1, 1, 1]],
|
| 125 |
+
[[0, 0, 0, 0, 0, 0],
|
| 126 |
+
[0, 0, 0, 0, 0, 0],
|
| 127 |
+
[1, 1, 1, 1, 1, 1],
|
| 128 |
+
[1, 1, 1, 1, 1, 1],
|
| 129 |
+
[1, 1, 1, 1, 1, 1],
|
| 130 |
+
[1, 1, 1, 1, 1, 1]]], dtype=torch.uint8)
|
| 131 |
+
>>> make_pad_mask(lengths, xs, 2)
|
| 132 |
+
tensor([[[0, 0, 0, 0, 0, 1],
|
| 133 |
+
[0, 0, 0, 0, 0, 1],
|
| 134 |
+
[0, 0, 0, 0, 0, 1],
|
| 135 |
+
[0, 0, 0, 0, 0, 1],
|
| 136 |
+
[0, 0, 0, 0, 0, 1],
|
| 137 |
+
[0, 0, 0, 0, 0, 1]],
|
| 138 |
+
[[0, 0, 0, 1, 1, 1],
|
| 139 |
+
[0, 0, 0, 1, 1, 1],
|
| 140 |
+
[0, 0, 0, 1, 1, 1],
|
| 141 |
+
[0, 0, 0, 1, 1, 1],
|
| 142 |
+
[0, 0, 0, 1, 1, 1],
|
| 143 |
+
[0, 0, 0, 1, 1, 1]],
|
| 144 |
+
[[0, 0, 1, 1, 1, 1],
|
| 145 |
+
[0, 0, 1, 1, 1, 1],
|
| 146 |
+
[0, 0, 1, 1, 1, 1],
|
| 147 |
+
[0, 0, 1, 1, 1, 1],
|
| 148 |
+
[0, 0, 1, 1, 1, 1],
|
| 149 |
+
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
|
| 150 |
+
|
| 151 |
+
"""
|
| 152 |
+
if length_dim == 0:
|
| 153 |
+
raise ValueError("length_dim cannot be 0: {}".format(length_dim))
|
| 154 |
+
|
| 155 |
+
# If the input dimension is 2 or 3,
|
| 156 |
+
# then we use ESPnet-ONNX based implementation for tracable modeling.
|
| 157 |
+
# otherwise we use the traditional implementation for research use.
|
| 158 |
+
if isinstance(lengths, list):
|
| 159 |
+
logging.warning(
|
| 160 |
+
"Using make_pad_mask with a list of lengths is not tracable. "
|
| 161 |
+
+ "If you try to trace this function with type(lengths) == list, "
|
| 162 |
+
+ "please change the type of lengths to torch.LongTensor."
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
if (
|
| 166 |
+
(xs is None or xs.dim() in (2, 3))
|
| 167 |
+
and length_dim <= 2
|
| 168 |
+
and (not isinstance(lengths, list) and lengths.dim() == 1)
|
| 169 |
+
):
|
| 170 |
+
return _make_pad_mask_traceable(lengths, xs, length_dim, maxlen)
|
| 171 |
+
else:
|
| 172 |
+
return _make_pad_mask(lengths, xs, length_dim, maxlen)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def _make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None):
|
| 176 |
+
if not isinstance(lengths, list):
|
| 177 |
+
lengths = lengths.long().tolist()
|
| 178 |
+
|
| 179 |
+
bs = int(len(lengths))
|
| 180 |
+
if maxlen is None:
|
| 181 |
+
if xs is None:
|
| 182 |
+
maxlen = int(max(lengths))
|
| 183 |
+
else:
|
| 184 |
+
maxlen = xs.size(length_dim)
|
| 185 |
+
else:
|
| 186 |
+
assert xs is None, "When maxlen is specified, xs must not be specified."
|
| 187 |
+
assert maxlen >= int(
|
| 188 |
+
max(lengths)
|
| 189 |
+
), f"maxlen {maxlen} must be >= max(lengths) {max(lengths)}"
|
| 190 |
+
|
| 191 |
+
seq_range = torch.arange(0, maxlen, dtype=torch.int64)
|
| 192 |
+
seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen)
|
| 193 |
+
seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1)
|
| 194 |
+
mask = seq_range_expand >= seq_length_expand
|
| 195 |
+
|
| 196 |
+
if xs is not None:
|
| 197 |
+
assert (
|
| 198 |
+
xs.size(0) == bs
|
| 199 |
+
), f"The size of x.size(0) {xs.size(0)} must match the batch size {bs}"
|
| 200 |
+
|
| 201 |
+
if length_dim < 0:
|
| 202 |
+
length_dim = xs.dim() + length_dim
|
| 203 |
+
# ind = (:, None, ..., None, :, , None, ..., None)
|
| 204 |
+
ind = tuple(
|
| 205 |
+
slice(None) if i in (0, length_dim) else None for i in range(xs.dim())
|
| 206 |
+
)
|
| 207 |
+
mask = mask[ind].expand_as(xs).to(xs.device)
|
| 208 |
+
return mask
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def _make_pad_mask_traceable(lengths, xs, length_dim, maxlen=None):
|
| 212 |
+
"""Make mask tensor containing indices of padded part.
|
| 213 |
+
|
| 214 |
+
This is a simplified implementation of make_pad_mask without the xs input
|
| 215 |
+
that supports JIT tracing for applications like exporting models to ONNX.
|
| 216 |
+
Dimension length of xs should be 2 or 3
|
| 217 |
+
This function will create torch.ones(maxlen, maxlen).triu(diagonal=1) and
|
| 218 |
+
select rows to create mask tensor.
|
| 219 |
+
"""
|
| 220 |
+
if xs is None:
|
| 221 |
+
device = lengths.device
|
| 222 |
+
else:
|
| 223 |
+
device = xs.device
|
| 224 |
+
|
| 225 |
+
if xs is not None and len(xs.shape) == 3:
|
| 226 |
+
if length_dim == 1:
|
| 227 |
+
lengths = lengths.unsqueeze(1).expand(*xs.transpose(1, 2).shape[:2])
|
| 228 |
+
else:
|
| 229 |
+
# Then length_dim is 2 or -1.
|
| 230 |
+
if length_dim not in (-1, 2):
|
| 231 |
+
logging.warning(
|
| 232 |
+
f"Invalid length_dim {length_dim}."
|
| 233 |
+
+ "We set it to -1, which is the default value."
|
| 234 |
+
)
|
| 235 |
+
length_dim = -1
|
| 236 |
+
lengths = lengths.unsqueeze(1).expand(*xs.shape[:2])
|
| 237 |
+
|
| 238 |
+
if maxlen is not None:
|
| 239 |
+
assert xs is None
|
| 240 |
+
assert maxlen >= lengths.max()
|
| 241 |
+
elif xs is not None:
|
| 242 |
+
maxlen = xs.shape[length_dim]
|
| 243 |
+
else:
|
| 244 |
+
maxlen = lengths.max()
|
| 245 |
+
|
| 246 |
+
# clip max(length) to maxlen
|
| 247 |
+
lengths = torch.clamp(lengths, max=maxlen).type(torch.long)
|
| 248 |
+
|
| 249 |
+
mask = torch.ones(maxlen + 1, maxlen + 1, dtype=torch.bool, device=device)
|
| 250 |
+
mask = triu_onnx(mask)[1:, :-1] # onnx cannot handle diagonal argument.
|
| 251 |
+
mask = mask[lengths - 1][..., :maxlen]
|
| 252 |
+
|
| 253 |
+
if xs is not None and len(xs.shape) == 3 and length_dim == 1:
|
| 254 |
+
return mask.transpose(1, 2)
|
| 255 |
+
else:
|
| 256 |
+
return mask
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def triu_onnx(x):
|
| 260 |
+
"""Make TriU for ONNX."""
|
| 261 |
+
arange = torch.arange(x.size(0), device=x.device)
|
| 262 |
+
mask = arange.unsqueeze(-1).expand(-1, x.size(0)) <= arange
|
| 263 |
+
return x * mask
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def make_non_pad_mask(lengths, xs=None, length_dim=-1):
|
| 267 |
+
"""Make mask tensor containing indices of non-padded part.
|
| 268 |
+
|
| 269 |
+
Args:
|
| 270 |
+
lengths (LongTensor or List): Batch of lengths (B,).
|
| 271 |
+
xs (Tensor, optional): The reference tensor.
|
| 272 |
+
If set, masks will be the same shape as this tensor.
|
| 273 |
+
length_dim (int, optional): Dimension indicator of the above tensor.
|
| 274 |
+
See the example.
|
| 275 |
+
|
| 276 |
+
Returns:
|
| 277 |
+
ByteTensor: mask tensor containing indices of padded part.
|
| 278 |
+
dtype=torch.uint8 in PyTorch 1.2-
|
| 279 |
+
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
|
| 280 |
+
|
| 281 |
+
Examples:
|
| 282 |
+
With only lengths.
|
| 283 |
+
|
| 284 |
+
>>> lengths = [5, 3, 2]
|
| 285 |
+
>>> make_non_pad_mask(lengths)
|
| 286 |
+
masks = [[1, 1, 1, 1 ,1],
|
| 287 |
+
[1, 1, 1, 0, 0],
|
| 288 |
+
[1, 1, 0, 0, 0]]
|
| 289 |
+
|
| 290 |
+
With the reference tensor.
|
| 291 |
+
|
| 292 |
+
>>> xs = torch.zeros((3, 2, 4))
|
| 293 |
+
>>> make_non_pad_mask(lengths, xs)
|
| 294 |
+
tensor([[[1, 1, 1, 1],
|
| 295 |
+
[1, 1, 1, 1]],
|
| 296 |
+
[[1, 1, 1, 0],
|
| 297 |
+
[1, 1, 1, 0]],
|
| 298 |
+
[[1, 1, 0, 0],
|
| 299 |
+
[1, 1, 0, 0]]], dtype=torch.uint8)
|
| 300 |
+
>>> xs = torch.zeros((3, 2, 6))
|
| 301 |
+
>>> make_non_pad_mask(lengths, xs)
|
| 302 |
+
tensor([[[1, 1, 1, 1, 1, 0],
|
| 303 |
+
[1, 1, 1, 1, 1, 0]],
|
| 304 |
+
[[1, 1, 1, 0, 0, 0],
|
| 305 |
+
[1, 1, 1, 0, 0, 0]],
|
| 306 |
+
[[1, 1, 0, 0, 0, 0],
|
| 307 |
+
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
|
| 308 |
+
|
| 309 |
+
With the reference tensor and dimension indicator.
|
| 310 |
+
|
| 311 |
+
>>> xs = torch.zeros((3, 6, 6))
|
| 312 |
+
>>> make_non_pad_mask(lengths, xs, 1)
|
| 313 |
+
tensor([[[1, 1, 1, 1, 1, 1],
|
| 314 |
+
[1, 1, 1, 1, 1, 1],
|
| 315 |
+
[1, 1, 1, 1, 1, 1],
|
| 316 |
+
[1, 1, 1, 1, 1, 1],
|
| 317 |
+
[1, 1, 1, 1, 1, 1],
|
| 318 |
+
[0, 0, 0, 0, 0, 0]],
|
| 319 |
+
[[1, 1, 1, 1, 1, 1],
|
| 320 |
+
[1, 1, 1, 1, 1, 1],
|
| 321 |
+
[1, 1, 1, 1, 1, 1],
|
| 322 |
+
[0, 0, 0, 0, 0, 0],
|
| 323 |
+
[0, 0, 0, 0, 0, 0],
|
| 324 |
+
[0, 0, 0, 0, 0, 0]],
|
| 325 |
+
[[1, 1, 1, 1, 1, 1],
|
| 326 |
+
[1, 1, 1, 1, 1, 1],
|
| 327 |
+
[0, 0, 0, 0, 0, 0],
|
| 328 |
+
[0, 0, 0, 0, 0, 0],
|
| 329 |
+
[0, 0, 0, 0, 0, 0],
|
| 330 |
+
[0, 0, 0, 0, 0, 0]]], dtype=torch.uint8)
|
| 331 |
+
>>> make_non_pad_mask(lengths, xs, 2)
|
| 332 |
+
tensor([[[1, 1, 1, 1, 1, 0],
|
| 333 |
+
[1, 1, 1, 1, 1, 0],
|
| 334 |
+
[1, 1, 1, 1, 1, 0],
|
| 335 |
+
[1, 1, 1, 1, 1, 0],
|
| 336 |
+
[1, 1, 1, 1, 1, 0],
|
| 337 |
+
[1, 1, 1, 1, 1, 0]],
|
| 338 |
+
[[1, 1, 1, 0, 0, 0],
|
| 339 |
+
[1, 1, 1, 0, 0, 0],
|
| 340 |
+
[1, 1, 1, 0, 0, 0],
|
| 341 |
+
[1, 1, 1, 0, 0, 0],
|
| 342 |
+
[1, 1, 1, 0, 0, 0],
|
| 343 |
+
[1, 1, 1, 0, 0, 0]],
|
| 344 |
+
[[1, 1, 0, 0, 0, 0],
|
| 345 |
+
[1, 1, 0, 0, 0, 0],
|
| 346 |
+
[1, 1, 0, 0, 0, 0],
|
| 347 |
+
[1, 1, 0, 0, 0, 0],
|
| 348 |
+
[1, 1, 0, 0, 0, 0],
|
| 349 |
+
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
|
| 350 |
+
|
| 351 |
+
"""
|
| 352 |
+
return ~make_pad_mask(lengths, xs, length_dim)
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
def mask_by_length(xs, lengths, fill=0):
|
| 356 |
+
"""Mask tensor according to length.
|
| 357 |
+
|
| 358 |
+
Args:
|
| 359 |
+
xs (Tensor): Batch of input tensor (B, `*`).
|
| 360 |
+
lengths (LongTensor or List): Batch of lengths (B,).
|
| 361 |
+
fill (int or float): Value to fill masked part.
|
| 362 |
+
|
| 363 |
+
Returns:
|
| 364 |
+
Tensor: Batch of masked input tensor (B, `*`).
|
| 365 |
+
|
| 366 |
+
Examples:
|
| 367 |
+
>>> x = torch.arange(5).repeat(3, 1) + 1
|
| 368 |
+
>>> x
|
| 369 |
+
tensor([[1, 2, 3, 4, 5],
|
| 370 |
+
[1, 2, 3, 4, 5],
|
| 371 |
+
[1, 2, 3, 4, 5]])
|
| 372 |
+
>>> lengths = [5, 3, 2]
|
| 373 |
+
>>> mask_by_length(x, lengths)
|
| 374 |
+
tensor([[1, 2, 3, 4, 5],
|
| 375 |
+
[1, 2, 3, 0, 0],
|
| 376 |
+
[1, 2, 0, 0, 0]])
|
| 377 |
+
|
| 378 |
+
"""
|
| 379 |
+
assert xs.size(0) == len(lengths)
|
| 380 |
+
ret = xs.data.new(*xs.size()).fill_(fill)
|
| 381 |
+
for i, l in enumerate(lengths):
|
| 382 |
+
ret[i, :l] = xs[i, :l]
|
| 383 |
+
return ret
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
def th_accuracy(pad_outputs, pad_targets, ignore_label):
|
| 387 |
+
"""Calculate accuracy.
|
| 388 |
+
|
| 389 |
+
Args:
|
| 390 |
+
pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
|
| 391 |
+
pad_targets (LongTensor): Target label tensors (B, Lmax, D).
|
| 392 |
+
ignore_label (int): Ignore label id.
|
| 393 |
+
|
| 394 |
+
Returns:
|
| 395 |
+
float: Accuracy value (0.0 - 1.0).
|
| 396 |
+
|
| 397 |
+
"""
|
| 398 |
+
pad_pred = pad_outputs.view(
|
| 399 |
+
pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1)
|
| 400 |
+
).argmax(2)
|
| 401 |
+
mask = pad_targets != ignore_label
|
| 402 |
+
numerator = torch.sum(
|
| 403 |
+
pad_pred.masked_select(mask) == pad_targets.masked_select(mask)
|
| 404 |
+
)
|
| 405 |
+
denominator = torch.sum(mask)
|
| 406 |
+
return float(numerator) / float(denominator)
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
def to_torch_tensor(x):
|
| 410 |
+
"""Change to torch.Tensor or ComplexTensor from numpy.ndarray.
|
| 411 |
+
|
| 412 |
+
Args:
|
| 413 |
+
x: Inputs. It should be one of numpy.ndarray, Tensor, ComplexTensor, and dict.
|
| 414 |
+
|
| 415 |
+
Returns:
|
| 416 |
+
Tensor or ComplexTensor: Type converted inputs.
|
| 417 |
+
|
| 418 |
+
Examples:
|
| 419 |
+
>>> xs = np.ones(3, dtype=np.float32)
|
| 420 |
+
>>> xs = to_torch_tensor(xs)
|
| 421 |
+
tensor([1., 1., 1.])
|
| 422 |
+
>>> xs = torch.ones(3, 4, 5)
|
| 423 |
+
>>> assert to_torch_tensor(xs) is xs
|
| 424 |
+
>>> xs = {'real': xs, 'imag': xs}
|
| 425 |
+
>>> to_torch_tensor(xs)
|
| 426 |
+
ComplexTensor(
|
| 427 |
+
Real:
|
| 428 |
+
tensor([1., 1., 1.])
|
| 429 |
+
Imag;
|
| 430 |
+
tensor([1., 1., 1.])
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
"""
|
| 434 |
+
# If numpy, change to torch tensor
|
| 435 |
+
if isinstance(x, np.ndarray):
|
| 436 |
+
if x.dtype.kind == "c":
|
| 437 |
+
# Dynamically importing because torch_complex requires python3
|
| 438 |
+
from torch_complex.tensor import ComplexTensor
|
| 439 |
+
|
| 440 |
+
return ComplexTensor(x)
|
| 441 |
+
else:
|
| 442 |
+
return torch.from_numpy(x)
|
| 443 |
+
|
| 444 |
+
# If {'real': ..., 'imag': ...}, convert to ComplexTensor
|
| 445 |
+
elif isinstance(x, dict):
|
| 446 |
+
# Dynamically importing because torch_complex requires python3
|
| 447 |
+
from torch_complex.tensor import ComplexTensor
|
| 448 |
+
|
| 449 |
+
if "real" not in x or "imag" not in x:
|
| 450 |
+
raise ValueError("has 'real' and 'imag' keys: {}".format(list(x)))
|
| 451 |
+
# Relative importing because of using python3 syntax
|
| 452 |
+
return ComplexTensor(x["real"], x["imag"])
|
| 453 |
+
|
| 454 |
+
# If torch.Tensor, as it is
|
| 455 |
+
elif isinstance(x, torch.Tensor):
|
| 456 |
+
return x
|
| 457 |
+
|
| 458 |
+
else:
|
| 459 |
+
error = (
|
| 460 |
+
"x must be numpy.ndarray, torch.Tensor or a dict like "
|
| 461 |
+
"{{'real': torch.Tensor, 'imag': torch.Tensor}}, "
|
| 462 |
+
"but got {}".format(type(x))
|
| 463 |
+
)
|
| 464 |
+
try:
|
| 465 |
+
from torch_complex.tensor import ComplexTensor
|
| 466 |
+
except Exception:
|
| 467 |
+
# If PY2
|
| 468 |
+
raise ValueError(error)
|
| 469 |
+
else:
|
| 470 |
+
# If PY3
|
| 471 |
+
if isinstance(x, ComplexTensor):
|
| 472 |
+
return x
|
| 473 |
+
else:
|
| 474 |
+
raise ValueError(error)
|
| 475 |
+
|
| 476 |
+
|
| 477 |
+
def get_subsample(train_args, mode, arch):
|
| 478 |
+
"""Parse the subsampling factors from the args for the specified `mode` and `arch`.
|
| 479 |
+
|
| 480 |
+
Args:
|
| 481 |
+
train_args: argument Namespace containing options.
|
| 482 |
+
mode: one of ('asr', 'mt', 'st')
|
| 483 |
+
arch: one of ('rnn', 'rnn-t', 'rnn_mix', 'rnn_mulenc', 'transformer')
|
| 484 |
+
|
| 485 |
+
Returns:
|
| 486 |
+
np.ndarray / List[np.ndarray]: subsampling factors.
|
| 487 |
+
"""
|
| 488 |
+
if arch == "transformer":
|
| 489 |
+
return np.array([1])
|
| 490 |
+
|
| 491 |
+
elif mode == "mt" and arch == "rnn":
|
| 492 |
+
# +1 means input (+1) and layers outputs (train_args.elayer)
|
| 493 |
+
subsample = np.ones(train_args.elayers + 1, dtype=np.int64)
|
| 494 |
+
logging.warning("Subsampling is not performed for machine translation.")
|
| 495 |
+
logging.info("subsample: " + " ".join([str(x) for x in subsample]))
|
| 496 |
+
return subsample
|
| 497 |
+
|
| 498 |
+
elif (
|
| 499 |
+
(mode == "asr" and arch in ("rnn", "rnn-t"))
|
| 500 |
+
or (mode == "mt" and arch == "rnn")
|
| 501 |
+
or (mode == "st" and arch == "rnn")
|
| 502 |
+
):
|
| 503 |
+
subsample = np.ones(train_args.elayers + 1, dtype=np.int64)
|
| 504 |
+
if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"):
|
| 505 |
+
ss = train_args.subsample.split("_")
|
| 506 |
+
for j in range(min(train_args.elayers + 1, len(ss))):
|
| 507 |
+
subsample[j] = int(ss[j])
|
| 508 |
+
else:
|
| 509 |
+
logging.warning(
|
| 510 |
+
"Subsampling is not performed for vgg*. "
|
| 511 |
+
"It is performed in max pooling layers at CNN."
|
| 512 |
+
)
|
| 513 |
+
logging.info("subsample: " + " ".join([str(x) for x in subsample]))
|
| 514 |
+
return subsample
|
| 515 |
+
|
| 516 |
+
elif mode == "asr" and arch == "rnn_mix":
|
| 517 |
+
subsample = np.ones(
|
| 518 |
+
train_args.elayers_sd + train_args.elayers + 1, dtype=np.int64
|
| 519 |
+
)
|
| 520 |
+
if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"):
|
| 521 |
+
ss = train_args.subsample.split("_")
|
| 522 |
+
for j in range(
|
| 523 |
+
min(train_args.elayers_sd + train_args.elayers + 1, len(ss))
|
| 524 |
+
):
|
| 525 |
+
subsample[j] = int(ss[j])
|
| 526 |
+
else:
|
| 527 |
+
logging.warning(
|
| 528 |
+
"Subsampling is not performed for vgg*. "
|
| 529 |
+
"It is performed in max pooling layers at CNN."
|
| 530 |
+
)
|
| 531 |
+
logging.info("subsample: " + " ".join([str(x) for x in subsample]))
|
| 532 |
+
return subsample
|
| 533 |
+
|
| 534 |
+
elif mode == "asr" and arch == "rnn_mulenc":
|
| 535 |
+
subsample_list = []
|
| 536 |
+
for idx in range(train_args.num_encs):
|
| 537 |
+
subsample = np.ones(train_args.elayers[idx] + 1, dtype=np.int64)
|
| 538 |
+
if train_args.etype[idx].endswith("p") and not train_args.etype[
|
| 539 |
+
idx
|
| 540 |
+
].startswith("vgg"):
|
| 541 |
+
ss = train_args.subsample[idx].split("_")
|
| 542 |
+
for j in range(min(train_args.elayers[idx] + 1, len(ss))):
|
| 543 |
+
subsample[j] = int(ss[j])
|
| 544 |
+
else:
|
| 545 |
+
logging.warning(
|
| 546 |
+
"Encoder %d: Subsampling is not performed for vgg*. "
|
| 547 |
+
"It is performed in max pooling layers at CNN.",
|
| 548 |
+
idx + 1,
|
| 549 |
+
)
|
| 550 |
+
logging.info("subsample: " + " ".join([str(x) for x in subsample]))
|
| 551 |
+
subsample_list.append(subsample)
|
| 552 |
+
return subsample_list
|
| 553 |
+
|
| 554 |
+
else:
|
| 555 |
+
raise ValueError("Invalid options: mode={}, arch={}".format(mode, arch))
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
def rename_state_dict(
|
| 559 |
+
old_prefix: str, new_prefix: str, state_dict: Dict[str, torch.Tensor]
|
| 560 |
+
):
|
| 561 |
+
"""Replace keys of old prefix with new prefix in state dict."""
|
| 562 |
+
# need this list not to break the dict iterator
|
| 563 |
+
old_keys = [k for k in state_dict if k.startswith(old_prefix)]
|
| 564 |
+
if len(old_keys) > 0:
|
| 565 |
+
logging.warning(f"Rename: {old_prefix} -> {new_prefix}")
|
| 566 |
+
for k in old_keys:
|
| 567 |
+
v = state_dict.pop(k)
|
| 568 |
+
new_k = k.replace(old_prefix, new_prefix)
|
| 569 |
+
state_dict[new_k] = v
|
| 570 |
+
|
| 571 |
+
import torch
|
| 572 |
+
|
| 573 |
+
|
| 574 |
+
# from espnet2.legacy.nets.pytorch_backend.conformer.swish import Swish
|
| 575 |
+
class Swish(torch.nn.Module):
|
| 576 |
+
"""Construct an Swish object."""
|
| 577 |
+
|
| 578 |
+
def forward(self, x):
|
| 579 |
+
"""Return Swich activation function."""
|
| 580 |
+
return x * torch.sigmoid(x)
|
| 581 |
+
|
| 582 |
+
|
| 583 |
+
def get_activation(act):
|
| 584 |
+
"""Return activation function."""
|
| 585 |
+
|
| 586 |
+
activation_funcs = {
|
| 587 |
+
"hardtanh": torch.nn.Hardtanh,
|
| 588 |
+
"tanh": torch.nn.Tanh,
|
| 589 |
+
"relu": torch.nn.ReLU,
|
| 590 |
+
"selu": torch.nn.SELU,
|
| 591 |
+
"swish": Swish,
|
| 592 |
+
}
|
| 593 |
+
|
| 594 |
+
return activation_funcs[act]()
|
| 595 |
+
|
| 596 |
+
|
| 597 |
+
def trim_by_ctc_posterior(
|
| 598 |
+
h: torch.Tensor,
|
| 599 |
+
ctc_probs: torch.Tensor,
|
| 600 |
+
masks: torch.Tensor,
|
| 601 |
+
pos_emb: torch.Tensor = None,
|
| 602 |
+
):
|
| 603 |
+
"""Trim the encoder hidden output using CTC posterior.
|
| 604 |
+
|
| 605 |
+
The continuous frames in the tail that confidently represent
|
| 606 |
+
blank symbols are trimmed.
|
| 607 |
+
"""
|
| 608 |
+
# Empirical settings
|
| 609 |
+
frame_tolerance = 5
|
| 610 |
+
conf_tolerance = 0.95
|
| 611 |
+
blank_id = 0
|
| 612 |
+
|
| 613 |
+
assert masks.size(1) == 1
|
| 614 |
+
masks = masks.squeeze(1)
|
| 615 |
+
hlens = masks.sum(dim=1)
|
| 616 |
+
assert h.size()[:2] == ctc_probs.size()[:2]
|
| 617 |
+
assert h.size(0) == hlens.size(0)
|
| 618 |
+
|
| 619 |
+
# blank frames
|
| 620 |
+
max_values, max_indices = ctc_probs.max(dim=2)
|
| 621 |
+
blank_masks = torch.logical_and(
|
| 622 |
+
max_values > conf_tolerance, max_indices == blank_id
|
| 623 |
+
)
|
| 624 |
+
|
| 625 |
+
# plus ignored frames
|
| 626 |
+
joint_masks = torch.logical_or(blank_masks, ~masks)
|
| 627 |
+
|
| 628 |
+
# lengths after the trimming
|
| 629 |
+
B, T, _ = h.size()
|
| 630 |
+
frame_idx = torch.where(
|
| 631 |
+
joint_masks, -1, torch.arange(T).unsqueeze(0).repeat(B, 1).to(h.device)
|
| 632 |
+
)
|
| 633 |
+
after_lens = torch.where(
|
| 634 |
+
frame_idx.max(dim=-1)[0] + frame_tolerance + 1 < hlens,
|
| 635 |
+
frame_idx.max(dim=-1)[0] + frame_tolerance + 1,
|
| 636 |
+
hlens,
|
| 637 |
+
)
|
| 638 |
+
|
| 639 |
+
h = h[:, : max(after_lens)]
|
| 640 |
+
masks = ~make_pad_mask(after_lens).to(h.device).unsqueeze(1)
|
| 641 |
+
|
| 642 |
+
if pos_emb is None:
|
| 643 |
+
pos_emb = None
|
| 644 |
+
elif (hlens.max() * 2 - 1).item() == pos_emb.size(1): # RelPositionalEncoding
|
| 645 |
+
pos_emb = pos_emb[
|
| 646 |
+
:, pos_emb.size(1) // 2 - h.size(1) + 1 : pos_emb.size(1) // 2 + h.size(1)
|
| 647 |
+
]
|
| 648 |
+
else:
|
| 649 |
+
pos_emb = pos_emb[:, : h.size(1)]
|
| 650 |
+
|
| 651 |
+
return h, masks, pos_emb
|
| 652 |
+
|
| 653 |
+
|
| 654 |
+
def roll_tensor(
|
| 655 |
+
x: torch.Tensor,
|
| 656 |
+
lengths: torch.Tensor,
|
| 657 |
+
roll_amounts: Optional[torch.Tensor] = None,
|
| 658 |
+
fixed_intervals: Optional[int] = None,
|
| 659 |
+
) -> torch.Tensor:
|
| 660 |
+
"""Left-roll tensor x by roll_amounts, only within lengths and optionally quantized.
|
| 661 |
+
|
| 662 |
+
Args:
|
| 663 |
+
x: input tensor (B, T, D)
|
| 664 |
+
lengths: lengths of each sequence (B,)
|
| 665 |
+
roll_amounts: random shift amounts (B,). If None, random shift
|
| 666 |
+
amounts are generated.
|
| 667 |
+
fixed_intervals: if not None, roll_amounts are quantized to
|
| 668 |
+
multiples of this.
|
| 669 |
+
Returns:
|
| 670 |
+
rolled_x: rolled tensor (B, T, D)
|
| 671 |
+
Useful to apply roll augmentation to the input, while considering
|
| 672 |
+
the input length for each sample.
|
| 673 |
+
"""
|
| 674 |
+
B, T, D = x.shape
|
| 675 |
+
|
| 676 |
+
indices = torch.arange(T).unsqueeze(0).expand(B, T).to(x.device) # (B, T)
|
| 677 |
+
lengths = lengths.unsqueeze(1) # (B, 1)
|
| 678 |
+
|
| 679 |
+
if roll_amounts is None:
|
| 680 |
+
roll_amounts = torch.randint(0, lengths.max(), (B,), device=x.device)
|
| 681 |
+
if fixed_intervals is not None:
|
| 682 |
+
roll_amounts = (roll_amounts // fixed_intervals) * fixed_intervals
|
| 683 |
+
roll_indices = (indices - roll_amounts.unsqueeze(1)) % lengths # (B, T)
|
| 684 |
+
roll_indices = roll_indices.unsqueeze(2).expand(-1, -1, D) # (B, T, D)
|
| 685 |
+
|
| 686 |
+
mask = indices < lengths # (B, T), True if position is valid
|
| 687 |
+
rolled_x = torch.empty_like(x)
|
| 688 |
+
rolled_x[mask] = x.gather(1, roll_indices)[mask]
|
| 689 |
+
rolled_x[~mask] = x[~mask]
|
| 690 |
+
return rolled_x
|
src/espnet_import/positionwise_feed_forward.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
# Copyright 2019 Shigeki Karita
|
| 5 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
| 6 |
+
|
| 7 |
+
"""Positionwise feed forward layer definition."""
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class PositionwiseFeedForward(torch.nn.Module):
|
| 13 |
+
"""Positionwise feed forward layer.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
idim (int): Input dimenstion.
|
| 17 |
+
hidden_units (int): The number of hidden units.
|
| 18 |
+
dropout_rate (float): Dropout rate.
|
| 19 |
+
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, idim, hidden_units, dropout_rate, activation=torch.nn.ReLU()):
|
| 23 |
+
"""Construct an PositionwiseFeedForward object."""
|
| 24 |
+
super(PositionwiseFeedForward, self).__init__()
|
| 25 |
+
self.w_1 = torch.nn.Linear(idim, hidden_units)
|
| 26 |
+
self.w_2 = torch.nn.Linear(hidden_units, idim)
|
| 27 |
+
self.dropout = torch.nn.Dropout(dropout_rate)
|
| 28 |
+
self.activation = activation
|
| 29 |
+
|
| 30 |
+
def forward(self, x):
|
| 31 |
+
"""Forward function."""
|
| 32 |
+
return self.w_2(self.dropout(self.activation(self.w_1(x))))
|
src/espnet_import/repeat.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
# Copyright 2019 Shigeki Karita
|
| 5 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
| 6 |
+
|
| 7 |
+
"""Repeat the same layer definition."""
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class MultiSequential(torch.nn.Sequential):
|
| 13 |
+
"""Multi-input multi-output torch.nn.Sequential."""
|
| 14 |
+
|
| 15 |
+
def __init__(self, *args, layer_drop_rate=0.0):
|
| 16 |
+
"""Initialize MultiSequential with layer_drop.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
layer_drop_rate (float): Probability of dropping out each fn (layer).
|
| 20 |
+
|
| 21 |
+
"""
|
| 22 |
+
super(MultiSequential, self).__init__(*args)
|
| 23 |
+
self.layer_drop_rate = layer_drop_rate
|
| 24 |
+
|
| 25 |
+
def forward(self, *args):
|
| 26 |
+
"""Repeat."""
|
| 27 |
+
_probs = torch.empty(len(self)).uniform_()
|
| 28 |
+
for idx, m in enumerate(self):
|
| 29 |
+
if not self.training or (_probs[idx] >= self.layer_drop_rate):
|
| 30 |
+
args = m(*args)
|
| 31 |
+
return args
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def repeat(N, fn, layer_drop_rate=0.0):
|
| 35 |
+
"""Repeat module N times.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
N (int): Number of repeat time.
|
| 39 |
+
fn (Callable): Function to generate module.
|
| 40 |
+
layer_drop_rate (float): Probability of dropping out each fn (layer).
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
MultiSequential: Repeated model instance.
|
| 44 |
+
|
| 45 |
+
"""
|
| 46 |
+
return MultiSequential(*[fn(n) for n in range(N)], layer_drop_rate=layer_drop_rate)
|
src/espnet_import/subsampling.py
ADDED
|
@@ -0,0 +1,873 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
# Copyright 2019 Shigeki Karita
|
| 5 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
| 6 |
+
|
| 7 |
+
"""Subsampling layer definition."""
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from src.espnet_import.embedding import PositionalEncoding
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class TooShortUttError(Exception):
|
| 15 |
+
"""Raised when the utt is too short for subsampling.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
message (str): Message for error catch
|
| 19 |
+
actual_size (int): the short size that cannot pass the subsampling
|
| 20 |
+
limit (int): the limit size for subsampling
|
| 21 |
+
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, message, actual_size, limit):
|
| 25 |
+
"""Construct a TooShortUttError for error handler."""
|
| 26 |
+
super().__init__(message)
|
| 27 |
+
self.actual_size = actual_size
|
| 28 |
+
self.limit = limit
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def check_short_utt(ins, size):
|
| 32 |
+
"""Check if the utterance is too short for subsampling."""
|
| 33 |
+
if isinstance(ins, Conv1dSubsampling1) and size < 5:
|
| 34 |
+
return True, 5
|
| 35 |
+
if isinstance(ins, Conv1dSubsampling2) and size < 5:
|
| 36 |
+
return True, 5
|
| 37 |
+
if isinstance(ins, Conv1dSubsampling3) and size < 7:
|
| 38 |
+
return True, 7
|
| 39 |
+
if isinstance(ins, Conv2dSubsampling1) and size < 5:
|
| 40 |
+
return True, 5
|
| 41 |
+
if isinstance(ins, Conv2dSubsampling2) and size < 7:
|
| 42 |
+
return True, 7
|
| 43 |
+
if isinstance(ins, Conv2dSubsampling) and size < 7:
|
| 44 |
+
return True, 7
|
| 45 |
+
if isinstance(ins, Conv2dSubsampling6) and size < 11:
|
| 46 |
+
return True, 11
|
| 47 |
+
if isinstance(ins, Conv2dSubsampling8) and size < 15:
|
| 48 |
+
return True, 15
|
| 49 |
+
return False, -1
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _upgrade_legacy_subsampling_state_dict(state_dict, prefix):
|
| 53 |
+
"""Remap legacy nn.Sequential keys for subsampling modules."""
|
| 54 |
+
w_new = prefix + "out.weight"
|
| 55 |
+
b_new = prefix + "out.bias"
|
| 56 |
+
w_old = prefix + "out.0.weight"
|
| 57 |
+
b_old = prefix + "out.0.bias"
|
| 58 |
+
|
| 59 |
+
if w_new not in state_dict and w_old in state_dict:
|
| 60 |
+
state_dict[w_new] = state_dict.pop(w_old)
|
| 61 |
+
elif w_new in state_dict and w_old in state_dict:
|
| 62 |
+
state_dict.pop(w_old)
|
| 63 |
+
|
| 64 |
+
if b_new not in state_dict and b_old in state_dict:
|
| 65 |
+
state_dict[b_new] = state_dict.pop(b_old)
|
| 66 |
+
elif b_new in state_dict and b_old in state_dict:
|
| 67 |
+
state_dict.pop(b_old)
|
| 68 |
+
|
| 69 |
+
old_pos_prefix = prefix + "out.1."
|
| 70 |
+
new_pos_prefix = prefix + "pos_enc."
|
| 71 |
+
for k in list(state_dict.keys()):
|
| 72 |
+
if not k.startswith(old_pos_prefix):
|
| 73 |
+
continue
|
| 74 |
+
new_k = new_pos_prefix + k[len(old_pos_prefix) :]
|
| 75 |
+
if new_k not in state_dict:
|
| 76 |
+
state_dict[new_k] = state_dict[k]
|
| 77 |
+
state_dict.pop(k, None)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class Conv1dSubsampling1(torch.nn.Module):
|
| 81 |
+
"""Convolutional 1D subsampling.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
idim (int): Input dimension.
|
| 85 |
+
odim (int): Output dimension.
|
| 86 |
+
dropout_rate (float): Dropout rate.
|
| 87 |
+
pos_enc (torch.nn.Module): Custom position encoding layer.
|
| 88 |
+
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
def __init__(self, idim, odim, dropout_rate, pos_enc=None):
|
| 92 |
+
"""Construct an Conv1dSubsampling1 object."""
|
| 93 |
+
super(Conv1dSubsampling1, self).__init__()
|
| 94 |
+
self.conv = torch.nn.Sequential(
|
| 95 |
+
torch.nn.Conv1d(idim, odim, 3, 1),
|
| 96 |
+
torch.nn.ReLU(),
|
| 97 |
+
torch.nn.Conv1d(odim, odim, 3, 1),
|
| 98 |
+
torch.nn.ReLU(),
|
| 99 |
+
)
|
| 100 |
+
self.out = torch.nn.Linear(odim, odim)
|
| 101 |
+
self.pos_enc = (
|
| 102 |
+
pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate)
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
def _load_from_state_dict(
|
| 106 |
+
self,
|
| 107 |
+
state_dict,
|
| 108 |
+
prefix,
|
| 109 |
+
local_metadata,
|
| 110 |
+
strict,
|
| 111 |
+
missing_keys,
|
| 112 |
+
unexpected_keys,
|
| 113 |
+
error_msgs,
|
| 114 |
+
):
|
| 115 |
+
_upgrade_legacy_subsampling_state_dict(state_dict, prefix)
|
| 116 |
+
super()._load_from_state_dict(
|
| 117 |
+
state_dict,
|
| 118 |
+
prefix,
|
| 119 |
+
local_metadata,
|
| 120 |
+
strict,
|
| 121 |
+
missing_keys,
|
| 122 |
+
unexpected_keys,
|
| 123 |
+
error_msgs,
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
def forward(self, x, x_mask, prefix_embeds=None):
|
| 127 |
+
"""Subsample x.
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
x (torch.Tensor): Input tensor (#batch, time, idim).
|
| 131 |
+
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
| 132 |
+
prefix_embeds (torch.Tensor or None): Prefix token embeddings
|
| 133 |
+
(#batch, prefix_len, odim).
|
| 134 |
+
|
| 135 |
+
Returns:
|
| 136 |
+
torch.Tensor: Subsampled tensor (#batch, time', odim),
|
| 137 |
+
where time' = time // 2.
|
| 138 |
+
torch.Tensor: Subsampled mask (#batch, 1, time'),
|
| 139 |
+
where time' = time // 2.
|
| 140 |
+
|
| 141 |
+
"""
|
| 142 |
+
x = x.transpose(2, 1) # (#batch, idim, time)
|
| 143 |
+
x = self.conv(x)
|
| 144 |
+
b, c, t = x.size()
|
| 145 |
+
x = self.out(x.transpose(1, 2).contiguous())
|
| 146 |
+
if x_mask is not None:
|
| 147 |
+
x_mask = x_mask[:, :, :-2:1][:, :, :-2:1]
|
| 148 |
+
|
| 149 |
+
if prefix_embeds is not None:
|
| 150 |
+
x = torch.cat([prefix_embeds, x], dim=1)
|
| 151 |
+
if x_mask is not None:
|
| 152 |
+
x_mask = torch.cat(
|
| 153 |
+
[
|
| 154 |
+
torch.ones(
|
| 155 |
+
x_mask.shape[0],
|
| 156 |
+
1,
|
| 157 |
+
prefix_embeds.size(1),
|
| 158 |
+
dtype=x_mask.dtype,
|
| 159 |
+
device=x_mask.device,
|
| 160 |
+
),
|
| 161 |
+
x_mask,
|
| 162 |
+
],
|
| 163 |
+
dim=-1,
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
x = self.pos_enc(x)
|
| 167 |
+
|
| 168 |
+
return x, x_mask
|
| 169 |
+
|
| 170 |
+
def __getitem__(self, key):
|
| 171 |
+
"""Get item.
|
| 172 |
+
|
| 173 |
+
When reset_parameters() is called, if use_scaled_pos_enc is used,
|
| 174 |
+
return the positioning encoding.
|
| 175 |
+
|
| 176 |
+
"""
|
| 177 |
+
if key != -1:
|
| 178 |
+
raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
|
| 179 |
+
return self.pos_enc
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class Conv1dSubsampling2(torch.nn.Module):
|
| 183 |
+
"""Convolutional 1D subsampling (to 1/2 length).
|
| 184 |
+
|
| 185 |
+
Args:
|
| 186 |
+
idim (int): Input dimension.
|
| 187 |
+
odim (int): Output dimension.
|
| 188 |
+
dropout_rate (float): Dropout rate.
|
| 189 |
+
pos_enc (torch.nn.Module): Custom position encoding layer.
|
| 190 |
+
|
| 191 |
+
"""
|
| 192 |
+
|
| 193 |
+
def __init__(self, idim, odim, dropout_rate, pos_enc=None):
|
| 194 |
+
"""Construct an Conv1dSubsampling2 object."""
|
| 195 |
+
super(Conv1dSubsampling2, self).__init__()
|
| 196 |
+
self.conv = torch.nn.Sequential(
|
| 197 |
+
torch.nn.Conv1d(idim, odim, 3, 1),
|
| 198 |
+
torch.nn.ReLU(),
|
| 199 |
+
torch.nn.Conv1d(odim, odim, 3, 2),
|
| 200 |
+
torch.nn.ReLU(),
|
| 201 |
+
)
|
| 202 |
+
self.out = torch.nn.Linear(odim, odim)
|
| 203 |
+
self.pos_enc = (
|
| 204 |
+
pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate)
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
def _load_from_state_dict(
|
| 208 |
+
self,
|
| 209 |
+
state_dict,
|
| 210 |
+
prefix,
|
| 211 |
+
local_metadata,
|
| 212 |
+
strict,
|
| 213 |
+
missing_keys,
|
| 214 |
+
unexpected_keys,
|
| 215 |
+
error_msgs,
|
| 216 |
+
):
|
| 217 |
+
_upgrade_legacy_subsampling_state_dict(state_dict, prefix)
|
| 218 |
+
super()._load_from_state_dict(
|
| 219 |
+
state_dict,
|
| 220 |
+
prefix,
|
| 221 |
+
local_metadata,
|
| 222 |
+
strict,
|
| 223 |
+
missing_keys,
|
| 224 |
+
unexpected_keys,
|
| 225 |
+
error_msgs,
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
def forward(self, x, x_mask, prefix_embeds=None):
|
| 229 |
+
"""Subsample x.
|
| 230 |
+
|
| 231 |
+
Args:
|
| 232 |
+
x (torch.Tensor): Input tensor (#batch, time, idim).
|
| 233 |
+
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
| 234 |
+
prefix_embeds (torch.Tensor or None): Prefix token embeddings
|
| 235 |
+
(#batch, prefix_len, odim).
|
| 236 |
+
|
| 237 |
+
Returns:
|
| 238 |
+
torch.Tensor: Subsampled tensor (#batch, time', odim),
|
| 239 |
+
where time' = time // 2.
|
| 240 |
+
torch.Tensor: Subsampled mask (#batch, 1, time'),
|
| 241 |
+
where time' = time // 2.
|
| 242 |
+
|
| 243 |
+
"""
|
| 244 |
+
x = x.transpose(2, 1) # (#batch, idim, time)
|
| 245 |
+
x = self.conv(x)
|
| 246 |
+
b, c, t = x.size()
|
| 247 |
+
x = self.out(x.transpose(1, 2).contiguous())
|
| 248 |
+
if x_mask is not None:
|
| 249 |
+
x_mask = x_mask[:, :, :-2:1][:, :, :-2:2]
|
| 250 |
+
|
| 251 |
+
if prefix_embeds is not None:
|
| 252 |
+
x = torch.cat([prefix_embeds, x], dim=1)
|
| 253 |
+
if x_mask is not None:
|
| 254 |
+
x_mask = torch.cat(
|
| 255 |
+
[
|
| 256 |
+
torch.ones(
|
| 257 |
+
x_mask.shape[0],
|
| 258 |
+
1,
|
| 259 |
+
prefix_embeds.size(1),
|
| 260 |
+
dtype=x_mask.dtype,
|
| 261 |
+
device=x_mask.device,
|
| 262 |
+
),
|
| 263 |
+
x_mask,
|
| 264 |
+
],
|
| 265 |
+
dim=-1,
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
x = self.pos_enc(x)
|
| 269 |
+
|
| 270 |
+
return x, x_mask
|
| 271 |
+
|
| 272 |
+
def __getitem__(self, key):
|
| 273 |
+
"""Get item.
|
| 274 |
+
|
| 275 |
+
When reset_parameters() is called, if use_scaled_pos_enc is used,
|
| 276 |
+
return the positioning encoding.
|
| 277 |
+
|
| 278 |
+
"""
|
| 279 |
+
if key != -1:
|
| 280 |
+
raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
|
| 281 |
+
return self.pos_enc
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
class Conv1dSubsampling3(torch.nn.Module):
|
| 285 |
+
"""Convolutional 1D subsampling (to 1/3 length).
|
| 286 |
+
|
| 287 |
+
Args:
|
| 288 |
+
idim (int): Input dimension.
|
| 289 |
+
odim (int): Output dimension.
|
| 290 |
+
dropout_rate (float): Dropout rate.
|
| 291 |
+
pos_enc (torch.nn.Module): Custom position encoding layer.
|
| 292 |
+
|
| 293 |
+
"""
|
| 294 |
+
|
| 295 |
+
def __init__(self, idim, odim, dropout_rate, pos_enc=None):
|
| 296 |
+
"""Construct an Conv1dSubsampling3 object."""
|
| 297 |
+
super(Conv1dSubsampling3, self).__init__()
|
| 298 |
+
self.conv = torch.nn.Sequential(
|
| 299 |
+
torch.nn.Conv1d(idim, odim, 3, 1),
|
| 300 |
+
torch.nn.ReLU(),
|
| 301 |
+
torch.nn.Conv1d(odim, odim, 5, 3),
|
| 302 |
+
torch.nn.ReLU(),
|
| 303 |
+
)
|
| 304 |
+
self.out = torch.nn.Linear(odim, odim)
|
| 305 |
+
self.pos_enc = (
|
| 306 |
+
pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate)
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
def _load_from_state_dict(
|
| 310 |
+
self,
|
| 311 |
+
state_dict,
|
| 312 |
+
prefix,
|
| 313 |
+
local_metadata,
|
| 314 |
+
strict,
|
| 315 |
+
missing_keys,
|
| 316 |
+
unexpected_keys,
|
| 317 |
+
error_msgs,
|
| 318 |
+
):
|
| 319 |
+
_upgrade_legacy_subsampling_state_dict(state_dict, prefix)
|
| 320 |
+
super()._load_from_state_dict(
|
| 321 |
+
state_dict,
|
| 322 |
+
prefix,
|
| 323 |
+
local_metadata,
|
| 324 |
+
strict,
|
| 325 |
+
missing_keys,
|
| 326 |
+
unexpected_keys,
|
| 327 |
+
error_msgs,
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
def forward(self, x, x_mask, prefix_embeds=None):
|
| 331 |
+
"""Subsample x.
|
| 332 |
+
|
| 333 |
+
Args:
|
| 334 |
+
x (torch.Tensor): Input tensor (#batch, time, idim).
|
| 335 |
+
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
| 336 |
+
prefix_embeds (torch.Tensor or None): Prefix token embeddings
|
| 337 |
+
(#batch, prefix_len, odim).
|
| 338 |
+
|
| 339 |
+
Returns:
|
| 340 |
+
torch.Tensor: Subsampled tensor (#batch, time', odim),
|
| 341 |
+
where time' = time // 2.
|
| 342 |
+
torch.Tensor: Subsampled mask (#batch, 1, time'),
|
| 343 |
+
where time' = time // 2.
|
| 344 |
+
|
| 345 |
+
"""
|
| 346 |
+
x = x.transpose(2, 1) # (#batch, idim, time)
|
| 347 |
+
x = self.conv(x)
|
| 348 |
+
b, c, t = x.size()
|
| 349 |
+
x = self.out(x.transpose(1, 2).contiguous())
|
| 350 |
+
if x_mask is not None:
|
| 351 |
+
x_mask = x_mask[:, :, :-2:1][:, :, :-4:3]
|
| 352 |
+
|
| 353 |
+
if prefix_embeds is not None:
|
| 354 |
+
x = torch.cat([prefix_embeds, x], dim=1)
|
| 355 |
+
if x_mask is not None:
|
| 356 |
+
x_mask = torch.cat(
|
| 357 |
+
[
|
| 358 |
+
torch.ones(
|
| 359 |
+
x_mask.shape[0],
|
| 360 |
+
1,
|
| 361 |
+
prefix_embeds.size(1),
|
| 362 |
+
dtype=x_mask.dtype,
|
| 363 |
+
device=x_mask.device,
|
| 364 |
+
),
|
| 365 |
+
x_mask,
|
| 366 |
+
],
|
| 367 |
+
dim=-1,
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
x = self.pos_enc(x)
|
| 371 |
+
|
| 372 |
+
return x, x_mask
|
| 373 |
+
|
| 374 |
+
def __getitem__(self, key):
|
| 375 |
+
"""Get item.
|
| 376 |
+
|
| 377 |
+
When reset_parameters() is called, if use_scaled_pos_enc is used,
|
| 378 |
+
return the positioning encoding.
|
| 379 |
+
|
| 380 |
+
"""
|
| 381 |
+
if key != -1:
|
| 382 |
+
raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
|
| 383 |
+
return self.pos_enc
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
class Conv2dSubsampling(torch.nn.Module):
|
| 387 |
+
"""Convolutional 2D subsampling (to 1/4 length).
|
| 388 |
+
|
| 389 |
+
Args:
|
| 390 |
+
idim (int): Input dimension.
|
| 391 |
+
odim (int): Output dimension.
|
| 392 |
+
dropout_rate (float): Dropout rate.
|
| 393 |
+
pos_enc (torch.nn.Module): Custom position encoding layer.
|
| 394 |
+
|
| 395 |
+
"""
|
| 396 |
+
|
| 397 |
+
def __init__(self, idim, odim, dropout_rate, pos_enc=None):
|
| 398 |
+
"""Construct an Conv2dSubsampling object."""
|
| 399 |
+
super(Conv2dSubsampling, self).__init__()
|
| 400 |
+
self.conv = torch.nn.Sequential(
|
| 401 |
+
torch.nn.Conv2d(1, odim, 3, 2),
|
| 402 |
+
torch.nn.ReLU(),
|
| 403 |
+
torch.nn.Conv2d(odim, odim, 3, 2),
|
| 404 |
+
torch.nn.ReLU(),
|
| 405 |
+
)
|
| 406 |
+
self.out = torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
|
| 407 |
+
self.pos_enc = (
|
| 408 |
+
pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate)
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
def _load_from_state_dict(
|
| 412 |
+
self,
|
| 413 |
+
state_dict,
|
| 414 |
+
prefix,
|
| 415 |
+
local_metadata,
|
| 416 |
+
strict,
|
| 417 |
+
missing_keys,
|
| 418 |
+
unexpected_keys,
|
| 419 |
+
error_msgs,
|
| 420 |
+
):
|
| 421 |
+
_upgrade_legacy_subsampling_state_dict(state_dict, prefix)
|
| 422 |
+
super()._load_from_state_dict(
|
| 423 |
+
state_dict,
|
| 424 |
+
prefix,
|
| 425 |
+
local_metadata,
|
| 426 |
+
strict,
|
| 427 |
+
missing_keys,
|
| 428 |
+
unexpected_keys,
|
| 429 |
+
error_msgs,
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
def forward(self, x, x_mask, prefix_embeds=None):
|
| 433 |
+
"""Subsample x.
|
| 434 |
+
|
| 435 |
+
Args:
|
| 436 |
+
x (torch.Tensor): Input tensor (#batch, time, idim).
|
| 437 |
+
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
| 438 |
+
prefix_embeds (torch.Tensor or None): Prefix token embeddings
|
| 439 |
+
(#batch, prefix_len, odim).
|
| 440 |
+
|
| 441 |
+
Returns:
|
| 442 |
+
torch.Tensor: Subsampled tensor (#batch, time', odim),
|
| 443 |
+
where time' = time // 4.
|
| 444 |
+
torch.Tensor: Subsampled mask (#batch, 1, time'),
|
| 445 |
+
where time' = time // 4.
|
| 446 |
+
|
| 447 |
+
"""
|
| 448 |
+
x = x.unsqueeze(1) # (b, c, t, f)
|
| 449 |
+
x = self.conv(x)
|
| 450 |
+
b, c, t, f = x.size()
|
| 451 |
+
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
| 452 |
+
if x_mask is not None:
|
| 453 |
+
x_mask = x_mask[:, :, :-2:2][:, :, :-2:2]
|
| 454 |
+
|
| 455 |
+
if prefix_embeds is not None:
|
| 456 |
+
x = torch.cat([prefix_embeds, x], dim=1)
|
| 457 |
+
if x_mask is not None:
|
| 458 |
+
x_mask = torch.cat(
|
| 459 |
+
[
|
| 460 |
+
torch.ones(
|
| 461 |
+
x_mask.shape[0],
|
| 462 |
+
1,
|
| 463 |
+
prefix_embeds.size(1),
|
| 464 |
+
dtype=x_mask.dtype,
|
| 465 |
+
device=x_mask.device,
|
| 466 |
+
),
|
| 467 |
+
x_mask,
|
| 468 |
+
],
|
| 469 |
+
dim=-1,
|
| 470 |
+
)
|
| 471 |
+
|
| 472 |
+
x = self.pos_enc(x)
|
| 473 |
+
|
| 474 |
+
return x, x_mask
|
| 475 |
+
|
| 476 |
+
# def __getitem__(self, key):
|
| 477 |
+
# """Get item.
|
| 478 |
+
|
| 479 |
+
# When reset_parameters() is called, if use_scaled_pos_enc is used,
|
| 480 |
+
# return the positioning encoding.
|
| 481 |
+
|
| 482 |
+
# """
|
| 483 |
+
# if key != -1:
|
| 484 |
+
# raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
|
| 485 |
+
# return self.out[key]
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
class Conv2dSubsampling1(torch.nn.Module):
|
| 489 |
+
"""Similar to Conv2dSubsampling module, but without any subsampling performed.
|
| 490 |
+
|
| 491 |
+
Args:
|
| 492 |
+
idim (int): Input dimension.
|
| 493 |
+
odim (int): Output dimension.
|
| 494 |
+
dropout_rate (float): Dropout rate.
|
| 495 |
+
pos_enc (torch.nn.Module): Custom position encoding layer.
|
| 496 |
+
|
| 497 |
+
"""
|
| 498 |
+
|
| 499 |
+
def __init__(self, idim, odim, dropout_rate, pos_enc=None):
|
| 500 |
+
"""Construct an Conv2dSubsampling1 object."""
|
| 501 |
+
super(Conv2dSubsampling1, self).__init__()
|
| 502 |
+
self.conv = torch.nn.Sequential(
|
| 503 |
+
torch.nn.Conv2d(1, odim, 3, 1),
|
| 504 |
+
torch.nn.ReLU(),
|
| 505 |
+
torch.nn.Conv2d(odim, odim, 3, 1),
|
| 506 |
+
torch.nn.ReLU(),
|
| 507 |
+
)
|
| 508 |
+
self.out = torch.nn.Linear(odim * (idim - 4), odim)
|
| 509 |
+
self.pos_enc = (
|
| 510 |
+
pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate)
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
def _load_from_state_dict(
|
| 514 |
+
self,
|
| 515 |
+
state_dict,
|
| 516 |
+
prefix,
|
| 517 |
+
local_metadata,
|
| 518 |
+
strict,
|
| 519 |
+
missing_keys,
|
| 520 |
+
unexpected_keys,
|
| 521 |
+
error_msgs,
|
| 522 |
+
):
|
| 523 |
+
_upgrade_legacy_subsampling_state_dict(state_dict, prefix)
|
| 524 |
+
super()._load_from_state_dict(
|
| 525 |
+
state_dict,
|
| 526 |
+
prefix,
|
| 527 |
+
local_metadata,
|
| 528 |
+
strict,
|
| 529 |
+
missing_keys,
|
| 530 |
+
unexpected_keys,
|
| 531 |
+
error_msgs,
|
| 532 |
+
)
|
| 533 |
+
|
| 534 |
+
def forward(self, x, x_mask, prefix_embeds=None):
|
| 535 |
+
"""Pass x through 2 Conv2d layers without subsampling.
|
| 536 |
+
|
| 537 |
+
Args:
|
| 538 |
+
x (torch.Tensor): Input tensor (#batch, time, idim).
|
| 539 |
+
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
| 540 |
+
prefix_embeds (torch.Tensor or None): Prefix token embeddings
|
| 541 |
+
(#batch, prefix_len, odim).
|
| 542 |
+
|
| 543 |
+
Returns:
|
| 544 |
+
torch.Tensor: Subsampled tensor (#batch, time', odim).
|
| 545 |
+
where time' = time - 4.
|
| 546 |
+
torch.Tensor: Subsampled mask (#batch, 1, time').
|
| 547 |
+
where time' = time - 4.
|
| 548 |
+
|
| 549 |
+
"""
|
| 550 |
+
x = x.unsqueeze(1) # (b, c, t, f)
|
| 551 |
+
x = self.conv(x)
|
| 552 |
+
b, c, t, f = x.size()
|
| 553 |
+
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
| 554 |
+
if x_mask is not None:
|
| 555 |
+
x_mask = x_mask[:, :, :-4]
|
| 556 |
+
|
| 557 |
+
if prefix_embeds is not None:
|
| 558 |
+
x = torch.cat([prefix_embeds, x], dim=1)
|
| 559 |
+
if x_mask is not None:
|
| 560 |
+
x_mask = torch.cat(
|
| 561 |
+
[
|
| 562 |
+
torch.ones(
|
| 563 |
+
x_mask.shape[0],
|
| 564 |
+
1,
|
| 565 |
+
prefix_embeds.size(1),
|
| 566 |
+
dtype=x_mask.dtype,
|
| 567 |
+
device=x_mask.device,
|
| 568 |
+
),
|
| 569 |
+
x_mask,
|
| 570 |
+
],
|
| 571 |
+
dim=-1,
|
| 572 |
+
)
|
| 573 |
+
|
| 574 |
+
x = self.pos_enc(x)
|
| 575 |
+
|
| 576 |
+
return x, x_mask
|
| 577 |
+
|
| 578 |
+
def __getitem__(self, key):
|
| 579 |
+
"""Get item.
|
| 580 |
+
|
| 581 |
+
When reset_parameters() is called, if use_scaled_pos_enc is used,
|
| 582 |
+
return the positioning encoding.
|
| 583 |
+
|
| 584 |
+
"""
|
| 585 |
+
if key != -1:
|
| 586 |
+
raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
|
| 587 |
+
return self.pos_enc
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
class Conv2dSubsampling2(torch.nn.Module):
|
| 591 |
+
"""Convolutional 2D subsampling (to 1/2 length).
|
| 592 |
+
|
| 593 |
+
Args:
|
| 594 |
+
idim (int): Input dimension.
|
| 595 |
+
odim (int): Output dimension.
|
| 596 |
+
dropout_rate (float): Dropout rate.
|
| 597 |
+
pos_enc (torch.nn.Module): Custom position encoding layer.
|
| 598 |
+
|
| 599 |
+
"""
|
| 600 |
+
|
| 601 |
+
def __init__(self, idim, odim, dropout_rate, pos_enc=None):
|
| 602 |
+
"""Construct an Conv2dSubsampling2 object."""
|
| 603 |
+
super(Conv2dSubsampling2, self).__init__()
|
| 604 |
+
self.conv = torch.nn.Sequential(
|
| 605 |
+
torch.nn.Conv2d(1, odim, 3, 2),
|
| 606 |
+
torch.nn.ReLU(),
|
| 607 |
+
torch.nn.Conv2d(odim, odim, 3, 1),
|
| 608 |
+
torch.nn.ReLU(),
|
| 609 |
+
)
|
| 610 |
+
self.out = torch.nn.Linear(odim * (((idim - 1) // 2 - 2)), odim)
|
| 611 |
+
self.pos_enc = (
|
| 612 |
+
pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate)
|
| 613 |
+
)
|
| 614 |
+
|
| 615 |
+
def _load_from_state_dict(
|
| 616 |
+
self,
|
| 617 |
+
state_dict,
|
| 618 |
+
prefix,
|
| 619 |
+
local_metadata,
|
| 620 |
+
strict,
|
| 621 |
+
missing_keys,
|
| 622 |
+
unexpected_keys,
|
| 623 |
+
error_msgs,
|
| 624 |
+
):
|
| 625 |
+
_upgrade_legacy_subsampling_state_dict(state_dict, prefix)
|
| 626 |
+
super()._load_from_state_dict(
|
| 627 |
+
state_dict,
|
| 628 |
+
prefix,
|
| 629 |
+
local_metadata,
|
| 630 |
+
strict,
|
| 631 |
+
missing_keys,
|
| 632 |
+
unexpected_keys,
|
| 633 |
+
error_msgs,
|
| 634 |
+
)
|
| 635 |
+
|
| 636 |
+
def forward(self, x, x_mask, prefix_embeds=None):
|
| 637 |
+
"""Subsample x.
|
| 638 |
+
|
| 639 |
+
Args:
|
| 640 |
+
x (torch.Tensor): Input tensor (#batch, time, idim).
|
| 641 |
+
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
| 642 |
+
prefix_embeds (torch.Tensor or None): Prefix token embeddings
|
| 643 |
+
(#batch, prefix_len, odim).
|
| 644 |
+
|
| 645 |
+
Returns:
|
| 646 |
+
torch.Tensor: Subsampled tensor (#batch, time', odim),
|
| 647 |
+
where time' = time // 2.
|
| 648 |
+
torch.Tensor: Subsampled mask (#batch, 1, time'),
|
| 649 |
+
where time' = time // 2.
|
| 650 |
+
|
| 651 |
+
"""
|
| 652 |
+
x = x.unsqueeze(1) # (b, c, t, f)
|
| 653 |
+
x = self.conv(x)
|
| 654 |
+
b, c, t, f = x.size()
|
| 655 |
+
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
| 656 |
+
if x_mask is not None:
|
| 657 |
+
x_mask = x_mask[:, :, :-2:2][:, :, :-2:1]
|
| 658 |
+
|
| 659 |
+
if prefix_embeds is not None:
|
| 660 |
+
x = torch.cat([prefix_embeds, x], dim=1)
|
| 661 |
+
if x_mask is not None:
|
| 662 |
+
x_mask = torch.cat(
|
| 663 |
+
[
|
| 664 |
+
torch.ones(
|
| 665 |
+
x_mask.shape[0],
|
| 666 |
+
1,
|
| 667 |
+
prefix_embeds.size(1),
|
| 668 |
+
dtype=x_mask.dtype,
|
| 669 |
+
device=x_mask.device,
|
| 670 |
+
),
|
| 671 |
+
x_mask,
|
| 672 |
+
],
|
| 673 |
+
dim=-1,
|
| 674 |
+
)
|
| 675 |
+
|
| 676 |
+
x = self.pos_enc(x)
|
| 677 |
+
|
| 678 |
+
return x, x_mask
|
| 679 |
+
|
| 680 |
+
def __getitem__(self, key):
|
| 681 |
+
"""Get item.
|
| 682 |
+
|
| 683 |
+
When reset_parameters() is called, if use_scaled_pos_enc is used,
|
| 684 |
+
return the positioning encoding.
|
| 685 |
+
|
| 686 |
+
"""
|
| 687 |
+
if key != -1:
|
| 688 |
+
raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
|
| 689 |
+
return self.pos_enc
|
| 690 |
+
|
| 691 |
+
|
| 692 |
+
class Conv2dSubsampling6(torch.nn.Module):
|
| 693 |
+
"""Convolutional 2D subsampling (to 1/6 length).
|
| 694 |
+
|
| 695 |
+
Args:
|
| 696 |
+
idim (int): Input dimension.
|
| 697 |
+
odim (int): Output dimension.
|
| 698 |
+
dropout_rate (float): Dropout rate.
|
| 699 |
+
pos_enc (torch.nn.Module): Custom position encoding layer.
|
| 700 |
+
|
| 701 |
+
"""
|
| 702 |
+
|
| 703 |
+
def __init__(self, idim, odim, dropout_rate, pos_enc=None):
|
| 704 |
+
"""Construct an Conv2dSubsampling6 object."""
|
| 705 |
+
super(Conv2dSubsampling6, self).__init__()
|
| 706 |
+
self.conv = torch.nn.Sequential(
|
| 707 |
+
torch.nn.Conv2d(1, odim, 3, 2),
|
| 708 |
+
torch.nn.ReLU(),
|
| 709 |
+
torch.nn.Conv2d(odim, odim, 5, 3),
|
| 710 |
+
torch.nn.ReLU(),
|
| 711 |
+
)
|
| 712 |
+
self.out = torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3), odim)
|
| 713 |
+
self.pos_enc = (
|
| 714 |
+
pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate)
|
| 715 |
+
)
|
| 716 |
+
|
| 717 |
+
def _load_from_state_dict(
|
| 718 |
+
self,
|
| 719 |
+
state_dict,
|
| 720 |
+
prefix,
|
| 721 |
+
local_metadata,
|
| 722 |
+
strict,
|
| 723 |
+
missing_keys,
|
| 724 |
+
unexpected_keys,
|
| 725 |
+
error_msgs,
|
| 726 |
+
):
|
| 727 |
+
_upgrade_legacy_subsampling_state_dict(state_dict, prefix)
|
| 728 |
+
super()._load_from_state_dict(
|
| 729 |
+
state_dict,
|
| 730 |
+
prefix,
|
| 731 |
+
local_metadata,
|
| 732 |
+
strict,
|
| 733 |
+
missing_keys,
|
| 734 |
+
unexpected_keys,
|
| 735 |
+
error_msgs,
|
| 736 |
+
)
|
| 737 |
+
|
| 738 |
+
def forward(self, x, x_mask, prefix_embeds=None):
|
| 739 |
+
"""Subsample x.
|
| 740 |
+
|
| 741 |
+
Args:
|
| 742 |
+
x (torch.Tensor): Input tensor (#batch, time, idim).
|
| 743 |
+
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
| 744 |
+
prefix_embeds (torch.Tensor or None): Prefix token embeddings
|
| 745 |
+
(#batch, prefix_len, odim).
|
| 746 |
+
|
| 747 |
+
Returns:
|
| 748 |
+
torch.Tensor: Subsampled tensor (#batch, time', odim),
|
| 749 |
+
where time' = time // 6.
|
| 750 |
+
torch.Tensor: Subsampled mask (#batch, 1, time'),
|
| 751 |
+
where time' = time // 6.
|
| 752 |
+
|
| 753 |
+
"""
|
| 754 |
+
x = x.unsqueeze(1) # (b, c, t, f)
|
| 755 |
+
x = self.conv(x)
|
| 756 |
+
b, c, t, f = x.size()
|
| 757 |
+
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
| 758 |
+
if x_mask is not None:
|
| 759 |
+
x_mask = x_mask[:, :, :-2:2][:, :, :-4:3]
|
| 760 |
+
|
| 761 |
+
if prefix_embeds is not None:
|
| 762 |
+
x = torch.cat([prefix_embeds, x], dim=1)
|
| 763 |
+
if x_mask is not None:
|
| 764 |
+
x_mask = torch.cat(
|
| 765 |
+
[
|
| 766 |
+
torch.ones(
|
| 767 |
+
x_mask.shape[0],
|
| 768 |
+
1,
|
| 769 |
+
prefix_embeds.size(1),
|
| 770 |
+
dtype=x_mask.dtype,
|
| 771 |
+
device=x_mask.device,
|
| 772 |
+
),
|
| 773 |
+
x_mask,
|
| 774 |
+
],
|
| 775 |
+
dim=-1,
|
| 776 |
+
)
|
| 777 |
+
|
| 778 |
+
x = self.pos_enc(x)
|
| 779 |
+
|
| 780 |
+
return x, x_mask
|
| 781 |
+
|
| 782 |
+
|
| 783 |
+
class Conv2dSubsampling8(torch.nn.Module):
|
| 784 |
+
"""Convolutional 2D subsampling (to 1/8 length).
|
| 785 |
+
|
| 786 |
+
Args:
|
| 787 |
+
idim (int): Input dimension.
|
| 788 |
+
odim (int): Output dimension.
|
| 789 |
+
dropout_rate (float): Dropout rate.
|
| 790 |
+
pos_enc (torch.nn.Module): Custom position encoding layer.
|
| 791 |
+
|
| 792 |
+
"""
|
| 793 |
+
|
| 794 |
+
def __init__(self, idim, odim, dropout_rate, pos_enc=None):
|
| 795 |
+
"""Construct an Conv2dSubsampling8 object."""
|
| 796 |
+
super(Conv2dSubsampling8, self).__init__()
|
| 797 |
+
self.conv = torch.nn.Sequential(
|
| 798 |
+
torch.nn.Conv2d(1, odim, 3, 2),
|
| 799 |
+
torch.nn.ReLU(),
|
| 800 |
+
torch.nn.Conv2d(odim, odim, 3, 2),
|
| 801 |
+
torch.nn.ReLU(),
|
| 802 |
+
torch.nn.Conv2d(odim, odim, 3, 2),
|
| 803 |
+
torch.nn.ReLU(),
|
| 804 |
+
)
|
| 805 |
+
self.out = torch.nn.Linear(odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim)
|
| 806 |
+
self.pos_enc = (
|
| 807 |
+
pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate)
|
| 808 |
+
)
|
| 809 |
+
|
| 810 |
+
def _load_from_state_dict(
|
| 811 |
+
self,
|
| 812 |
+
state_dict,
|
| 813 |
+
prefix,
|
| 814 |
+
local_metadata,
|
| 815 |
+
strict,
|
| 816 |
+
missing_keys,
|
| 817 |
+
unexpected_keys,
|
| 818 |
+
error_msgs,
|
| 819 |
+
):
|
| 820 |
+
_upgrade_legacy_subsampling_state_dict(state_dict, prefix)
|
| 821 |
+
super()._load_from_state_dict(
|
| 822 |
+
state_dict,
|
| 823 |
+
prefix,
|
| 824 |
+
local_metadata,
|
| 825 |
+
strict,
|
| 826 |
+
missing_keys,
|
| 827 |
+
unexpected_keys,
|
| 828 |
+
error_msgs,
|
| 829 |
+
)
|
| 830 |
+
|
| 831 |
+
def forward(self, x, x_mask, prefix_embeds=None):
|
| 832 |
+
"""Subsample x.
|
| 833 |
+
|
| 834 |
+
Args:
|
| 835 |
+
x (torch.Tensor): Input tensor (#batch, time, idim).
|
| 836 |
+
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
| 837 |
+
prefix_embeds (torch.Tensor or None): Prefix token embeddings
|
| 838 |
+
(#batch, prefix_len, odim).
|
| 839 |
+
|
| 840 |
+
Returns:
|
| 841 |
+
torch.Tensor: Subsampled tensor (#batch, time', odim),
|
| 842 |
+
where time' = time // 8.
|
| 843 |
+
torch.Tensor: Subsampled mask (#batch, 1, time'),
|
| 844 |
+
where time' = time // 8.
|
| 845 |
+
|
| 846 |
+
"""
|
| 847 |
+
x = x.unsqueeze(1) # (b, c, t, f)
|
| 848 |
+
x = self.conv(x)
|
| 849 |
+
b, c, t, f = x.size()
|
| 850 |
+
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
| 851 |
+
if x_mask is not None:
|
| 852 |
+
x_mask = x_mask[:, :, :-2:2][:, :, :-2:2][:, :, :-2:2]
|
| 853 |
+
|
| 854 |
+
if prefix_embeds is not None:
|
| 855 |
+
x = torch.cat([prefix_embeds, x], dim=1)
|
| 856 |
+
if x_mask is not None:
|
| 857 |
+
x_mask = torch.cat(
|
| 858 |
+
[
|
| 859 |
+
torch.ones(
|
| 860 |
+
x_mask.shape[0],
|
| 861 |
+
1,
|
| 862 |
+
prefix_embeds.size(1),
|
| 863 |
+
dtype=x_mask.dtype,
|
| 864 |
+
device=x_mask.device,
|
| 865 |
+
),
|
| 866 |
+
x_mask,
|
| 867 |
+
],
|
| 868 |
+
dim=-1,
|
| 869 |
+
)
|
| 870 |
+
|
| 871 |
+
x = self.pos_enc(x)
|
| 872 |
+
|
| 873 |
+
return x, x_mask
|
src/model/__init__.py
ADDED
|
File without changes
|
src/model/powsm/__init__.py
ADDED
|
File without changes
|
src/model/powsm/ctc.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from typeguard import typechecked
|
| 6 |
+
from src.utils import RankedLogger
|
| 7 |
+
|
| 8 |
+
log = RankedLogger(__name__, rank_zero_only=True)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class CTC(torch.nn.Module):
|
| 12 |
+
"""CTC module.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
odim: dimension of outputs
|
| 16 |
+
encoder_output_size: number of encoder projection units
|
| 17 |
+
dropout_rate: dropout rate (0.0 ~ 1.0)
|
| 18 |
+
ctc_type: builtin or gtnctc
|
| 19 |
+
reduce: reduce the CTC loss into a scalar
|
| 20 |
+
ignore_nan_grad: Same as zero_infinity (keeping for backward compatiblity)
|
| 21 |
+
zero_infinity: Whether to zero infinite losses and the associated gradients.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
@typechecked
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
odim: int,
|
| 28 |
+
encoder_output_size: int,
|
| 29 |
+
dropout_rate: float = 0.0,
|
| 30 |
+
ctc_type: str = "builtin",
|
| 31 |
+
reduce: bool = True,
|
| 32 |
+
ignore_nan_grad: Optional[bool] = None,
|
| 33 |
+
zero_infinity: bool = True,
|
| 34 |
+
brctc_risk_strategy: str = "exp",
|
| 35 |
+
brctc_group_strategy: str = "end",
|
| 36 |
+
brctc_risk_factor: float = 0.0,
|
| 37 |
+
):
|
| 38 |
+
super().__init__()
|
| 39 |
+
eprojs = encoder_output_size
|
| 40 |
+
self.dropout_rate = dropout_rate
|
| 41 |
+
self.ctc_lo = torch.nn.Linear(eprojs, odim)
|
| 42 |
+
self.ctc_type = ctc_type
|
| 43 |
+
if ignore_nan_grad is not None:
|
| 44 |
+
zero_infinity = ignore_nan_grad
|
| 45 |
+
|
| 46 |
+
if self.ctc_type == "builtin":
|
| 47 |
+
self.ctc_loss = torch.nn.CTCLoss(
|
| 48 |
+
reduction="none", zero_infinity=zero_infinity
|
| 49 |
+
)
|
| 50 |
+
elif self.ctc_type == "builtin2":
|
| 51 |
+
self.ignore_nan_grad = True
|
| 52 |
+
log.warning("builtin2")
|
| 53 |
+
self.ctc_loss = torch.nn.CTCLoss(reduction="none")
|
| 54 |
+
|
| 55 |
+
elif self.ctc_type == "gtnctc":
|
| 56 |
+
raise ImportError("gtnctc requires gtn_ctc which is not bundled here.")
|
| 57 |
+
|
| 58 |
+
elif self.ctc_type == "brctc":
|
| 59 |
+
try:
|
| 60 |
+
import k2 # noqa
|
| 61 |
+
except ImportError:
|
| 62 |
+
raise ImportError("You should install K2 to use Bayes Risk CTC")
|
| 63 |
+
|
| 64 |
+
raise ImportError("brctc requires BayesRiskCTC which is not bundled here.")
|
| 65 |
+
else:
|
| 66 |
+
raise ValueError(
|
| 67 |
+
f'ctc_type must be "builtin" or "builtin2": {self.ctc_type}'
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
self.reduce = reduce
|
| 71 |
+
|
| 72 |
+
def loss_fn(
|
| 73 |
+
self,
|
| 74 |
+
th_pred,
|
| 75 |
+
th_target,
|
| 76 |
+
th_ilen,
|
| 77 |
+
th_olen,
|
| 78 |
+
lang_sym: Optional[Union[List[str], None]] = None,
|
| 79 |
+
accent_sym: Optional[Union[List[str], None]] = None,
|
| 80 |
+
) -> torch.Tensor:
|
| 81 |
+
if self.ctc_type in ["builtin", "brctc"]:
|
| 82 |
+
th_pred = th_pred.log_softmax(2).float()
|
| 83 |
+
loss = self.ctc_loss(th_pred, th_target, th_ilen, th_olen)
|
| 84 |
+
|
| 85 |
+
if self.ctc_type == "builtin":
|
| 86 |
+
size = th_pred.size(1)
|
| 87 |
+
else:
|
| 88 |
+
size = loss.size(0) # some invalid examples will be excluded
|
| 89 |
+
|
| 90 |
+
if self.reduce:
|
| 91 |
+
# Batch-size average
|
| 92 |
+
loss = loss.sum() / size
|
| 93 |
+
else:
|
| 94 |
+
loss = loss / size
|
| 95 |
+
return loss
|
| 96 |
+
|
| 97 |
+
# builtin2 ignores nan losses using the logic below, while
|
| 98 |
+
# builtin relies on the zero_infinity flag in pytorch CTC
|
| 99 |
+
elif self.ctc_type == "builtin2":
|
| 100 |
+
th_pred = th_pred.log_softmax(2).float()
|
| 101 |
+
loss = self.ctc_loss(th_pred, th_target, th_ilen, th_olen)
|
| 102 |
+
|
| 103 |
+
if loss.requires_grad and self.ignore_nan_grad:
|
| 104 |
+
# ctc_grad: (L, B, O)
|
| 105 |
+
ctc_grad = loss.grad_fn(torch.ones_like(loss))
|
| 106 |
+
ctc_grad = ctc_grad.sum([0, 2])
|
| 107 |
+
indices = torch.isfinite(ctc_grad)
|
| 108 |
+
size = indices.long().sum()
|
| 109 |
+
if size == 0:
|
| 110 |
+
# Return as is
|
| 111 |
+
log.warning(
|
| 112 |
+
"All samples in this mini-batch got nan grad."
|
| 113 |
+
" Returning nan value instead of CTC loss"
|
| 114 |
+
)
|
| 115 |
+
elif size != th_pred.size(1):
|
| 116 |
+
log.warning(
|
| 117 |
+
f"{th_pred.size(1) - size}/{th_pred.size(1)}"
|
| 118 |
+
" samples got nan grad."
|
| 119 |
+
" These were ignored for CTC loss."
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
# Create mask for target
|
| 123 |
+
target_mask = torch.full(
|
| 124 |
+
[th_target.size(0)],
|
| 125 |
+
1,
|
| 126 |
+
dtype=torch.bool,
|
| 127 |
+
device=th_target.device,
|
| 128 |
+
)
|
| 129 |
+
s = 0
|
| 130 |
+
for ind, le in enumerate(th_olen):
|
| 131 |
+
if not indices[ind]:
|
| 132 |
+
target_mask[s : s + le] = 0
|
| 133 |
+
s += le
|
| 134 |
+
|
| 135 |
+
# Calc loss again using maksed data
|
| 136 |
+
loss = self.ctc_loss(
|
| 137 |
+
th_pred[:, indices, :],
|
| 138 |
+
th_target[target_mask],
|
| 139 |
+
th_ilen[indices],
|
| 140 |
+
th_olen[indices],
|
| 141 |
+
)
|
| 142 |
+
else:
|
| 143 |
+
size = th_pred.size(1)
|
| 144 |
+
|
| 145 |
+
if self.reduce:
|
| 146 |
+
# Batch-size average
|
| 147 |
+
loss = loss.sum() / size
|
| 148 |
+
else:
|
| 149 |
+
loss = loss / size
|
| 150 |
+
return loss
|
| 151 |
+
|
| 152 |
+
elif self.ctc_type == "gtnctc":
|
| 153 |
+
log_probs = torch.nn.functional.log_softmax(th_pred, dim=2)
|
| 154 |
+
return self.ctc_loss(log_probs, th_target, th_ilen, 0, "none")
|
| 155 |
+
|
| 156 |
+
else:
|
| 157 |
+
raise NotImplementedError
|
| 158 |
+
|
| 159 |
+
def forward(
|
| 160 |
+
self,
|
| 161 |
+
hs_pad,
|
| 162 |
+
hlens,
|
| 163 |
+
ys_pad,
|
| 164 |
+
ys_lens,
|
| 165 |
+
lang_sym: Optional[Union[List[str], None]] = None,
|
| 166 |
+
accent_sym: Optional[Union[List[str], None]] = None,
|
| 167 |
+
):
|
| 168 |
+
"""Calculate CTC loss.
|
| 169 |
+
|
| 170 |
+
Args:
|
| 171 |
+
hs_pad: batch of padded hidden state sequences (B, Tmax, D)
|
| 172 |
+
hlens: batch of lengths of hidden state sequences (B)
|
| 173 |
+
ys_pad: batch of padded character id sequence tensor (B, Lmax)
|
| 174 |
+
ys_lens: batch of lengths of character sequence (B)
|
| 175 |
+
lang_sym: optional list of language codes per utterance
|
| 176 |
+
accent_sym: optional list of accent codes per utterance
|
| 177 |
+
"""
|
| 178 |
+
# hs_pad: (B, L, NProj) -> ys_hat: (B, L, Nvocab)
|
| 179 |
+
ys_hat = self.ctc_lo(F.dropout(hs_pad, p=self.dropout_rate))
|
| 180 |
+
|
| 181 |
+
if self.ctc_type == "brctc":
|
| 182 |
+
loss = self.loss_fn(
|
| 183 |
+
ys_hat, ys_pad, hlens, ys_lens, lang_sym=lang_sym, accent_sym=accent_sym
|
| 184 |
+
).to(device=hs_pad.device, dtype=hs_pad.dtype)
|
| 185 |
+
return loss
|
| 186 |
+
|
| 187 |
+
elif self.ctc_type == "gtnctc":
|
| 188 |
+
# gtn expects list form for ys
|
| 189 |
+
ys_true = [y[y != -1] for y in ys_pad] # parse padded ys
|
| 190 |
+
else:
|
| 191 |
+
# ys_hat: (B, L, D) -> (L, B, D)
|
| 192 |
+
ys_hat = ys_hat.transpose(0, 1)
|
| 193 |
+
# (B, L) -> (BxL,)
|
| 194 |
+
ys_true = torch.cat([ys_pad[i, :l] for i, l in enumerate(ys_lens)])
|
| 195 |
+
|
| 196 |
+
loss = self.loss_fn(
|
| 197 |
+
ys_hat, ys_true, hlens, ys_lens, lang_sym=lang_sym, accent_sym=accent_sym
|
| 198 |
+
).to(device=hs_pad.device, dtype=hs_pad.dtype)
|
| 199 |
+
|
| 200 |
+
return loss
|
| 201 |
+
|
| 202 |
+
def softmax(self, hs_pad):
|
| 203 |
+
"""softmax of frame activations
|
| 204 |
+
|
| 205 |
+
Args:
|
| 206 |
+
Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
|
| 207 |
+
Returns:
|
| 208 |
+
torch.Tensor: softmax applied 3d tensor (B, Tmax, odim)
|
| 209 |
+
"""
|
| 210 |
+
return F.softmax(self.ctc_lo(hs_pad), dim=2)
|
| 211 |
+
|
| 212 |
+
def log_softmax(self, hs_pad):
|
| 213 |
+
"""log_softmax of frame activations
|
| 214 |
+
|
| 215 |
+
Args:
|
| 216 |
+
Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
|
| 217 |
+
Returns:
|
| 218 |
+
torch.Tensor: log softmax applied 3d tensor (B, Tmax, odim)
|
| 219 |
+
"""
|
| 220 |
+
return F.log_softmax(self.ctc_lo(hs_pad), dim=2)
|
| 221 |
+
|
| 222 |
+
def argmax(self, hs_pad):
|
| 223 |
+
"""argmax of frame activations
|
| 224 |
+
|
| 225 |
+
Args:
|
| 226 |
+
torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
|
| 227 |
+
Returns:
|
| 228 |
+
torch.Tensor: argmax applied 2d tensor (B, Tmax)
|
| 229 |
+
"""
|
| 230 |
+
return torch.argmax(self.ctc_lo(hs_pad), dim=2)
|
src/model/powsm/e_branchformer.py
ADDED
|
@@ -0,0 +1,555 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 Kwangyoun Kim (ASAPP inc.)
|
| 2 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
| 3 |
+
|
| 4 |
+
"""E-Branchformer encoder definition.
|
| 5 |
+
|
| 6 |
+
Reference:
|
| 7 |
+
Kwangyoun Kim, Felix Wu, Yifan Peng, Jing Pan,
|
| 8 |
+
Prashant Sridhar, Kyu J. Han, Shinji Watanabe,
|
| 9 |
+
"E-Branchformer: Branchformer with Enhanced merging
|
| 10 |
+
for speech recognition," in SLT 2022.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import logging
|
| 14 |
+
from typing import List, Optional, Tuple
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
from typeguard import typechecked
|
| 18 |
+
|
| 19 |
+
from src.model.powsm.ctc import CTC
|
| 20 |
+
from src.espnet_import.fastformer import FastSelfAttention
|
| 21 |
+
from src.espnet_import.cgmlp import ConvolutionalGatingMLP
|
| 22 |
+
|
| 23 |
+
from src.espnet_import.nets_utils import get_activation, make_pad_mask
|
| 24 |
+
from src.espnet_import.attention import (
|
| 25 |
+
LegacyRelPositionMultiHeadedAttention,
|
| 26 |
+
MultiHeadedAttention,
|
| 27 |
+
RelPositionMultiHeadedAttention,
|
| 28 |
+
)
|
| 29 |
+
from src.espnet_import.embedding import (
|
| 30 |
+
ConvolutionalPositionalEmbedding,
|
| 31 |
+
LegacyRelPositionalEncoding,
|
| 32 |
+
PositionalEncoding,
|
| 33 |
+
RelPositionalEncoding,
|
| 34 |
+
ScaledPositionalEncoding,
|
| 35 |
+
)
|
| 36 |
+
from src.espnet_import.layer_norm import LayerNorm
|
| 37 |
+
from src.espnet_import.positionwise_feed_forward import PositionwiseFeedForward
|
| 38 |
+
from src.espnet_import.repeat import repeat
|
| 39 |
+
from src.espnet_import.subsampling import (
|
| 40 |
+
Conv1dSubsampling1,
|
| 41 |
+
Conv1dSubsampling2,
|
| 42 |
+
Conv1dSubsampling3,
|
| 43 |
+
Conv2dSubsampling,
|
| 44 |
+
Conv2dSubsampling1,
|
| 45 |
+
Conv2dSubsampling2,
|
| 46 |
+
Conv2dSubsampling6,
|
| 47 |
+
Conv2dSubsampling8,
|
| 48 |
+
TooShortUttError,
|
| 49 |
+
check_short_utt,
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class EBranchformerEncoderLayer(torch.nn.Module):
|
| 54 |
+
"""E-Branchformer encoder layer module.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
size (int): model dimension
|
| 58 |
+
attn: standard self-attention or efficient attention
|
| 59 |
+
cgmlp: ConvolutionalGatingMLP
|
| 60 |
+
feed_forward: feed-forward module, optional
|
| 61 |
+
feed_forward: macaron-style feed-forward module, optional
|
| 62 |
+
dropout_rate (float): dropout probability
|
| 63 |
+
merge_conv_kernel (int): kernel size of the depth-wise conv in merge module
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
def __init__(
|
| 67 |
+
self,
|
| 68 |
+
size: int,
|
| 69 |
+
attn: torch.nn.Module,
|
| 70 |
+
cgmlp: torch.nn.Module,
|
| 71 |
+
feed_forward: Optional[torch.nn.Module],
|
| 72 |
+
feed_forward_macaron: Optional[torch.nn.Module],
|
| 73 |
+
dropout_rate: float,
|
| 74 |
+
merge_conv_kernel: int = 3,
|
| 75 |
+
):
|
| 76 |
+
super().__init__()
|
| 77 |
+
|
| 78 |
+
self.size = size
|
| 79 |
+
self.attn = attn
|
| 80 |
+
self.cgmlp = cgmlp
|
| 81 |
+
|
| 82 |
+
self.feed_forward = feed_forward
|
| 83 |
+
self.feed_forward_macaron = feed_forward_macaron
|
| 84 |
+
self.ff_scale = 1.0
|
| 85 |
+
if self.feed_forward is not None:
|
| 86 |
+
self.norm_ff = LayerNorm(size)
|
| 87 |
+
if self.feed_forward_macaron is not None:
|
| 88 |
+
self.ff_scale = 0.5
|
| 89 |
+
self.norm_ff_macaron = LayerNorm(size)
|
| 90 |
+
|
| 91 |
+
self.norm_mha = LayerNorm(size) # for the MHA module
|
| 92 |
+
self.norm_mlp = LayerNorm(size) # for the MLP module
|
| 93 |
+
self.norm_final = LayerNorm(size) # for the final output of the block
|
| 94 |
+
|
| 95 |
+
self.dropout = torch.nn.Dropout(dropout_rate)
|
| 96 |
+
|
| 97 |
+
self.depthwise_conv_fusion = torch.nn.Conv1d(
|
| 98 |
+
size + size,
|
| 99 |
+
size + size,
|
| 100 |
+
kernel_size=merge_conv_kernel,
|
| 101 |
+
stride=1,
|
| 102 |
+
padding=(merge_conv_kernel - 1) // 2,
|
| 103 |
+
groups=size + size,
|
| 104 |
+
bias=True,
|
| 105 |
+
)
|
| 106 |
+
self.merge_proj = torch.nn.Linear(size + size, size)
|
| 107 |
+
|
| 108 |
+
def forward(self, x_input, mask, cache=None):
|
| 109 |
+
"""Compute encoded features.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb.
|
| 113 |
+
- w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)].
|
| 114 |
+
- w/o pos emb: Tensor (#batch, time, size).
|
| 115 |
+
mask (torch.Tensor): Mask tensor for the input (#batch, 1, time).
|
| 116 |
+
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
|
| 117 |
+
Returns:
|
| 118 |
+
torch.Tensor: Output tensor (#batch, time, size).
|
| 119 |
+
torch.Tensor: Mask tensor (#batch, time).
|
| 120 |
+
"""
|
| 121 |
+
|
| 122 |
+
if cache is not None:
|
| 123 |
+
raise NotImplementedError("cache is not None, which is not tested")
|
| 124 |
+
|
| 125 |
+
if isinstance(x_input, tuple):
|
| 126 |
+
x, pos_emb = x_input[0], x_input[1]
|
| 127 |
+
else:
|
| 128 |
+
x, pos_emb = x_input, None
|
| 129 |
+
|
| 130 |
+
if self.feed_forward_macaron is not None:
|
| 131 |
+
residual = x
|
| 132 |
+
x = self.norm_ff_macaron(x)
|
| 133 |
+
x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x))
|
| 134 |
+
|
| 135 |
+
# Two branches
|
| 136 |
+
x1 = x
|
| 137 |
+
x2 = x
|
| 138 |
+
|
| 139 |
+
# Branch 1: multi-headed attention module
|
| 140 |
+
x1 = self.norm_mha(x1)
|
| 141 |
+
|
| 142 |
+
if isinstance(self.attn, FastSelfAttention):
|
| 143 |
+
x_att = self.attn(x1, mask)
|
| 144 |
+
else:
|
| 145 |
+
if pos_emb is not None:
|
| 146 |
+
x_att = self.attn(x1, x1, x1, pos_emb, mask)
|
| 147 |
+
else:
|
| 148 |
+
x_att = self.attn(x1, x1, x1, mask)
|
| 149 |
+
|
| 150 |
+
x1 = self.dropout(x_att)
|
| 151 |
+
|
| 152 |
+
# Branch 2: convolutional gating mlp
|
| 153 |
+
x2 = self.norm_mlp(x2)
|
| 154 |
+
|
| 155 |
+
if pos_emb is not None:
|
| 156 |
+
x2 = (x2, pos_emb)
|
| 157 |
+
x2 = self.cgmlp(x2, mask)
|
| 158 |
+
if isinstance(x2, tuple):
|
| 159 |
+
x2 = x2[0]
|
| 160 |
+
|
| 161 |
+
x2 = self.dropout(x2)
|
| 162 |
+
|
| 163 |
+
# Merge two branches
|
| 164 |
+
x_concat = torch.cat([x1, x2], dim=-1)
|
| 165 |
+
x_tmp = x_concat.transpose(1, 2)
|
| 166 |
+
x_tmp = self.depthwise_conv_fusion(x_tmp)
|
| 167 |
+
x_tmp = x_tmp.transpose(1, 2)
|
| 168 |
+
x = x + self.dropout(self.merge_proj(x_concat + x_tmp))
|
| 169 |
+
|
| 170 |
+
if self.feed_forward is not None:
|
| 171 |
+
# feed forward module
|
| 172 |
+
residual = x
|
| 173 |
+
x = self.norm_ff(x)
|
| 174 |
+
x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
|
| 175 |
+
|
| 176 |
+
x = self.norm_final(x)
|
| 177 |
+
|
| 178 |
+
if pos_emb is not None:
|
| 179 |
+
return (x, pos_emb), mask
|
| 180 |
+
|
| 181 |
+
return x, mask
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class EBranchformerEncoder(torch.nn.Module):
|
| 185 |
+
"""E-Branchformer encoder module."""
|
| 186 |
+
|
| 187 |
+
@typechecked
|
| 188 |
+
def __init__(
|
| 189 |
+
self,
|
| 190 |
+
input_size: int,
|
| 191 |
+
output_size: int = 256,
|
| 192 |
+
attention_heads: int = 4,
|
| 193 |
+
attention_layer_type: str = "rel_selfattn",
|
| 194 |
+
pos_enc_layer_type: str = "rel_pos",
|
| 195 |
+
rel_pos_type: str = "latest",
|
| 196 |
+
cgmlp_linear_units: int = 2048,
|
| 197 |
+
cgmlp_conv_kernel: int = 31,
|
| 198 |
+
use_linear_after_conv: bool = False,
|
| 199 |
+
gate_activation: str = "identity",
|
| 200 |
+
num_blocks: int = 12,
|
| 201 |
+
dropout_rate: float = 0.1,
|
| 202 |
+
positional_dropout_rate: float = 0.1,
|
| 203 |
+
attention_dropout_rate: float = 0.0,
|
| 204 |
+
input_layer: Optional[str] = "conv2d",
|
| 205 |
+
zero_triu: bool = False,
|
| 206 |
+
padding_idx: int = -1,
|
| 207 |
+
layer_drop_rate: float = 0.0,
|
| 208 |
+
max_pos_emb_len: int = 5000,
|
| 209 |
+
use_ffn: bool = False,
|
| 210 |
+
macaron_ffn: bool = False,
|
| 211 |
+
ffn_activation_type: str = "swish",
|
| 212 |
+
linear_units: int = 2048,
|
| 213 |
+
positionwise_layer_type: str = "linear",
|
| 214 |
+
merge_conv_kernel: int = 3,
|
| 215 |
+
interctc_layer_idx=None,
|
| 216 |
+
interctc_use_conditioning: bool = False,
|
| 217 |
+
qk_norm: bool = False,
|
| 218 |
+
use_flash_attn: bool = True,
|
| 219 |
+
gradient_checkpoint_layers: List[int] = [],
|
| 220 |
+
):
|
| 221 |
+
super().__init__()
|
| 222 |
+
self._output_size = output_size
|
| 223 |
+
|
| 224 |
+
if rel_pos_type == "legacy":
|
| 225 |
+
if pos_enc_layer_type == "rel_pos":
|
| 226 |
+
pos_enc_layer_type = "legacy_rel_pos"
|
| 227 |
+
if attention_layer_type == "rel_selfattn":
|
| 228 |
+
attention_layer_type = "legacy_rel_selfattn"
|
| 229 |
+
elif rel_pos_type == "latest":
|
| 230 |
+
assert attention_layer_type != "legacy_rel_selfattn"
|
| 231 |
+
assert pos_enc_layer_type != "legacy_rel_pos"
|
| 232 |
+
else:
|
| 233 |
+
raise ValueError("unknown rel_pos_type: " + rel_pos_type)
|
| 234 |
+
|
| 235 |
+
if pos_enc_layer_type == "abs_pos":
|
| 236 |
+
pos_enc_class = PositionalEncoding
|
| 237 |
+
elif pos_enc_layer_type == "conv":
|
| 238 |
+
pos_enc_class = ConvolutionalPositionalEmbedding
|
| 239 |
+
elif pos_enc_layer_type == "scaled_abs_pos":
|
| 240 |
+
pos_enc_class = ScaledPositionalEncoding
|
| 241 |
+
elif pos_enc_layer_type == "rel_pos":
|
| 242 |
+
assert attention_layer_type == "rel_selfattn"
|
| 243 |
+
pos_enc_class = RelPositionalEncoding
|
| 244 |
+
elif pos_enc_layer_type == "legacy_rel_pos":
|
| 245 |
+
assert attention_layer_type == "legacy_rel_selfattn"
|
| 246 |
+
pos_enc_class = LegacyRelPositionalEncoding
|
| 247 |
+
logging.warning(
|
| 248 |
+
"Using legacy_rel_pos and it will be deprecated in the future."
|
| 249 |
+
)
|
| 250 |
+
else:
|
| 251 |
+
raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
|
| 252 |
+
|
| 253 |
+
if input_layer == "linear":
|
| 254 |
+
self.embed = torch.nn.Sequential(
|
| 255 |
+
torch.nn.Linear(input_size, output_size),
|
| 256 |
+
torch.nn.LayerNorm(output_size),
|
| 257 |
+
torch.nn.Dropout(dropout_rate),
|
| 258 |
+
pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len),
|
| 259 |
+
)
|
| 260 |
+
elif input_layer == "conv1d1":
|
| 261 |
+
self.embed = Conv1dSubsampling1(
|
| 262 |
+
input_size,
|
| 263 |
+
output_size,
|
| 264 |
+
dropout_rate,
|
| 265 |
+
pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len),
|
| 266 |
+
)
|
| 267 |
+
elif input_layer == "conv1d2":
|
| 268 |
+
self.embed = Conv1dSubsampling2(
|
| 269 |
+
input_size,
|
| 270 |
+
output_size,
|
| 271 |
+
dropout_rate,
|
| 272 |
+
pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len),
|
| 273 |
+
)
|
| 274 |
+
elif input_layer == "conv1d3":
|
| 275 |
+
self.embed = Conv1dSubsampling3(
|
| 276 |
+
input_size,
|
| 277 |
+
output_size,
|
| 278 |
+
dropout_rate,
|
| 279 |
+
pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len),
|
| 280 |
+
)
|
| 281 |
+
elif input_layer == "conv2d":
|
| 282 |
+
self.embed = Conv2dSubsampling(
|
| 283 |
+
input_size,
|
| 284 |
+
output_size,
|
| 285 |
+
dropout_rate,
|
| 286 |
+
pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len),
|
| 287 |
+
)
|
| 288 |
+
elif input_layer == "conv2d1":
|
| 289 |
+
self.embed = Conv2dSubsampling1(
|
| 290 |
+
input_size,
|
| 291 |
+
output_size,
|
| 292 |
+
dropout_rate,
|
| 293 |
+
pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len),
|
| 294 |
+
)
|
| 295 |
+
elif input_layer == "conv2d2":
|
| 296 |
+
self.embed = Conv2dSubsampling2(
|
| 297 |
+
input_size,
|
| 298 |
+
output_size,
|
| 299 |
+
dropout_rate,
|
| 300 |
+
pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len),
|
| 301 |
+
)
|
| 302 |
+
elif input_layer == "conv2d6":
|
| 303 |
+
self.embed = Conv2dSubsampling6(
|
| 304 |
+
input_size,
|
| 305 |
+
output_size,
|
| 306 |
+
dropout_rate,
|
| 307 |
+
pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len),
|
| 308 |
+
)
|
| 309 |
+
elif input_layer == "conv2d8":
|
| 310 |
+
self.embed = Conv2dSubsampling8(
|
| 311 |
+
input_size,
|
| 312 |
+
output_size,
|
| 313 |
+
dropout_rate,
|
| 314 |
+
pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len),
|
| 315 |
+
)
|
| 316 |
+
elif input_layer == "embed":
|
| 317 |
+
self.embed = torch.nn.Sequential(
|
| 318 |
+
torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
|
| 319 |
+
pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len),
|
| 320 |
+
)
|
| 321 |
+
elif isinstance(input_layer, torch.nn.Module):
|
| 322 |
+
self.embed = torch.nn.Sequential(
|
| 323 |
+
input_layer,
|
| 324 |
+
pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len),
|
| 325 |
+
)
|
| 326 |
+
elif input_layer is None:
|
| 327 |
+
if input_size == output_size:
|
| 328 |
+
self.embed = torch.nn.Sequential(
|
| 329 |
+
pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len)
|
| 330 |
+
)
|
| 331 |
+
else:
|
| 332 |
+
self.embed = torch.nn.Linear(input_size, output_size)
|
| 333 |
+
else:
|
| 334 |
+
raise ValueError("unknown input_layer: " + input_layer)
|
| 335 |
+
|
| 336 |
+
activation = get_activation(ffn_activation_type)
|
| 337 |
+
if positionwise_layer_type == "linear":
|
| 338 |
+
positionwise_layer = PositionwiseFeedForward
|
| 339 |
+
positionwise_layer_args = (
|
| 340 |
+
output_size,
|
| 341 |
+
linear_units,
|
| 342 |
+
dropout_rate,
|
| 343 |
+
activation,
|
| 344 |
+
)
|
| 345 |
+
elif positionwise_layer_type is None:
|
| 346 |
+
logging.warning("no macaron ffn")
|
| 347 |
+
else:
|
| 348 |
+
raise ValueError("Support only linear.")
|
| 349 |
+
|
| 350 |
+
if attention_layer_type == "selfattn":
|
| 351 |
+
# Default to flash attention unless overrided by user
|
| 352 |
+
if use_flash_attn:
|
| 353 |
+
try:
|
| 354 |
+
import flash_attn_interface # noqa
|
| 355 |
+
except Exception:
|
| 356 |
+
use_flash_attn = False
|
| 357 |
+
encoder_selfattn_layer = MultiHeadedAttention
|
| 358 |
+
encoder_selfattn_layer_args = (
|
| 359 |
+
attention_heads,
|
| 360 |
+
output_size,
|
| 361 |
+
attention_dropout_rate,
|
| 362 |
+
qk_norm,
|
| 363 |
+
use_flash_attn,
|
| 364 |
+
False,
|
| 365 |
+
False,
|
| 366 |
+
)
|
| 367 |
+
elif attention_layer_type == "legacy_rel_selfattn":
|
| 368 |
+
assert pos_enc_layer_type == "legacy_rel_pos"
|
| 369 |
+
encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention
|
| 370 |
+
encoder_selfattn_layer_args = (
|
| 371 |
+
attention_heads,
|
| 372 |
+
output_size,
|
| 373 |
+
attention_dropout_rate,
|
| 374 |
+
)
|
| 375 |
+
logging.warning(
|
| 376 |
+
"Using legacy_rel_selfattn and it will be deprecated in the future."
|
| 377 |
+
)
|
| 378 |
+
elif attention_layer_type == "rel_selfattn":
|
| 379 |
+
assert pos_enc_layer_type == "rel_pos"
|
| 380 |
+
encoder_selfattn_layer = RelPositionMultiHeadedAttention
|
| 381 |
+
encoder_selfattn_layer_args = (
|
| 382 |
+
attention_heads,
|
| 383 |
+
output_size,
|
| 384 |
+
attention_dropout_rate,
|
| 385 |
+
zero_triu,
|
| 386 |
+
)
|
| 387 |
+
elif attention_layer_type == "fast_selfattn":
|
| 388 |
+
assert pos_enc_layer_type in ["abs_pos", "scaled_abs_pos"]
|
| 389 |
+
encoder_selfattn_layer = FastSelfAttention
|
| 390 |
+
encoder_selfattn_layer_args = (
|
| 391 |
+
output_size,
|
| 392 |
+
attention_heads,
|
| 393 |
+
attention_dropout_rate,
|
| 394 |
+
)
|
| 395 |
+
else:
|
| 396 |
+
raise ValueError("unknown encoder_attn_layer: " + attention_layer_type)
|
| 397 |
+
|
| 398 |
+
cgmlp_layer = ConvolutionalGatingMLP
|
| 399 |
+
cgmlp_layer_args = (
|
| 400 |
+
output_size,
|
| 401 |
+
cgmlp_linear_units,
|
| 402 |
+
cgmlp_conv_kernel,
|
| 403 |
+
dropout_rate,
|
| 404 |
+
use_linear_after_conv,
|
| 405 |
+
gate_activation,
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
self.encoders = repeat(
|
| 409 |
+
num_blocks,
|
| 410 |
+
lambda lnum: EBranchformerEncoderLayer(
|
| 411 |
+
output_size,
|
| 412 |
+
encoder_selfattn_layer(*encoder_selfattn_layer_args),
|
| 413 |
+
cgmlp_layer(*cgmlp_layer_args),
|
| 414 |
+
positionwise_layer(*positionwise_layer_args) if use_ffn else None,
|
| 415 |
+
(
|
| 416 |
+
positionwise_layer(*positionwise_layer_args)
|
| 417 |
+
if use_ffn and macaron_ffn
|
| 418 |
+
else None
|
| 419 |
+
),
|
| 420 |
+
dropout_rate,
|
| 421 |
+
merge_conv_kernel,
|
| 422 |
+
),
|
| 423 |
+
layer_drop_rate,
|
| 424 |
+
)
|
| 425 |
+
self.after_norm = LayerNorm(output_size)
|
| 426 |
+
|
| 427 |
+
self.layer_drop_rate = layer_drop_rate
|
| 428 |
+
|
| 429 |
+
if interctc_layer_idx is None:
|
| 430 |
+
interctc_layer_idx = []
|
| 431 |
+
self.interctc_layer_idx = interctc_layer_idx
|
| 432 |
+
if len(interctc_layer_idx) > 0:
|
| 433 |
+
assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
|
| 434 |
+
self.interctc_use_conditioning = interctc_use_conditioning
|
| 435 |
+
self.conditioning_layer = None
|
| 436 |
+
|
| 437 |
+
# For gradient checkpointing
|
| 438 |
+
# 0 is the embedding layer, 1 is the first encoder layer, etc.
|
| 439 |
+
self.gradient_checkpoint_layers = gradient_checkpoint_layers
|
| 440 |
+
# logging.info(f"Gradient checkpoint layers: {self.gradient_checkpoint_layers}")
|
| 441 |
+
|
| 442 |
+
def output_size(self) -> int:
|
| 443 |
+
return self._output_size
|
| 444 |
+
|
| 445 |
+
def forward(
|
| 446 |
+
self,
|
| 447 |
+
xs_pad: torch.Tensor,
|
| 448 |
+
ilens: torch.Tensor,
|
| 449 |
+
prev_states: torch.Tensor = None,
|
| 450 |
+
masks: torch.Tensor = None,
|
| 451 |
+
ctc: CTC = None,
|
| 452 |
+
max_layer: int = None,
|
| 453 |
+
return_all_hs: bool = False,
|
| 454 |
+
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
| 455 |
+
"""Calculate forward propagation.
|
| 456 |
+
|
| 457 |
+
Args:
|
| 458 |
+
xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
|
| 459 |
+
ilens (torch.Tensor): Input length (#batch).
|
| 460 |
+
prev_states (torch.Tensor): Not to be used now.
|
| 461 |
+
ctc (CTC): Intermediate CTC module.
|
| 462 |
+
max_layer (int): Layer depth below which InterCTC is applied.
|
| 463 |
+
Returns:
|
| 464 |
+
torch.Tensor: Output tensor (#batch, L, output_size).
|
| 465 |
+
torch.Tensor: Output length (#batch).
|
| 466 |
+
torch.Tensor: Not to be used now.
|
| 467 |
+
"""
|
| 468 |
+
|
| 469 |
+
if masks is None:
|
| 470 |
+
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
|
| 471 |
+
else:
|
| 472 |
+
masks = ~masks[:, None, :]
|
| 473 |
+
|
| 474 |
+
if (
|
| 475 |
+
isinstance(self.embed, Conv2dSubsampling)
|
| 476 |
+
or isinstance(self.embed, Conv1dSubsampling1)
|
| 477 |
+
or isinstance(self.embed, Conv1dSubsampling2)
|
| 478 |
+
or isinstance(self.embed, Conv1dSubsampling3)
|
| 479 |
+
or isinstance(self.embed, Conv2dSubsampling1)
|
| 480 |
+
or isinstance(self.embed, Conv2dSubsampling2)
|
| 481 |
+
or isinstance(self.embed, Conv2dSubsampling6)
|
| 482 |
+
or isinstance(self.embed, Conv2dSubsampling8)
|
| 483 |
+
):
|
| 484 |
+
short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
|
| 485 |
+
if short_status:
|
| 486 |
+
raise TooShortUttError(
|
| 487 |
+
f"has {xs_pad.size(1)} frames and is too short for subsampling "
|
| 488 |
+
+ f"(it needs more than {limit_size} frames), return empty results",
|
| 489 |
+
xs_pad.size(1),
|
| 490 |
+
limit_size,
|
| 491 |
+
)
|
| 492 |
+
if 0 in self.gradient_checkpoint_layers:
|
| 493 |
+
xs_pad, masks = torch.utils.checkpoint.checkpoint(
|
| 494 |
+
self.embed, xs_pad, masks, use_reentrant=False
|
| 495 |
+
)
|
| 496 |
+
else:
|
| 497 |
+
xs_pad, masks = self.embed(xs_pad, masks)
|
| 498 |
+
elif self.embed is not None:
|
| 499 |
+
if 0 in self.gradient_checkpoint_layers:
|
| 500 |
+
xs_pad = torch.utils.checkpoint.checkpoint(
|
| 501 |
+
self.embed, xs_pad, use_reentrant=False
|
| 502 |
+
)
|
| 503 |
+
else:
|
| 504 |
+
xs_pad = self.embed(xs_pad)
|
| 505 |
+
|
| 506 |
+
intermediate_outs = []
|
| 507 |
+
for layer_idx, encoder_layer in enumerate(self.encoders):
|
| 508 |
+
if max_layer is not None and layer_idx >= max_layer:
|
| 509 |
+
break
|
| 510 |
+
|
| 511 |
+
if (
|
| 512 |
+
self.training
|
| 513 |
+
and torch.empty(1).uniform_().item() < self.layer_drop_rate
|
| 514 |
+
):
|
| 515 |
+
continue
|
| 516 |
+
|
| 517 |
+
if layer_idx + 1 in self.gradient_checkpoint_layers:
|
| 518 |
+
xs_pad, masks = torch.utils.checkpoint.checkpoint(
|
| 519 |
+
encoder_layer, xs_pad, masks, use_reentrant=False
|
| 520 |
+
)
|
| 521 |
+
else:
|
| 522 |
+
xs_pad, masks = encoder_layer(xs_pad, masks)
|
| 523 |
+
|
| 524 |
+
if return_all_hs:
|
| 525 |
+
if isinstance(xs_pad, tuple):
|
| 526 |
+
intermediate_outs.append(xs_pad[0])
|
| 527 |
+
else:
|
| 528 |
+
intermediate_outs.append(xs_pad)
|
| 529 |
+
|
| 530 |
+
elif layer_idx + 1 in self.interctc_layer_idx:
|
| 531 |
+
encoder_out = xs_pad
|
| 532 |
+
|
| 533 |
+
if isinstance(encoder_out, tuple):
|
| 534 |
+
encoder_out = encoder_out[0]
|
| 535 |
+
|
| 536 |
+
intermediate_outs.append((layer_idx + 1, encoder_out))
|
| 537 |
+
|
| 538 |
+
if self.interctc_use_conditioning:
|
| 539 |
+
ctc_out = ctc.softmax(encoder_out)
|
| 540 |
+
|
| 541 |
+
if isinstance(xs_pad, tuple):
|
| 542 |
+
xs_pad = list(xs_pad)
|
| 543 |
+
xs_pad[0] = xs_pad[0] + self.conditioning_layer(ctc_out)
|
| 544 |
+
xs_pad = tuple(xs_pad)
|
| 545 |
+
else:
|
| 546 |
+
xs_pad = xs_pad + self.conditioning_layer(ctc_out)
|
| 547 |
+
|
| 548 |
+
if isinstance(xs_pad, tuple):
|
| 549 |
+
xs_pad = xs_pad[0]
|
| 550 |
+
|
| 551 |
+
xs_pad = self.after_norm(xs_pad)
|
| 552 |
+
olens = masks.squeeze(1).sum(1)
|
| 553 |
+
if len(intermediate_outs) > 0:
|
| 554 |
+
return (xs_pad, intermediate_outs), olens, None
|
| 555 |
+
return xs_pad, olens, None
|
src/model/powsm/specaug.py
ADDED
|
@@ -0,0 +1,384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SpecAugment module."""
|
| 2 |
+
|
| 3 |
+
from typing import Optional, Sequence, Union
|
| 4 |
+
import math
|
| 5 |
+
from typeguard import typechecked
|
| 6 |
+
import torch
|
| 7 |
+
from src.espnet_import.nets_utils import pad_list
|
| 8 |
+
|
| 9 |
+
DEFAULT_TIME_WARP_MODE = "bicubic"
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def time_warp(x: torch.Tensor, window: int = 80, mode: str = DEFAULT_TIME_WARP_MODE):
|
| 13 |
+
"""Time warping using torch.interpolate.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
x: (Batch, Time, Freq)
|
| 17 |
+
window: time warp parameter
|
| 18 |
+
mode: Interpolate mode
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
# bicubic supports 4D or more dimension tensor
|
| 22 |
+
org_size = x.size()
|
| 23 |
+
if x.dim() == 3:
|
| 24 |
+
# x: (Batch, Time, Freq) -> (Batch, 1, Time, Freq)
|
| 25 |
+
x = x[:, None]
|
| 26 |
+
|
| 27 |
+
t = x.shape[2]
|
| 28 |
+
if t - window <= window:
|
| 29 |
+
return x.view(*org_size)
|
| 30 |
+
|
| 31 |
+
center = torch.randint(window, t - window, (1,))[0]
|
| 32 |
+
warped = torch.randint(center - window, center + window, (1,))[0] + 1
|
| 33 |
+
|
| 34 |
+
# left: (Batch, Channel, warped, Freq)
|
| 35 |
+
# right: (Batch, Channel, time - warped, Freq)
|
| 36 |
+
left = torch.nn.functional.interpolate(
|
| 37 |
+
x[:, :, :center], (warped, x.shape[3]), mode=mode, align_corners=False
|
| 38 |
+
)
|
| 39 |
+
right = torch.nn.functional.interpolate(
|
| 40 |
+
x[:, :, center:], (t - warped, x.shape[3]), mode=mode, align_corners=False
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
if x.requires_grad:
|
| 44 |
+
x = torch.cat([left, right], dim=-2)
|
| 45 |
+
else:
|
| 46 |
+
x[:, :, :warped] = left
|
| 47 |
+
x[:, :, warped:] = right
|
| 48 |
+
|
| 49 |
+
return x.view(*org_size)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def mask_along_axis(
|
| 53 |
+
spec: torch.Tensor,
|
| 54 |
+
spec_lengths: torch.Tensor,
|
| 55 |
+
mask_width_range: Sequence[int] = (0, 30),
|
| 56 |
+
dim: int = 1,
|
| 57 |
+
num_mask: int = 2,
|
| 58 |
+
replace_with_zero: bool = True,
|
| 59 |
+
):
|
| 60 |
+
"""Apply mask along the specified direction.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
spec: (Batch, Length, Freq)
|
| 64 |
+
spec_lengths: (Length): Not using lengths in this implementation
|
| 65 |
+
mask_width_range: Select the width randomly between this range
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
org_size = spec.size()
|
| 69 |
+
if spec.dim() == 4:
|
| 70 |
+
# spec: (Batch, Channel, Length, Freq) -> (Batch * Channel, Length, Freq)
|
| 71 |
+
spec = spec.view(-1, spec.size(2), spec.size(3))
|
| 72 |
+
|
| 73 |
+
B = spec.shape[0]
|
| 74 |
+
# D = Length or Freq
|
| 75 |
+
D = spec.shape[dim]
|
| 76 |
+
# mask_length: (B, num_mask, 1)
|
| 77 |
+
mask_length = torch.randint(
|
| 78 |
+
mask_width_range[0],
|
| 79 |
+
mask_width_range[1],
|
| 80 |
+
(B, num_mask),
|
| 81 |
+
device=spec.device,
|
| 82 |
+
).unsqueeze(2)
|
| 83 |
+
|
| 84 |
+
# mask_pos: (B, num_mask, 1)
|
| 85 |
+
mask_pos = torch.randint(
|
| 86 |
+
0, max(1, D - mask_length.max()), (B, num_mask), device=spec.device
|
| 87 |
+
).unsqueeze(2)
|
| 88 |
+
|
| 89 |
+
# aran: (1, 1, D)
|
| 90 |
+
aran = torch.arange(D, device=spec.device)[None, None, :]
|
| 91 |
+
# mask: (Batch, num_mask, D)
|
| 92 |
+
mask = (mask_pos <= aran) * (aran < (mask_pos + mask_length))
|
| 93 |
+
# Multiply masks: (Batch, num_mask, D) -> (Batch, D)
|
| 94 |
+
mask = mask.any(dim=1)
|
| 95 |
+
if dim == 1:
|
| 96 |
+
# mask: (Batch, Length, 1)
|
| 97 |
+
mask = mask.unsqueeze(2)
|
| 98 |
+
elif dim == 2:
|
| 99 |
+
# mask: (Batch, 1, Freq)
|
| 100 |
+
mask = mask.unsqueeze(1)
|
| 101 |
+
|
| 102 |
+
if replace_with_zero:
|
| 103 |
+
value = 0.0
|
| 104 |
+
else:
|
| 105 |
+
value = spec.mean()
|
| 106 |
+
|
| 107 |
+
if spec.requires_grad:
|
| 108 |
+
spec = spec.masked_fill(mask, value)
|
| 109 |
+
else:
|
| 110 |
+
spec = spec.masked_fill_(mask, value)
|
| 111 |
+
spec = spec.view(*org_size)
|
| 112 |
+
return spec, spec_lengths
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class TimeWarp(torch.nn.Module):
|
| 116 |
+
"""Time warping using torch.interpolate.
|
| 117 |
+
|
| 118 |
+
Args:
|
| 119 |
+
window: time warp parameter
|
| 120 |
+
mode: Interpolate mode
|
| 121 |
+
"""
|
| 122 |
+
|
| 123 |
+
def __init__(self, window: int = 80, mode: str = DEFAULT_TIME_WARP_MODE):
|
| 124 |
+
super().__init__()
|
| 125 |
+
self.window = window
|
| 126 |
+
self.mode = mode
|
| 127 |
+
|
| 128 |
+
def extra_repr(self):
|
| 129 |
+
return f"window={self.window}, mode={self.mode}"
|
| 130 |
+
|
| 131 |
+
def forward(self, x: torch.Tensor, x_lengths: torch.Tensor = None):
|
| 132 |
+
"""Forward function.
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
x: (Batch, Time, Freq)
|
| 136 |
+
x_lengths: (Batch,)
|
| 137 |
+
"""
|
| 138 |
+
|
| 139 |
+
if x_lengths is None or all(le == x_lengths[0] for le in x_lengths):
|
| 140 |
+
# Note that applying same warping for each sample
|
| 141 |
+
y = time_warp(x, window=self.window, mode=self.mode)
|
| 142 |
+
else:
|
| 143 |
+
# FIXME(kamo): I have no idea to batchify Timewarp
|
| 144 |
+
ys = []
|
| 145 |
+
for i in range(x.size(0)):
|
| 146 |
+
_y = time_warp(
|
| 147 |
+
x[i][None, : x_lengths[i]],
|
| 148 |
+
window=self.window,
|
| 149 |
+
mode=self.mode,
|
| 150 |
+
)[0]
|
| 151 |
+
ys.append(_y)
|
| 152 |
+
y = pad_list(ys, 0.0)
|
| 153 |
+
|
| 154 |
+
return y, x_lengths
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
class MaskAlongAxis(torch.nn.Module):
|
| 158 |
+
@typechecked
|
| 159 |
+
def __init__(
|
| 160 |
+
self,
|
| 161 |
+
mask_width_range: Union[int, Sequence[int]] = (0, 30),
|
| 162 |
+
num_mask: int = 2,
|
| 163 |
+
dim: Union[int, str] = "time",
|
| 164 |
+
replace_with_zero: bool = True,
|
| 165 |
+
):
|
| 166 |
+
if isinstance(mask_width_range, int):
|
| 167 |
+
mask_width_range = (0, mask_width_range)
|
| 168 |
+
if len(mask_width_range) != 2:
|
| 169 |
+
raise TypeError(
|
| 170 |
+
f"mask_width_range must be a tuple of int and int values: "
|
| 171 |
+
f"{mask_width_range}",
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
assert mask_width_range[1] > mask_width_range[0]
|
| 175 |
+
if isinstance(dim, str):
|
| 176 |
+
if dim == "time":
|
| 177 |
+
dim = 1
|
| 178 |
+
elif dim == "freq":
|
| 179 |
+
dim = 2
|
| 180 |
+
else:
|
| 181 |
+
raise ValueError("dim must be int, 'time' or 'freq'")
|
| 182 |
+
if dim == 1:
|
| 183 |
+
self.mask_axis = "time"
|
| 184 |
+
elif dim == 2:
|
| 185 |
+
self.mask_axis = "freq"
|
| 186 |
+
else:
|
| 187 |
+
self.mask_axis = "unknown"
|
| 188 |
+
|
| 189 |
+
super().__init__()
|
| 190 |
+
self.mask_width_range = mask_width_range
|
| 191 |
+
self.num_mask = num_mask
|
| 192 |
+
self.dim = dim
|
| 193 |
+
self.replace_with_zero = replace_with_zero
|
| 194 |
+
|
| 195 |
+
def extra_repr(self):
|
| 196 |
+
return (
|
| 197 |
+
f"mask_width_range={self.mask_width_range}, "
|
| 198 |
+
f"num_mask={self.num_mask}, axis={self.mask_axis}"
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
def forward(self, spec: torch.Tensor, spec_lengths: torch.Tensor = None):
|
| 202 |
+
"""Forward function.
|
| 203 |
+
|
| 204 |
+
Args:
|
| 205 |
+
spec: (Batch, Length, Freq)
|
| 206 |
+
"""
|
| 207 |
+
|
| 208 |
+
return mask_along_axis(
|
| 209 |
+
spec,
|
| 210 |
+
spec_lengths,
|
| 211 |
+
mask_width_range=self.mask_width_range,
|
| 212 |
+
dim=self.dim,
|
| 213 |
+
num_mask=self.num_mask,
|
| 214 |
+
replace_with_zero=self.replace_with_zero,
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
class MaskAlongAxisVariableMaxWidth(torch.nn.Module):
|
| 219 |
+
"""Mask input spec along a specified axis with variable maximum width.
|
| 220 |
+
|
| 221 |
+
Formula:
|
| 222 |
+
max_width = max_width_ratio * seq_len
|
| 223 |
+
"""
|
| 224 |
+
|
| 225 |
+
@typechecked
|
| 226 |
+
def __init__(
|
| 227 |
+
self,
|
| 228 |
+
mask_width_ratio_range: Union[float, Sequence[float]] = (0.0, 0.05),
|
| 229 |
+
num_mask: int = 2,
|
| 230 |
+
dim: Union[int, str] = "time",
|
| 231 |
+
replace_with_zero: bool = True,
|
| 232 |
+
):
|
| 233 |
+
if isinstance(mask_width_ratio_range, float):
|
| 234 |
+
mask_width_ratio_range = (0.0, mask_width_ratio_range)
|
| 235 |
+
if len(mask_width_ratio_range) != 2:
|
| 236 |
+
raise TypeError(
|
| 237 |
+
f"mask_width_ratio_range must be a tuple of float and float values: "
|
| 238 |
+
f"{mask_width_ratio_range}",
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
assert mask_width_ratio_range[1] > mask_width_ratio_range[0]
|
| 242 |
+
if isinstance(dim, str):
|
| 243 |
+
if dim == "time":
|
| 244 |
+
dim = 1
|
| 245 |
+
elif dim == "freq":
|
| 246 |
+
dim = 2
|
| 247 |
+
else:
|
| 248 |
+
raise ValueError("dim must be int, 'time' or 'freq'")
|
| 249 |
+
if dim == 1:
|
| 250 |
+
self.mask_axis = "time"
|
| 251 |
+
elif dim == 2:
|
| 252 |
+
self.mask_axis = "freq"
|
| 253 |
+
else:
|
| 254 |
+
self.mask_axis = "unknown"
|
| 255 |
+
|
| 256 |
+
super().__init__()
|
| 257 |
+
self.mask_width_ratio_range = mask_width_ratio_range
|
| 258 |
+
self.num_mask = num_mask
|
| 259 |
+
self.dim = dim
|
| 260 |
+
self.replace_with_zero = replace_with_zero
|
| 261 |
+
|
| 262 |
+
def extra_repr(self):
|
| 263 |
+
return (
|
| 264 |
+
f"mask_width_ratio_range={self.mask_width_ratio_range}, "
|
| 265 |
+
f"num_mask={self.num_mask}, axis={self.mask_axis}"
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
def forward(self, spec: torch.Tensor, spec_lengths: torch.Tensor = None):
|
| 269 |
+
"""Forward function.
|
| 270 |
+
|
| 271 |
+
Args:
|
| 272 |
+
spec: (Batch, Length, Freq)
|
| 273 |
+
"""
|
| 274 |
+
|
| 275 |
+
max_seq_len = spec.shape[self.dim]
|
| 276 |
+
min_mask_width = math.floor(max_seq_len * self.mask_width_ratio_range[0])
|
| 277 |
+
min_mask_width = max([0, min_mask_width])
|
| 278 |
+
max_mask_width = math.floor(max_seq_len * self.mask_width_ratio_range[1])
|
| 279 |
+
max_mask_width = min([max_seq_len, max_mask_width])
|
| 280 |
+
|
| 281 |
+
if max_mask_width > min_mask_width:
|
| 282 |
+
return mask_along_axis(
|
| 283 |
+
spec,
|
| 284 |
+
spec_lengths,
|
| 285 |
+
mask_width_range=(min_mask_width, max_mask_width),
|
| 286 |
+
dim=self.dim,
|
| 287 |
+
num_mask=self.num_mask,
|
| 288 |
+
replace_with_zero=self.replace_with_zero,
|
| 289 |
+
)
|
| 290 |
+
return spec, spec_lengths
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
class SpecAug(torch.nn.Module):
|
| 294 |
+
"""Implementation of SpecAug.
|
| 295 |
+
|
| 296 |
+
Reference:
|
| 297 |
+
Daniel S. Park et al.
|
| 298 |
+
"SpecAugment: A Simple Data
|
| 299 |
+
Augmentation Method for Automatic Speech Recognition"
|
| 300 |
+
|
| 301 |
+
.. warning::
|
| 302 |
+
When using cuda mode, time_warp doesn't have reproducibility
|
| 303 |
+
due to `torch.nn.functional.interpolate`.
|
| 304 |
+
|
| 305 |
+
"""
|
| 306 |
+
|
| 307 |
+
def __init__(
|
| 308 |
+
self,
|
| 309 |
+
apply_time_warp: bool = True,
|
| 310 |
+
time_warp_window: int = 5,
|
| 311 |
+
time_warp_mode: str = "bicubic",
|
| 312 |
+
apply_freq_mask: bool = True,
|
| 313 |
+
freq_mask_width_range: Union[int, Sequence[int]] = (0, 20),
|
| 314 |
+
num_freq_mask: int = 2,
|
| 315 |
+
apply_time_mask: bool = True,
|
| 316 |
+
time_mask_width_range: Optional[Union[int, Sequence[int]]] = None,
|
| 317 |
+
time_mask_width_ratio_range: Optional[Union[float, Sequence[float]]] = None,
|
| 318 |
+
num_time_mask: int = 2,
|
| 319 |
+
replace_with_zero: bool = True,
|
| 320 |
+
):
|
| 321 |
+
if not apply_time_warp and not apply_time_mask and not apply_freq_mask:
|
| 322 |
+
raise ValueError(
|
| 323 |
+
"Either one of time_warp, time_mask, or freq_mask should be applied"
|
| 324 |
+
)
|
| 325 |
+
if (
|
| 326 |
+
apply_time_mask
|
| 327 |
+
and (time_mask_width_range is not None)
|
| 328 |
+
and (time_mask_width_ratio_range is not None)
|
| 329 |
+
):
|
| 330 |
+
raise ValueError(
|
| 331 |
+
'Either one of "time_mask_width_range" or '
|
| 332 |
+
'"time_mask_width_ratio_range" can be used'
|
| 333 |
+
)
|
| 334 |
+
super().__init__()
|
| 335 |
+
self.apply_time_warp = apply_time_warp
|
| 336 |
+
self.apply_freq_mask = apply_freq_mask
|
| 337 |
+
self.apply_time_mask = apply_time_mask
|
| 338 |
+
|
| 339 |
+
if apply_time_warp:
|
| 340 |
+
self.time_warp = TimeWarp(window=time_warp_window, mode=time_warp_mode)
|
| 341 |
+
else:
|
| 342 |
+
self.time_warp = None
|
| 343 |
+
|
| 344 |
+
if apply_freq_mask:
|
| 345 |
+
self.freq_mask = MaskAlongAxis(
|
| 346 |
+
dim="freq",
|
| 347 |
+
mask_width_range=freq_mask_width_range,
|
| 348 |
+
num_mask=num_freq_mask,
|
| 349 |
+
replace_with_zero=replace_with_zero,
|
| 350 |
+
)
|
| 351 |
+
else:
|
| 352 |
+
self.freq_mask = None
|
| 353 |
+
|
| 354 |
+
if apply_time_mask:
|
| 355 |
+
if time_mask_width_range is not None:
|
| 356 |
+
self.time_mask = MaskAlongAxis(
|
| 357 |
+
dim="time",
|
| 358 |
+
mask_width_range=time_mask_width_range,
|
| 359 |
+
num_mask=num_time_mask,
|
| 360 |
+
replace_with_zero=replace_with_zero,
|
| 361 |
+
)
|
| 362 |
+
elif time_mask_width_ratio_range is not None:
|
| 363 |
+
self.time_mask = MaskAlongAxisVariableMaxWidth(
|
| 364 |
+
dim="time",
|
| 365 |
+
mask_width_ratio_range=time_mask_width_ratio_range,
|
| 366 |
+
num_mask=num_time_mask,
|
| 367 |
+
replace_with_zero=replace_with_zero,
|
| 368 |
+
)
|
| 369 |
+
else:
|
| 370 |
+
raise ValueError(
|
| 371 |
+
'Either one of "time_mask_width_range" or '
|
| 372 |
+
'"time_mask_width_ratio_range" should be used.'
|
| 373 |
+
)
|
| 374 |
+
else:
|
| 375 |
+
self.time_mask = None
|
| 376 |
+
|
| 377 |
+
def forward(self, x, x_lengths=None):
|
| 378 |
+
if self.time_warp is not None:
|
| 379 |
+
x, x_lengths = self.time_warp(x, x_lengths)
|
| 380 |
+
if self.freq_mask is not None:
|
| 381 |
+
x, x_lengths = self.freq_mask(x, x_lengths)
|
| 382 |
+
if self.time_mask is not None:
|
| 383 |
+
x, x_lengths = self.time_mask(x, x_lengths)
|
| 384 |
+
return x, x_lengths
|
src/model/powsm/utils.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import dataclasses
|
| 2 |
+
import warnings
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def force_gatherable(data, device):
|
| 9 |
+
"""Change object to gatherable in torch.nn.DataParallel recursively
|
| 10 |
+
|
| 11 |
+
The restriction to the returned value in DataParallel:
|
| 12 |
+
The object must be
|
| 13 |
+
- torch.cuda.Tensor
|
| 14 |
+
- 1 or more dimension. 0-dimension-tensor sends warning.
|
| 15 |
+
or a list, tuple, dict.
|
| 16 |
+
|
| 17 |
+
"""
|
| 18 |
+
if isinstance(data, dict):
|
| 19 |
+
return {k: force_gatherable(v, device) for k, v in data.items()}
|
| 20 |
+
# DataParallel can't handle NamedTuple well
|
| 21 |
+
elif isinstance(data, tuple) and type(data) is not tuple:
|
| 22 |
+
return type(data)(*[force_gatherable(o, device) for o in data])
|
| 23 |
+
elif isinstance(data, (list, tuple, set)):
|
| 24 |
+
return type(data)(force_gatherable(v, device) for v in data)
|
| 25 |
+
elif isinstance(data, np.ndarray):
|
| 26 |
+
return force_gatherable(torch.from_numpy(data), device)
|
| 27 |
+
elif isinstance(data, torch.Tensor):
|
| 28 |
+
if data.dim() == 0:
|
| 29 |
+
# To 1-dim array
|
| 30 |
+
data = data[None]
|
| 31 |
+
return data.to(device)
|
| 32 |
+
elif isinstance(data, float):
|
| 33 |
+
return torch.tensor([data], dtype=torch.float, device=device)
|
| 34 |
+
elif isinstance(data, int):
|
| 35 |
+
return torch.tensor([data], dtype=torch.long, device=device)
|
| 36 |
+
elif data is None:
|
| 37 |
+
return None
|
| 38 |
+
else:
|
| 39 |
+
warnings.warn(f"{type(data)} may not be gatherable by DataParallel")
|
| 40 |
+
return data
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def to_device(data, device=None, dtype=None, non_blocking=False, copy=False):
|
| 44 |
+
"""Change the device of object recursively"""
|
| 45 |
+
if isinstance(data, dict):
|
| 46 |
+
return {
|
| 47 |
+
k: to_device(v, device, dtype, non_blocking, copy) for k, v in data.items()
|
| 48 |
+
}
|
| 49 |
+
elif dataclasses.is_dataclass(data) and not isinstance(data, type):
|
| 50 |
+
return type(data)(
|
| 51 |
+
*[
|
| 52 |
+
to_device(v, device, dtype, non_blocking, copy)
|
| 53 |
+
for v in dataclasses.astuple(data)
|
| 54 |
+
]
|
| 55 |
+
)
|
| 56 |
+
# maybe namedtuple. I don't know the correct way to judge namedtuple.
|
| 57 |
+
elif isinstance(data, tuple) and type(data) is not tuple:
|
| 58 |
+
return type(data)(
|
| 59 |
+
*[to_device(o, device, dtype, non_blocking, copy) for o in data]
|
| 60 |
+
)
|
| 61 |
+
elif isinstance(data, (list, tuple)):
|
| 62 |
+
return type(data)(to_device(v, device, dtype, non_blocking, copy) for v in data)
|
| 63 |
+
elif isinstance(data, np.ndarray):
|
| 64 |
+
return to_device(torch.from_numpy(data), device, dtype, non_blocking, copy)
|
| 65 |
+
elif isinstance(data, torch.Tensor):
|
| 66 |
+
if dtype is not None:
|
| 67 |
+
dtype = str(dtype).removeprefix("torch.")
|
| 68 |
+
cur_dtype = str(data.dtype).removeprefix("torch.")
|
| 69 |
+
|
| 70 |
+
if not (
|
| 71 |
+
("int" in dtype and "int" in cur_dtype)
|
| 72 |
+
or ("float" in dtype and "float" in cur_dtype)
|
| 73 |
+
):
|
| 74 |
+
dtype = None # avoid conversion between int and float.
|
| 75 |
+
else:
|
| 76 |
+
dtype = getattr(torch, dtype)
|
| 77 |
+
|
| 78 |
+
return data.to(device, dtype, non_blocking, copy)
|
| 79 |
+
else:
|
| 80 |
+
return data
|
src/model/xeusphoneme/__init__.py
ADDED
|
File without changes
|
src/model/xeusphoneme/builders.py
ADDED
|
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import Dict, Optional, Tuple
|
| 4 |
+
import argparse
|
| 5 |
+
import yaml
|
| 6 |
+
import json
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from src.model.powsm.specaug import SpecAug
|
| 10 |
+
from src.model.powsm.e_branchformer import EBranchformerEncoder
|
| 11 |
+
from src.model.xeusphoneme.cnn_frontend import CNNFrontend as Wav2VecCNN
|
| 12 |
+
from src.model.xeusphoneme.linear_layer import LinearProjection
|
| 13 |
+
from src.core.utils import download_hf_snapshot
|
| 14 |
+
from src.model.xeusphoneme.xeuspr_model import XeusPRModel
|
| 15 |
+
from src.model.xeusphoneme.xeuspr_inference import XeusPRInference
|
| 16 |
+
from src.model.powsm.ctc import CTC
|
| 17 |
+
from src.utils import RankedLogger
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
log = RankedLogger(__name__, rank_zero_only=False)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class XeusPRTokenizer:
|
| 24 |
+
"""Tokenizer that maps IPA phones to IDs using the xeuspr ipa_vocab.json."""
|
| 25 |
+
|
| 26 |
+
def __init__(self, vocab_file: str):
|
| 27 |
+
with open(vocab_file) as f:
|
| 28 |
+
self.vocab: Dict[str, int] = json.load(f)
|
| 29 |
+
self.unk_id = self.vocab.get("<unk>", 0)
|
| 30 |
+
|
| 31 |
+
def tokens2ids(self, tokens) -> list:
|
| 32 |
+
return [self.vocab.get(t, self.unk_id) for t in tokens]
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def build_xeus_pr(
|
| 37 |
+
config_file: str,
|
| 38 |
+
checkpoint: Optional[str] = None,
|
| 39 |
+
vocab_file: Optional[str] = None,
|
| 40 |
+
ctc_config: Optional[dict] = None,
|
| 41 |
+
weighted_sum: bool = False,
|
| 42 |
+
interctc_layer_idx: Optional[list] = None,
|
| 43 |
+
interctc_weight: float = 0.0,
|
| 44 |
+
interctc_use_conditioning: bool = False,
|
| 45 |
+
interctc_ctc_type: str = "phone",
|
| 46 |
+
ctc_aux_config: Optional[dict] = None,
|
| 47 |
+
decoder_config: Optional[dict] = None,
|
| 48 |
+
ctc_weight: float = 1.0,
|
| 49 |
+
) -> XeusPRModel:
|
| 50 |
+
"""Build Xeus PR model from config and optional checkpoint.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
config_file: Path to config yaml file
|
| 54 |
+
checkpoint: Path to model checkpoint (pretrained or fully trained)
|
| 55 |
+
vocab_file: Path to vocabulary file. If None, use vocab in config.
|
| 56 |
+
ctc_config: Optional dict of CTC config
|
| 57 |
+
weighted_sum: Whether to use weighted sum of transformer layers
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
XeusPRModel
|
| 61 |
+
"""
|
| 62 |
+
with open(config_file, "r", encoding="utf-8") as f:
|
| 63 |
+
args = argparse.Namespace(**yaml.safe_load(f))
|
| 64 |
+
if vocab_file is not None:
|
| 65 |
+
with open(vocab_file) as f:
|
| 66 |
+
tok2id = json.load(f)
|
| 67 |
+
id2tok = {v: k for k, v in tok2id.items()}
|
| 68 |
+
token_list = [id2tok[i] for i in range(len(id2tok))]
|
| 69 |
+
elif isinstance(args.token_list, str):
|
| 70 |
+
with open(args.token_list, encoding="utf-8") as f:
|
| 71 |
+
token_list = [line.rstrip() for line in f]
|
| 72 |
+
else:
|
| 73 |
+
token_list = list(args.token_list)
|
| 74 |
+
vocab_size = len(token_list)
|
| 75 |
+
log.info(f"Vocabulary size: {vocab_size}")
|
| 76 |
+
|
| 77 |
+
assert (
|
| 78 |
+
getattr(args, "frontend") == "wav2vec_cnn"
|
| 79 |
+
), "Config must specify wav2vec_cnn frontend"
|
| 80 |
+
frontend = Wav2VecCNN(**args.frontend_conf)
|
| 81 |
+
input_size = frontend.output_size()
|
| 82 |
+
|
| 83 |
+
specaug = None
|
| 84 |
+
if hasattr(args, "specaug") and args.specaug == "specaug":
|
| 85 |
+
specaug = SpecAug(**args.specaug_conf)
|
| 86 |
+
|
| 87 |
+
normalize = None
|
| 88 |
+
assert (
|
| 89 |
+
getattr(args, "preencoder") == "linear"
|
| 90 |
+
), "Config must specify linear preencoder"
|
| 91 |
+
preencoder = LinearProjection(input_size=input_size, **args.preencoder_conf)
|
| 92 |
+
input_size = preencoder.output_size()
|
| 93 |
+
assert (
|
| 94 |
+
args.encoder == "e_branchformer"
|
| 95 |
+
), f"Only e_branchformer supported, got {args.encoder}"
|
| 96 |
+
encoder_conf = dict(args.encoder_conf)
|
| 97 |
+
if interctc_layer_idx:
|
| 98 |
+
encoder_conf["interctc_layer_idx"] = interctc_layer_idx
|
| 99 |
+
if interctc_use_conditioning:
|
| 100 |
+
encoder_conf["interctc_use_conditioning"] = True
|
| 101 |
+
encoder = EBranchformerEncoder(input_size=input_size, **encoder_conf)
|
| 102 |
+
|
| 103 |
+
ctc_config = ctc_config or getattr(args, "ctc_conf", {})
|
| 104 |
+
ctc_config_orig = copy.deepcopy(ctc_config)
|
| 105 |
+
# Build CTC
|
| 106 |
+
ctc = CTC(
|
| 107 |
+
odim=vocab_size,
|
| 108 |
+
encoder_output_size=encoder.output_size(),
|
| 109 |
+
**ctc_config,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
# Build optional aux CTC (orthographic vocabulary)
|
| 113 |
+
ctc_aux = None
|
| 114 |
+
if ctc_aux_config is not None:
|
| 115 |
+
import sentencepiece as spm
|
| 116 |
+
|
| 117 |
+
ctc_aux_config = dict(ctc_aux_config) # copy to avoid mutating caller's dict
|
| 118 |
+
sp = spm.SentencePieceProcessor()
|
| 119 |
+
sp.load(ctc_aux_config.pop("vocab_file"))
|
| 120 |
+
aux_vocab_size = sp.get_piece_size()
|
| 121 |
+
ctc_aux = CTC(
|
| 122 |
+
odim=aux_vocab_size,
|
| 123 |
+
encoder_output_size=encoder.output_size(),
|
| 124 |
+
ctc_type="builtin",
|
| 125 |
+
**ctc_aux_config,
|
| 126 |
+
)
|
| 127 |
+
log.info(f"Built aux CTC with vocab size {aux_vocab_size}")
|
| 128 |
+
|
| 129 |
+
# Build optional attention decoder
|
| 130 |
+
decoder = None
|
| 131 |
+
if decoder_config:
|
| 132 |
+
from src.model.powsm.transformer_decoder import TransformerDecoder
|
| 133 |
+
|
| 134 |
+
decoder = TransformerDecoder(
|
| 135 |
+
vocab_size=vocab_size,
|
| 136 |
+
encoder_output_size=encoder.output_size(),
|
| 137 |
+
**decoder_config,
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
# Build model
|
| 141 |
+
model = XeusPRModel(
|
| 142 |
+
encoder=encoder,
|
| 143 |
+
ctc=ctc,
|
| 144 |
+
token_list=token_list,
|
| 145 |
+
frontend=frontend,
|
| 146 |
+
specaug=specaug,
|
| 147 |
+
normalize=normalize,
|
| 148 |
+
preencoder=preencoder,
|
| 149 |
+
ignore_id=getattr(args, "ignore_id", -1),
|
| 150 |
+
sym_blank=getattr(args, "sym_blank", "<blank>"),
|
| 151 |
+
freeze_frontend=checkpoint is not None,
|
| 152 |
+
weighted_sum=weighted_sum,
|
| 153 |
+
interctc_weight=interctc_weight,
|
| 154 |
+
interctc_use_conditioning=interctc_use_conditioning,
|
| 155 |
+
interctc_ctc_type=interctc_ctc_type,
|
| 156 |
+
ctc_aux=ctc_aux,
|
| 157 |
+
decoder=decoder,
|
| 158 |
+
ctc_weight=ctc_weight,
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
if checkpoint:
|
| 162 |
+
state_dict = torch.load(checkpoint, map_location="cpu", weights_only=False)
|
| 163 |
+
if "state_dict" in state_dict:
|
| 164 |
+
# convert to standard xeus style checkpoint
|
| 165 |
+
state_dict = state_dict["state_dict"] # for finetuned lightning checkpoints
|
| 166 |
+
state_dict = {
|
| 167 |
+
k.replace("net.", ""): v
|
| 168 |
+
for k, v in state_dict.items()
|
| 169 |
+
if k.startswith("net.")
|
| 170 |
+
}
|
| 171 |
+
load_info = model.load_state_dict(state_dict, strict=False)
|
| 172 |
+
log.info(f"Loaded checkpoint: {checkpoint} with load info: {load_info}")
|
| 173 |
+
print(f"Loaded checkpoint: {checkpoint} with load info: {load_info}")
|
| 174 |
+
|
| 175 |
+
model.training_args = args
|
| 176 |
+
model._net_config = {
|
| 177 |
+
"ctc_config": ctc_config_orig,
|
| 178 |
+
"weighted_sum": weighted_sum,
|
| 179 |
+
"interctc_layer_idx": interctc_layer_idx,
|
| 180 |
+
"interctc_weight": interctc_weight,
|
| 181 |
+
"interctc_use_conditioning": interctc_use_conditioning,
|
| 182 |
+
"interctc_ctc_type": interctc_ctc_type,
|
| 183 |
+
"ctc_aux_config": ctc_aux_config,
|
| 184 |
+
"decoder_config": decoder_config,
|
| 185 |
+
"ctc_weight": ctc_weight,
|
| 186 |
+
}
|
| 187 |
+
return model
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def build_xeus_pr_from_hf(
|
| 191 |
+
*,
|
| 192 |
+
work_dir: str,
|
| 193 |
+
hf_repo: Optional[str] = None,
|
| 194 |
+
force: bool = False,
|
| 195 |
+
config_file: Optional[str] = None,
|
| 196 |
+
checkpoint: Optional[str] = None,
|
| 197 |
+
vocab_file: Optional[str] = None,
|
| 198 |
+
ctc_config: Optional[dict] = None,
|
| 199 |
+
load_ckpt: bool = True,
|
| 200 |
+
weighted_sum: bool = False,
|
| 201 |
+
interctc_layer_idx: Optional[list] = None,
|
| 202 |
+
interctc_weight: float = 0.0,
|
| 203 |
+
interctc_use_conditioning: bool = False,
|
| 204 |
+
interctc_ctc_type: str = "phone",
|
| 205 |
+
ctc_aux_config: Optional[dict] = None,
|
| 206 |
+
decoder_config: Optional[dict] = None,
|
| 207 |
+
ctc_weight: float = 1.0,
|
| 208 |
+
) -> XeusPRModel:
|
| 209 |
+
"""Build Xeus PR model from local files or HuggingFace repo.
|
| 210 |
+
|
| 211 |
+
Args:
|
| 212 |
+
work_dir: Directory to store downloaded files from HF repo
|
| 213 |
+
hf_repo: HuggingFace repo name (e.g., "username/xeus-pr")
|
| 214 |
+
If None, load from local files only
|
| 215 |
+
force: Whether to force re-download from HF repo
|
| 216 |
+
config_file: Path to config file. If None, use default path in work_dir.
|
| 217 |
+
Takes precedence over hf_repo download.
|
| 218 |
+
checkpoint: Path to checkpoint file. If None, use default path in work_dir.
|
| 219 |
+
Takes precedence over hf_repo download.
|
| 220 |
+
vocab_file: Path to vocabulary file. If None, use path in config.
|
| 221 |
+
ctc_config: Optional dict of CTC config
|
| 222 |
+
load_ckpt: Whether to load checkpoint weights
|
| 223 |
+
weighted_sum: Whether to use weighted sum of transformer layers
|
| 224 |
+
Returns:
|
| 225 |
+
XeusPRModel
|
| 226 |
+
"""
|
| 227 |
+
# Default relative paths in HF repo
|
| 228 |
+
REL_CONFIG = "model/config.yaml"
|
| 229 |
+
REL_CKPT = "model/xeus_checkpoint_new.pth"
|
| 230 |
+
|
| 231 |
+
# Download from HF if repo specified
|
| 232 |
+
if hf_repo:
|
| 233 |
+
log.info(f"Downloading snapshot from HuggingFace: {hf_repo}")
|
| 234 |
+
download_hf_snapshot(
|
| 235 |
+
repo_id=hf_repo,
|
| 236 |
+
force_download=force,
|
| 237 |
+
work_dir=work_dir,
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
# Resolve file paths
|
| 241 |
+
root = Path(work_dir)
|
| 242 |
+
cfg = config_file or str(root / REL_CONFIG)
|
| 243 |
+
ckpt = checkpoint or str(root / REL_CKPT)
|
| 244 |
+
|
| 245 |
+
# Verify files exist
|
| 246 |
+
assert Path(cfg).exists(), f"Config file not found: {cfg}"
|
| 247 |
+
if not load_ckpt:
|
| 248 |
+
ckpt = None
|
| 249 |
+
else:
|
| 250 |
+
assert Path(ckpt).exists(), f"Checkpoint file not found: {ckpt}"
|
| 251 |
+
|
| 252 |
+
log.info(f"Building model from config: {cfg}")
|
| 253 |
+
log.info(f"Loading checkpoint: {ckpt}")
|
| 254 |
+
|
| 255 |
+
return build_xeus_pr(
|
| 256 |
+
config_file=cfg,
|
| 257 |
+
checkpoint=ckpt,
|
| 258 |
+
vocab_file=vocab_file,
|
| 259 |
+
ctc_config=ctc_config,
|
| 260 |
+
weighted_sum=weighted_sum,
|
| 261 |
+
interctc_layer_idx=interctc_layer_idx,
|
| 262 |
+
interctc_weight=interctc_weight,
|
| 263 |
+
interctc_use_conditioning=interctc_use_conditioning,
|
| 264 |
+
interctc_ctc_type=interctc_ctc_type,
|
| 265 |
+
ctc_aux_config=ctc_aux_config,
|
| 266 |
+
decoder_config=decoder_config,
|
| 267 |
+
ctc_weight=ctc_weight,
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def build_xeus_pr_inference(
|
| 272 |
+
work_dir: str,
|
| 273 |
+
checkpoint: str,
|
| 274 |
+
vocab_file: str,
|
| 275 |
+
device,
|
| 276 |
+
config_file: Optional[str] = None,
|
| 277 |
+
hf_repo: Optional[str] = None,
|
| 278 |
+
force_download: bool = False,
|
| 279 |
+
dtype: str = "float32",
|
| 280 |
+
ctc_config: Optional[dict] = None,
|
| 281 |
+
weighted_sum: bool = False,
|
| 282 |
+
interctc_layer_idx: Optional[list] = None,
|
| 283 |
+
interctc_weight: float = 0.0,
|
| 284 |
+
interctc_use_conditioning: bool = False,
|
| 285 |
+
interctc_ctc_type: str = "phone",
|
| 286 |
+
ctc_aux_config: Optional[dict] = None,
|
| 287 |
+
decoder_config: Optional[dict] = None,
|
| 288 |
+
) -> XeusPRInference:
|
| 289 |
+
model = build_xeus_pr_from_hf(
|
| 290 |
+
work_dir=work_dir,
|
| 291 |
+
hf_repo=hf_repo,
|
| 292 |
+
force=force_download,
|
| 293 |
+
config_file=config_file,
|
| 294 |
+
checkpoint=checkpoint,
|
| 295 |
+
vocab_file=vocab_file,
|
| 296 |
+
ctc_config=ctc_config,
|
| 297 |
+
load_ckpt=True,
|
| 298 |
+
weighted_sum=weighted_sum,
|
| 299 |
+
interctc_layer_idx=interctc_layer_idx,
|
| 300 |
+
interctc_weight=interctc_weight,
|
| 301 |
+
interctc_use_conditioning=interctc_use_conditioning,
|
| 302 |
+
interctc_ctc_type=interctc_ctc_type,
|
| 303 |
+
ctc_aux_config=ctc_aux_config,
|
| 304 |
+
decoder_config=decoder_config,
|
| 305 |
+
)
|
| 306 |
+
inference_obj = XeusPRInference(model, device=device, dtype=dtype)
|
| 307 |
+
return inference_obj
|
src/model/xeusphoneme/cnn_frontend.py
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Literal, Optional, Tuple, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import Tensor, nn
|
| 5 |
+
from torch.nn import Module
|
| 6 |
+
from torch.nn import functional as F
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def dim_1_layer_norm(x, eps=1e-05, gamma=None, beta=None):
|
| 10 |
+
"""Functional version of Dim1LayerNorm."""
|
| 11 |
+
|
| 12 |
+
B, D, T = x.shape
|
| 13 |
+
mean = torch.mean(x, 1, keepdim=True)
|
| 14 |
+
variance = torch.mean((x - mean) ** 2, 1, keepdim=True)
|
| 15 |
+
|
| 16 |
+
x = (x - mean) * torch.rsqrt(variance + eps)
|
| 17 |
+
|
| 18 |
+
if gamma is not None:
|
| 19 |
+
x = x * gamma.view(1, -1, 1)
|
| 20 |
+
if beta is not None:
|
| 21 |
+
x = x + beta.view(1, -1, 1)
|
| 22 |
+
return x
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class Dim1LayerNorm(Module):
|
| 26 |
+
def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True, bias=True):
|
| 27 |
+
"""LayerNorm on middle dim.
|
| 28 |
+
|
| 29 |
+
It assumes the input is shape B, D, T
|
| 30 |
+
to avoid transposing.
|
| 31 |
+
Faster than TransposedLayerNorm, but
|
| 32 |
+
may lead to minor numerical differences.
|
| 33 |
+
"""
|
| 34 |
+
super().__init__()
|
| 35 |
+
self.normalized_shape = normalized_shape
|
| 36 |
+
self.eps = eps
|
| 37 |
+
self.elementwise_affine = elementwise_affine
|
| 38 |
+
|
| 39 |
+
self.weight = None
|
| 40 |
+
self.bias = None
|
| 41 |
+
if elementwise_affine:
|
| 42 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
| 43 |
+
if bias:
|
| 44 |
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
| 45 |
+
|
| 46 |
+
def forward(self, x):
|
| 47 |
+
assert x.size(1) == self.normalized_shape
|
| 48 |
+
return dim_1_layer_norm(x, self.eps, self.weight, self.bias)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class TransposedLayerNorm(nn.LayerNorm):
|
| 52 |
+
"""Layer norm with transpose"""
|
| 53 |
+
|
| 54 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 55 |
+
x = input.transpose(-2, -1)
|
| 56 |
+
x = nn.functional.layer_norm(
|
| 57 |
+
x, self.normalized_shape, self.weight, self.bias, self.eps
|
| 58 |
+
)
|
| 59 |
+
x = x.transpose(-2, -1)
|
| 60 |
+
return x
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class ConvLayerBlock(Module):
|
| 64 |
+
"""Convolution unit of FeatureExtractor"""
|
| 65 |
+
|
| 66 |
+
def __init__(
|
| 67 |
+
self,
|
| 68 |
+
in_channels: int,
|
| 69 |
+
out_channels: int,
|
| 70 |
+
kernel_size: int,
|
| 71 |
+
stride: int,
|
| 72 |
+
bias: bool,
|
| 73 |
+
layer_norm: Optional[Module],
|
| 74 |
+
conv_mode: str,
|
| 75 |
+
):
|
| 76 |
+
super().__init__()
|
| 77 |
+
self.kernel_size = kernel_size
|
| 78 |
+
self.stride = stride
|
| 79 |
+
self.layer_norm = layer_norm
|
| 80 |
+
|
| 81 |
+
if conv_mode == "standard":
|
| 82 |
+
self.conv = nn.Conv1d(
|
| 83 |
+
in_channels=in_channels,
|
| 84 |
+
out_channels=out_channels,
|
| 85 |
+
kernel_size=kernel_size,
|
| 86 |
+
stride=stride,
|
| 87 |
+
bias=bias,
|
| 88 |
+
)
|
| 89 |
+
elif conv_mode == "depth_only":
|
| 90 |
+
self.conv = nn.Conv1d(
|
| 91 |
+
in_channels=in_channels,
|
| 92 |
+
out_channels=out_channels,
|
| 93 |
+
kernel_size=kernel_size,
|
| 94 |
+
stride=stride,
|
| 95 |
+
bias=bias,
|
| 96 |
+
groups=in_channels,
|
| 97 |
+
)
|
| 98 |
+
elif conv_mode == "depth_sep":
|
| 99 |
+
self.conv = nn.Sequential(
|
| 100 |
+
nn.Conv1d(
|
| 101 |
+
in_channels=in_channels,
|
| 102 |
+
out_channels=in_channels,
|
| 103 |
+
kernel_size=kernel_size,
|
| 104 |
+
stride=stride,
|
| 105 |
+
bias=bias,
|
| 106 |
+
groups=in_channels,
|
| 107 |
+
),
|
| 108 |
+
nn.Conv1d(
|
| 109 |
+
in_channels=in_channels,
|
| 110 |
+
out_channels=out_channels,
|
| 111 |
+
kernel_size=1,
|
| 112 |
+
stride=1,
|
| 113 |
+
bias=bias,
|
| 114 |
+
),
|
| 115 |
+
)
|
| 116 |
+
nn.init.kaiming_normal_(self.conv.weight)
|
| 117 |
+
|
| 118 |
+
def forward(
|
| 119 |
+
self,
|
| 120 |
+
x: Tensor,
|
| 121 |
+
length: Optional[Tensor],
|
| 122 |
+
) -> Tuple[Tensor, Optional[Tensor]]:
|
| 123 |
+
"""ConvLayerBlock Forward.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
x (Tensor): Shape: ``[batch, in_channels, in_frame]``.
|
| 127 |
+
length (Tensor or None, optional): Shape ``[batch, ]``.
|
| 128 |
+
Returns:
|
| 129 |
+
Tensor: Shape ``[batch, out_channels, out_frames]``.
|
| 130 |
+
Optional[Tensor]: Shape ``[batch, ]``.
|
| 131 |
+
"""
|
| 132 |
+
x = self.conv(x)
|
| 133 |
+
if self.layer_norm is not None:
|
| 134 |
+
x = self.layer_norm(x)
|
| 135 |
+
x = nn.functional.gelu(x)
|
| 136 |
+
|
| 137 |
+
if length is not None:
|
| 138 |
+
length = (
|
| 139 |
+
torch.div(length - self.kernel_size, self.stride, rounding_mode="floor")
|
| 140 |
+
+ 1
|
| 141 |
+
)
|
| 142 |
+
# When input length is 0, the resulting length can be negative.
|
| 143 |
+
length = torch.max(torch.zeros_like(length), length)
|
| 144 |
+
return x, length
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class CNNFrontend(Module):
|
| 148 |
+
"""Convolutional feature extractor.
|
| 149 |
+
|
| 150 |
+
Typically used in SSL models.
|
| 151 |
+
Uses raw waveforms as input.
|
| 152 |
+
"""
|
| 153 |
+
|
| 154 |
+
def __init__(
|
| 155 |
+
self,
|
| 156 |
+
norm_mode: str,
|
| 157 |
+
conv_mode: str,
|
| 158 |
+
bias: bool,
|
| 159 |
+
shapes: List[Tuple[int, int, int]] = [
|
| 160 |
+
(512, 10, 5),
|
| 161 |
+
(512, 3, 2),
|
| 162 |
+
(512, 3, 2),
|
| 163 |
+
(512, 3, 2),
|
| 164 |
+
(512, 3, 2),
|
| 165 |
+
(512, 2, 2),
|
| 166 |
+
(512, 2, 2),
|
| 167 |
+
],
|
| 168 |
+
fs: Union[int, str] = 16000,
|
| 169 |
+
normalize_audio: bool = False,
|
| 170 |
+
normalize_output: bool = False,
|
| 171 |
+
layer_norm_cls: Literal["transposed", "dim1"] = "transposed",
|
| 172 |
+
):
|
| 173 |
+
|
| 174 |
+
super().__init__()
|
| 175 |
+
|
| 176 |
+
if norm_mode not in ["group_norm", "layer_norm"]:
|
| 177 |
+
raise ValueError("Invalid norm mode")
|
| 178 |
+
|
| 179 |
+
if conv_mode not in ["standard", "depth_only", "depth_sep"]:
|
| 180 |
+
raise ValueError("Invalid cnn mode")
|
| 181 |
+
|
| 182 |
+
self.output_channels = shapes[-1][0]
|
| 183 |
+
self.normalize_audio = normalize_audio
|
| 184 |
+
|
| 185 |
+
if layer_norm_cls == "dim1":
|
| 186 |
+
layer_norm_func = Dim1LayerNorm
|
| 187 |
+
else:
|
| 188 |
+
layer_norm_func = TransposedLayerNorm
|
| 189 |
+
|
| 190 |
+
blocks = []
|
| 191 |
+
in_channels = 1
|
| 192 |
+
self.downsampling_factor = 1
|
| 193 |
+
for i, (out_channels, kernel_size, stride) in enumerate(shapes):
|
| 194 |
+
normalization = None
|
| 195 |
+
if norm_mode == "group_norm" and i == 0:
|
| 196 |
+
normalization = nn.GroupNorm(
|
| 197 |
+
num_groups=out_channels,
|
| 198 |
+
num_channels=out_channels,
|
| 199 |
+
affine=True,
|
| 200 |
+
)
|
| 201 |
+
elif norm_mode == "layer_norm":
|
| 202 |
+
normalization = layer_norm_func(
|
| 203 |
+
normalized_shape=out_channels,
|
| 204 |
+
)
|
| 205 |
+
blocks.append(
|
| 206 |
+
ConvLayerBlock(
|
| 207 |
+
in_channels=in_channels,
|
| 208 |
+
out_channels=out_channels,
|
| 209 |
+
kernel_size=kernel_size,
|
| 210 |
+
stride=stride,
|
| 211 |
+
bias=bias,
|
| 212 |
+
layer_norm=normalization,
|
| 213 |
+
conv_mode=conv_mode,
|
| 214 |
+
)
|
| 215 |
+
)
|
| 216 |
+
in_channels = out_channels
|
| 217 |
+
self.downsampling_factor *= stride
|
| 218 |
+
self.layers = nn.Sequential(*blocks)
|
| 219 |
+
|
| 220 |
+
if normalize_output:
|
| 221 |
+
self.final_norm = nn.LayerNorm(self.output_channels)
|
| 222 |
+
else:
|
| 223 |
+
self.final_norm = nn.Identity()
|
| 224 |
+
|
| 225 |
+
def output_size(self) -> int:
|
| 226 |
+
return self.output_channels
|
| 227 |
+
|
| 228 |
+
def forward(
|
| 229 |
+
self,
|
| 230 |
+
x: Tensor,
|
| 231 |
+
length: Optional[Tensor],
|
| 232 |
+
) -> Tuple[Tensor, Optional[Tensor]]:
|
| 233 |
+
"""CNNFrontend Forward.
|
| 234 |
+
|
| 235 |
+
Args:
|
| 236 |
+
x (Tensor):
|
| 237 |
+
Input Tensor representing a batch of audio,
|
| 238 |
+
shape: ``[batch, time]``.
|
| 239 |
+
length (Tensor or None, optional):
|
| 240 |
+
Valid length of each input sample. shape: ``[batch, ]``.
|
| 241 |
+
|
| 242 |
+
Returns:
|
| 243 |
+
Tensor:
|
| 244 |
+
The resulting feature, shape: ``[batch, frame, feature]``
|
| 245 |
+
Optional[Tensor]:
|
| 246 |
+
Valid length of each output sample. shape: ``[batch, ]``.
|
| 247 |
+
"""
|
| 248 |
+
if x.ndim != 2:
|
| 249 |
+
raise ValueError(
|
| 250 |
+
f"Expected the input to be 2D (batch, time). Found: {list(x.shape)}"
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
if self.normalize_audio:
|
| 254 |
+
x = F.layer_norm(x, x.shape)
|
| 255 |
+
|
| 256 |
+
x = x.unsqueeze(1) # (batch, channel==1, frame)
|
| 257 |
+
for layer in self.layers:
|
| 258 |
+
x, length = layer(x, length) # (batch, feature, frame)
|
| 259 |
+
x = x.transpose(1, 2) # (batch, frame, feature)
|
| 260 |
+
x = self.final_norm(x)
|
| 261 |
+
return x, length
|
src/model/xeusphoneme/linear_layer.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Linear Projection."""
|
| 2 |
+
|
| 3 |
+
from typing import Tuple
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class LinearProjection(torch.nn.Module):
|
| 8 |
+
def __init__(self, input_size: int, output_size: int, dropout: float = 0.0):
|
| 9 |
+
super().__init__()
|
| 10 |
+
self.output_dim = output_size
|
| 11 |
+
self.linear_out = torch.nn.Linear(input_size, output_size)
|
| 12 |
+
self.dropout = torch.nn.Dropout(dropout)
|
| 13 |
+
|
| 14 |
+
def forward(
|
| 15 |
+
self, input: torch.Tensor, input_lengths: torch.Tensor
|
| 16 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 17 |
+
output = self.linear_out(self.dropout(input))
|
| 18 |
+
return output, input_lengths # no state in this layer
|
| 19 |
+
|
| 20 |
+
def output_size(self) -> int:
|
| 21 |
+
return self.output_dim
|
src/model/xeusphoneme/resources/ipa_vocab.json
ADDED
|
@@ -0,0 +1,430 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"<blank>": 0,
|
| 3 |
+
"<sos>": 1,
|
| 4 |
+
"<eos>": 2,
|
| 5 |
+
"<unk>": 3,
|
| 6 |
+
"ʈ": 4,
|
| 7 |
+
"ʎː": 5,
|
| 8 |
+
"cː": 6,
|
| 9 |
+
"œ̞": 7,
|
| 10 |
+
"ʔʲ": 8,
|
| 11 |
+
"o̤": 9,
|
| 12 |
+
"ɠ": 10,
|
| 13 |
+
"ø": 11,
|
| 14 |
+
"kˀ": 12,
|
| 15 |
+
"e̝": 13,
|
| 16 |
+
"ʈ͡ʂ": 14,
|
| 17 |
+
"ɡʰ": 15,
|
| 18 |
+
"ɟ": 16,
|
| 19 |
+
"z": 17,
|
| 20 |
+
"ʃˠ": 18,
|
| 21 |
+
"vˠ": 19,
|
| 22 |
+
"ǃʰ": 20,
|
| 23 |
+
"dʷ": 21,
|
| 24 |
+
"ĩ": 22,
|
| 25 |
+
"nˠ": 23,
|
| 26 |
+
"ə": 24,
|
| 27 |
+
"t͡ʃʰ": 25,
|
| 28 |
+
"d̤": 26,
|
| 29 |
+
"fʲ": 27,
|
| 30 |
+
"xʷ": 28,
|
| 31 |
+
"ɛ̃": 29,
|
| 32 |
+
"ʃʰ": 30,
|
| 33 |
+
"ʃ̩": 31,
|
| 34 |
+
"ɤˀ": 32,
|
| 35 |
+
"əː": 33,
|
| 36 |
+
"ɛ̯": 34,
|
| 37 |
+
"ɞ": 35,
|
| 38 |
+
"yː": 36,
|
| 39 |
+
"fʷ": 37,
|
| 40 |
+
"ẽ": 38,
|
| 41 |
+
"rˤ": 39,
|
| 42 |
+
"ɒ": 40,
|
| 43 |
+
"ɲː": 41,
|
| 44 |
+
"j": 42,
|
| 45 |
+
"f": 43,
|
| 46 |
+
"ɲ̥": 44,
|
| 47 |
+
"ʃː": 45,
|
| 48 |
+
"l": 46,
|
| 49 |
+
"ʒ̩": 47,
|
| 50 |
+
"ɛ̝": 48,
|
| 51 |
+
"ð̞": 49,
|
| 52 |
+
"ʃʲ": 50,
|
| 53 |
+
"ɛ": 51,
|
| 54 |
+
"ɟː": 52,
|
| 55 |
+
"ʌ": 53,
|
| 56 |
+
"ʍ": 54,
|
| 57 |
+
"kʰ": 55,
|
| 58 |
+
"p͡f": 56,
|
| 59 |
+
"ɜː": 57,
|
| 60 |
+
"ɘ": 58,
|
| 61 |
+
"bʷ": 59,
|
| 62 |
+
"sː": 60,
|
| 63 |
+
"ɡː": 61,
|
| 64 |
+
"o̝": 62,
|
| 65 |
+
"cʼ": 63,
|
| 66 |
+
"tʰ": 64,
|
| 67 |
+
"kʷ": 65,
|
| 68 |
+
"ŋ̥": 66,
|
| 69 |
+
"r̝": 67,
|
| 70 |
+
"ɸː": 68,
|
| 71 |
+
"u̝": 69,
|
| 72 |
+
"ṳ": 70,
|
| 73 |
+
"β̞": 71,
|
| 74 |
+
"ɾː": 72,
|
| 75 |
+
"ɔˤ": 73,
|
| 76 |
+
"ʎ": 74,
|
| 77 |
+
"ʊ̃": 75,
|
| 78 |
+
"pˀ": 76,
|
| 79 |
+
"m̩": 77,
|
| 80 |
+
"ɕː": 78,
|
| 81 |
+
"ɪ̯": 79,
|
| 82 |
+
"ɖʰ": 80,
|
| 83 |
+
"ɰ": 81,
|
| 84 |
+
"t̠": 82,
|
| 85 |
+
"t͡ʃʲ": 83,
|
| 86 |
+
"ɡ̤": 84,
|
| 87 |
+
"j̩": 85,
|
| 88 |
+
"ɭ̩": 86,
|
| 89 |
+
"ŋ̰": 87,
|
| 90 |
+
"p": 88,
|
| 91 |
+
"ɾ": 89,
|
| 92 |
+
"sʲ": 90,
|
| 93 |
+
"ɲ̤": 91,
|
| 94 |
+
"cʰ": 92,
|
| 95 |
+
"a̯": 93,
|
| 96 |
+
"ɡʷ": 94,
|
| 97 |
+
"t͡s": 95,
|
| 98 |
+
"ɨ̯": 96,
|
| 99 |
+
"n̩": 97,
|
| 100 |
+
"ʌː": 98,
|
| 101 |
+
"ɤ": 99,
|
| 102 |
+
"l̩": 100,
|
| 103 |
+
"l̴": 101,
|
| 104 |
+
"pʲ": 102,
|
| 105 |
+
"k": 103,
|
| 106 |
+
"jː": 104,
|
| 107 |
+
"ɛ̈": 105,
|
| 108 |
+
"t͡ʃː": 106,
|
| 109 |
+
"dˠ": 107,
|
| 110 |
+
"ɱ̩": 108,
|
| 111 |
+
"ɯː": 109,
|
| 112 |
+
"kʼ": 110,
|
| 113 |
+
"ɑ̯": 111,
|
| 114 |
+
"zʷ": 112,
|
| 115 |
+
"çː": 113,
|
| 116 |
+
"ã": 114,
|
| 117 |
+
"sˠ": 115,
|
| 118 |
+
"s̻": 116,
|
| 119 |
+
"ɐ": 117,
|
| 120 |
+
"ɸʷ": 118,
|
| 121 |
+
"ɔ̃": 119,
|
| 122 |
+
"bˠ": 120,
|
| 123 |
+
"ʈː": 121,
|
| 124 |
+
"ʂ": 122,
|
| 125 |
+
"ɑ": 123,
|
| 126 |
+
"ë": 124,
|
| 127 |
+
"ɸ": 125,
|
| 128 |
+
"ɮʲ": 126,
|
| 129 |
+
"nː": 127,
|
| 130 |
+
"mʷ": 128,
|
| 131 |
+
"ǁ": 129,
|
| 132 |
+
"ʒ": 130,
|
| 133 |
+
"jˠ": 131,
|
| 134 |
+
"d": 132,
|
| 135 |
+
"tː": 133,
|
| 136 |
+
"ɤ̆": 134,
|
| 137 |
+
"s̺": 135,
|
| 138 |
+
"mː": 136,
|
| 139 |
+
"ɻ": 137,
|
| 140 |
+
"l̪": 138,
|
| 141 |
+
"ɜ": 139,
|
| 142 |
+
"ɓ": 140,
|
| 143 |
+
"ü": 141,
|
| 144 |
+
"lʲ": 142,
|
| 145 |
+
"tˠ": 143,
|
| 146 |
+
"ŋː": 144,
|
| 147 |
+
"ŋʲ": 145,
|
| 148 |
+
"h̩": 146,
|
| 149 |
+
"qʷ": 147,
|
| 150 |
+
"tʼ": 148,
|
| 151 |
+
"ə̯": 149,
|
| 152 |
+
"t͡sʲː": 150,
|
| 153 |
+
"m̤": 151,
|
| 154 |
+
"ɕʰ": 152,
|
| 155 |
+
"nʲ": 153,
|
| 156 |
+
"rˠ": 154,
|
| 157 |
+
"ɖ̤": 155,
|
| 158 |
+
"ø̈": 156,
|
| 159 |
+
"ɯˀ": 157,
|
| 160 |
+
"mʲ": 158,
|
| 161 |
+
"n̥": 159,
|
| 162 |
+
"mˤ": 160,
|
| 163 |
+
"ʒʲ": 161,
|
| 164 |
+
"æ": 162,
|
| 165 |
+
"tʷ": 163,
|
| 166 |
+
"d̪": 164,
|
| 167 |
+
"ʔ": 165,
|
| 168 |
+
"a̠": 166,
|
| 169 |
+
"ɾˠ": 167,
|
| 170 |
+
"ʉ": 168,
|
| 171 |
+
"ɔ̯": 169,
|
| 172 |
+
"zʲ": 170,
|
| 173 |
+
"ɳː": 171,
|
| 174 |
+
"t͡sː": 172,
|
| 175 |
+
"æ̯": 173,
|
| 176 |
+
"r̤": 174,
|
| 177 |
+
"ɑː": 175,
|
| 178 |
+
"ɘː": 176,
|
| 179 |
+
"ə˞": 177,
|
| 180 |
+
"zˤ": 178,
|
| 181 |
+
"õ": 179,
|
| 182 |
+
"əˀ": 180,
|
| 183 |
+
"e": 181,
|
| 184 |
+
"nˤ": 182,
|
| 185 |
+
"u": 183,
|
| 186 |
+
"ɑ̃": 184,
|
| 187 |
+
"o": 185,
|
| 188 |
+
"ħ": 186,
|
| 189 |
+
"ŋ": 187,
|
| 190 |
+
"mˠ": 188,
|
| 191 |
+
"i": 189,
|
| 192 |
+
"rʲ": 190,
|
| 193 |
+
"ɔ": 191,
|
| 194 |
+
"xʰ": 192,
|
| 195 |
+
"dˤ": 193,
|
| 196 |
+
"s̩": 194,
|
| 197 |
+
"t͡ɕʰ": 195,
|
| 198 |
+
"ɔ̈": 196,
|
| 199 |
+
"ĕ": 197,
|
| 200 |
+
"ɴ": 198,
|
| 201 |
+
"k͡x": 199,
|
| 202 |
+
"d͡ʒ": 200,
|
| 203 |
+
"dʲ": 201,
|
| 204 |
+
"æ̞": 202,
|
| 205 |
+
"ɡ̃": 203,
|
| 206 |
+
"uː": 204,
|
| 207 |
+
"pʰ": 205,
|
| 208 |
+
"ʁ": 206,
|
| 209 |
+
"n̪": 207,
|
| 210 |
+
"zˠ": 208,
|
| 211 |
+
"ø̞": 209,
|
| 212 |
+
"ɔː": 210,
|
| 213 |
+
"ɳ": 211,
|
| 214 |
+
"vʲ": 212,
|
| 215 |
+
"œ̃": 213,
|
| 216 |
+
"ɾ̝": 214,
|
| 217 |
+
"ũ": 215,
|
| 218 |
+
"ĭ": 216,
|
| 219 |
+
"ɐ̯": 217,
|
| 220 |
+
"ʁ̝": 218,
|
| 221 |
+
"qʼ": 219,
|
| 222 |
+
"β": 220,
|
| 223 |
+
"pʼ": 221,
|
| 224 |
+
"ɡ͡b": 222,
|
| 225 |
+
"oː": 223,
|
| 226 |
+
"ɲ": 224,
|
| 227 |
+
"j̃": 225,
|
| 228 |
+
"l̠": 226,
|
| 229 |
+
"a": 227,
|
| 230 |
+
"d͡ʑ": 228,
|
| 231 |
+
"œː": 229,
|
| 232 |
+
"t̪": 230,
|
| 233 |
+
"zː": 231,
|
| 234 |
+
"ʁ̩": 232,
|
| 235 |
+
"ɔ̤": 233,
|
| 236 |
+
"œ": 234,
|
| 237 |
+
"dʰ": 235,
|
| 238 |
+
"lː": 236,
|
| 239 |
+
"z̤": 237,
|
| 240 |
+
"sʰ": 238,
|
| 241 |
+
"ʏ̯": 239,
|
| 242 |
+
"ð": 240,
|
| 243 |
+
"r̩": 241,
|
| 244 |
+
"n̤": 242,
|
| 245 |
+
"ɭʲ": 243,
|
| 246 |
+
"ɭː": 244,
|
| 247 |
+
"ə̃": 245,
|
| 248 |
+
"ä": 246,
|
| 249 |
+
"ʀ": 247,
|
| 250 |
+
"æː": 248,
|
| 251 |
+
"ɡʲ": 249,
|
| 252 |
+
"ɪ̃": 250,
|
| 253 |
+
"lˠ": 251,
|
| 254 |
+
"ʊː": 252,
|
| 255 |
+
"cʲ": 253,
|
| 256 |
+
"ă": 254,
|
| 257 |
+
"d͡ʒː": 255,
|
| 258 |
+
"i̯": 256,
|
| 259 |
+
"ʉː": 257,
|
| 260 |
+
"t͡ɕː": 258,
|
| 261 |
+
"ɬ": 259,
|
| 262 |
+
"fˀ": 260,
|
| 263 |
+
"bʲ": 261,
|
| 264 |
+
"ɐ̃": 262,
|
| 265 |
+
"ɣ̤": 263,
|
| 266 |
+
"xʲ": 264,
|
| 267 |
+
"ɛ̆": 265,
|
| 268 |
+
"θ": 266,
|
| 269 |
+
"ɵː": 267,
|
| 270 |
+
"ɨ̞": 268,
|
| 271 |
+
"ɡ": 269,
|
| 272 |
+
"ð̠": 270,
|
| 273 |
+
"l̤": 271,
|
| 274 |
+
"w̃": 272,
|
| 275 |
+
"ɹ": 273,
|
| 276 |
+
"ɣʲ": 274,
|
| 277 |
+
"wˠ": 275,
|
| 278 |
+
"u̯": 276,
|
| 279 |
+
"wː": 277,
|
| 280 |
+
"ʐ": 278,
|
| 281 |
+
"ɵ": 279,
|
| 282 |
+
"ðˠ": 280,
|
| 283 |
+
"t͡ʃʼ": 281,
|
| 284 |
+
"pʷ": 282,
|
| 285 |
+
"v̤": 283,
|
| 286 |
+
"ǀʰ": 284,
|
| 287 |
+
"x": 285,
|
| 288 |
+
"ɥ": 286,
|
| 289 |
+
"ʂː": 287,
|
| 290 |
+
"r": 288,
|
| 291 |
+
"o̞": 289,
|
| 292 |
+
"ðˤ": 290,
|
| 293 |
+
"ɨ̃": 291,
|
| 294 |
+
"ʊ": 292,
|
| 295 |
+
"ʙ": 293,
|
| 296 |
+
"b̤": 294,
|
| 297 |
+
"ŋ̤": 295,
|
| 298 |
+
"kʲ": 296,
|
| 299 |
+
"ʏː": 297,
|
| 300 |
+
"ʄ": 298,
|
| 301 |
+
"eː": 299,
|
| 302 |
+
"ɗ": 300,
|
| 303 |
+
"ʏ̈": 301,
|
| 304 |
+
"ɛˤ": 302,
|
| 305 |
+
"w": 303,
|
| 306 |
+
"pː": 304,
|
| 307 |
+
"ɖ": 305,
|
| 308 |
+
"ɧ": 306,
|
| 309 |
+
"h": 307,
|
| 310 |
+
"ǁʰ": 308,
|
| 311 |
+
"hʲ": 309,
|
| 312 |
+
"ʃ": 310,
|
| 313 |
+
"ɑ̈": 311,
|
| 314 |
+
"d͡z": 312,
|
| 315 |
+
"bˤ": 313,
|
| 316 |
+
"k͡p": 314,
|
| 317 |
+
"ð̩": 315,
|
| 318 |
+
"n̠": 316,
|
| 319 |
+
"bː": 317,
|
| 320 |
+
"f̩": 318,
|
| 321 |
+
"wʲ": 319,
|
| 322 |
+
"o̯": 320,
|
| 323 |
+
"ʁː": 321,
|
| 324 |
+
"pˠ": 322,
|
| 325 |
+
"kː": 323,
|
| 326 |
+
"ɪˤ": 324,
|
| 327 |
+
"ʑː": 325,
|
| 328 |
+
"ʌ̃": 326,
|
| 329 |
+
"ɪː": 327,
|
| 330 |
+
"ǃ": 328,
|
| 331 |
+
"ç": 329,
|
| 332 |
+
"s": 330,
|
| 333 |
+
"hː": 331,
|
| 334 |
+
"rː": 332,
|
| 335 |
+
"tˤ": 333,
|
| 336 |
+
"ɦʲ": 334,
|
| 337 |
+
"ŋ̩": 335,
|
| 338 |
+
"m̥": 336,
|
| 339 |
+
"ɖː": 337,
|
| 340 |
+
"ɭ": 338,
|
| 341 |
+
"mˀ": 339,
|
| 342 |
+
"n": 340,
|
| 343 |
+
"iː": 341,
|
| 344 |
+
"æ̝": 342,
|
| 345 |
+
"xː": 343,
|
| 346 |
+
"i̤": 344,
|
| 347 |
+
"ɽ̤": 345,
|
| 348 |
+
"ɶ": 346,
|
| 349 |
+
"ˀs": 347,
|
| 350 |
+
"l̥": 348,
|
| 351 |
+
"ɱ": 349,
|
| 352 |
+
"e̞": 350,
|
| 353 |
+
"ʋ": 351,
|
| 354 |
+
"y̯": 352,
|
| 355 |
+
"lˤ": 353,
|
| 356 |
+
"ö": 354,
|
| 357 |
+
"a̝": 355,
|
| 358 |
+
"ɶː": 356,
|
| 359 |
+
"t͡sʼ": 357,
|
| 360 |
+
"s̠": 358,
|
| 361 |
+
"t͡sʲ": 359,
|
| 362 |
+
"ɪ": 360,
|
| 363 |
+
"y̆": 361,
|
| 364 |
+
"ɤː": 362,
|
| 365 |
+
"ɟʰ": 363,
|
| 366 |
+
"ʒː": 364,
|
| 367 |
+
"tʲ": 365,
|
| 368 |
+
"ɕ": 366,
|
| 369 |
+
"ɨ": 367,
|
| 370 |
+
"c": 368,
|
| 371 |
+
"t͡ʃ": 369,
|
| 372 |
+
"ʑ": 370,
|
| 373 |
+
"ʝ": 371,
|
| 374 |
+
"ʋ̥": 372,
|
| 375 |
+
"ɢ": 373,
|
| 376 |
+
"ɛː": 374,
|
| 377 |
+
"b": 375,
|
| 378 |
+
"øː": 376,
|
| 379 |
+
"ǀ": 377,
|
| 380 |
+
"ʏ": 378,
|
| 381 |
+
"i̝": 379,
|
| 382 |
+
"ʊ̯": 380,
|
| 383 |
+
"ʊˤ": 381,
|
| 384 |
+
"ɐˤ": 382,
|
| 385 |
+
"r̥": 383,
|
| 386 |
+
"t͡sʰ": 384,
|
| 387 |
+
"aː": 385,
|
| 388 |
+
"t͡ɬ": 386,
|
| 389 |
+
"ʋː": 387,
|
| 390 |
+
"sˤ": 388,
|
| 391 |
+
"s̪": 389,
|
| 392 |
+
"dː": 390,
|
| 393 |
+
"ɪ̈": 391,
|
| 394 |
+
"ɨː": 392,
|
| 395 |
+
"ɽʷ": 393,
|
| 396 |
+
"ʕ": 394,
|
| 397 |
+
"ɒː": 395,
|
| 398 |
+
"χ": 396,
|
| 399 |
+
"fˠ": 397,
|
| 400 |
+
"ɯ": 398,
|
| 401 |
+
"hˠ": 399,
|
| 402 |
+
"jˤ": 400,
|
| 403 |
+
"tˀ": 401,
|
| 404 |
+
"ɣ": 402,
|
| 405 |
+
"y": 403,
|
| 406 |
+
"ɦ": 404,
|
| 407 |
+
"ʈʰ": 405,
|
| 408 |
+
"t͡ɕ": 406,
|
| 409 |
+
"vː": 407,
|
| 410 |
+
"m": 408,
|
| 411 |
+
"ɮ": 409,
|
| 412 |
+
"e̤": 410,
|
| 413 |
+
"ʋʲ": 411,
|
| 414 |
+
"æ̃": 412,
|
| 415 |
+
"v": 413,
|
| 416 |
+
"ɽ": 414,
|
| 417 |
+
"t": 415,
|
| 418 |
+
"a̤": 416,
|
| 419 |
+
"e̯": 417,
|
| 420 |
+
"ɜ˞": 418,
|
| 421 |
+
"q": 419,
|
| 422 |
+
"bʰ": 420,
|
| 423 |
+
"t͡sˠ": 421,
|
| 424 |
+
"ʂʰ": 422,
|
| 425 |
+
"fː": 423,
|
| 426 |
+
"sʷ": 424,
|
| 427 |
+
"ɾʲ": 425,
|
| 428 |
+
"w̤": 426,
|
| 429 |
+
"fˤ": 427
|
| 430 |
+
}
|
src/model/xeusphoneme/xeuspr_inference.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Compatible with distributed inference api, uses greedy ctc inference strategy
|
| 2 |
+
# python -m src.model.xeusphoneme.xeuspr_inference
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from typing import Union, List, Dict, Any, Optional
|
| 7 |
+
|
| 8 |
+
from src.recipe.phone_recognition.greedy_ctc_strategy import GreedyCTCInference
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class XeusPRInference:
|
| 12 |
+
"""Greedy inference for Xeus Phoneme Recognition model."""
|
| 13 |
+
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
model: torch.nn.Module,
|
| 17 |
+
device: str = "cpu",
|
| 18 |
+
dtype: str = "float32",
|
| 19 |
+
):
|
| 20 |
+
self.device = device
|
| 21 |
+
self.dtype = getattr(torch, dtype)
|
| 22 |
+
self.model = model.to(device=self.device, dtype=self.dtype).eval()
|
| 23 |
+
|
| 24 |
+
self.token_list = model.token_list
|
| 25 |
+
self.blank_id = model.get_blank_id()
|
| 26 |
+
self.ignore_id = getattr(model, "ignore_id", -1)
|
| 27 |
+
self.inference_strategy = GreedyCTCInference(
|
| 28 |
+
token_list=self.token_list, blank_id=self.blank_id
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
@torch.no_grad()
|
| 32 |
+
def __call__(
|
| 33 |
+
self, speech: Union[torch.Tensor, np.ndarray], **kwargs
|
| 34 |
+
) -> List[Dict[str, Any]]:
|
| 35 |
+
"""
|
| 36 |
+
Perform greedy inference.
|
| 37 |
+
Args:
|
| 38 |
+
speech: Input speech of shape (nsamples,) or (batch, nsamples)
|
| 39 |
+
Returns:
|
| 40 |
+
List of results matching Powsm API
|
| 41 |
+
"""
|
| 42 |
+
# 1. Prepare Input
|
| 43 |
+
if isinstance(speech, np.ndarray):
|
| 44 |
+
speech = torch.from_numpy(speech)
|
| 45 |
+
|
| 46 |
+
if speech.dim() == 1:
|
| 47 |
+
speech = speech.unsqueeze(0)
|
| 48 |
+
|
| 49 |
+
speech = speech.to(device=self.device, dtype=self.dtype)
|
| 50 |
+
speech_lengths = torch.full(
|
| 51 |
+
(speech.size(0),), speech.size(1), device=self.device, dtype=torch.long
|
| 52 |
+
)
|
| 53 |
+
results = self.inference_strategy(
|
| 54 |
+
model=self.model,
|
| 55 |
+
speech=speech,
|
| 56 |
+
speech_lengths=speech_lengths,
|
| 57 |
+
**kwargs,
|
| 58 |
+
)
|
| 59 |
+
return results
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
if __name__ == "__main__":
|
| 63 |
+
from src.model.xeusphoneme.builders import build_xeus_pr_inference
|
| 64 |
+
|
| 65 |
+
# Example usage
|
| 66 |
+
ckpt_path = "path/to/checkpoints/last.ckpt"
|
| 67 |
+
work_dir = "path/to/exp/cache/xeus"
|
| 68 |
+
vocab_file = "src/model/xeusphoneme/resources/ipa_vocab.json"
|
| 69 |
+
device = "cpu" if not torch.cuda.is_available() else "cuda:0"
|
| 70 |
+
inference_obj = build_xeus_pr_inference(
|
| 71 |
+
work_dir=work_dir,
|
| 72 |
+
checkpoint=ckpt_path,
|
| 73 |
+
vocab_file=vocab_file,
|
| 74 |
+
hf_repo="espnet/xeus",
|
| 75 |
+
config_file=None,
|
| 76 |
+
device=device,
|
| 77 |
+
force_download=False,
|
| 78 |
+
)
|
| 79 |
+
import torchaudio
|
| 80 |
+
|
| 81 |
+
speechpath = "path/to/test_audio.wav"
|
| 82 |
+
speech = torchaudio.load(speechpath)[0].squeeze(0)
|
| 83 |
+
# speech = speech[: 16000 * 40] # 10 seconds of audio
|
| 84 |
+
# dummy_speech = np.random.randn(16000 * 5).astype(np.float32) # 5 seconds of audio
|
| 85 |
+
results = inference_obj(speech=speech)
|
| 86 |
+
print(results)
|
src/model/xeusphoneme/xeuspr_model.py
ADDED
|
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Xeus Phoneme Recognition Model.
|
| 3 |
+
# -*- coding: utf-8 -*-
|
| 4 |
+
|
| 5 |
+
# Copyright 2025 William Chen. Adapted from ESPnet.
|
| 6 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
| 7 |
+
|
| 8 |
+
Usage:
|
| 9 |
+
python -m src.model.xeusphoneme.xeuspr_model \
|
| 10 |
+
--work_dir path/to/cache/xeus
|
| 11 |
+
"""
|
| 12 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
| 13 |
+
import argparse
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
import torchaudio
|
| 18 |
+
from src.model.powsm.utils import force_gatherable
|
| 19 |
+
from src.espnet_import.nets_utils import make_pad_mask, pad_list, th_accuracy
|
| 20 |
+
from src.espnet_import.label_smoothing_loss import LabelSmoothingLoss
|
| 21 |
+
|
| 22 |
+
try:
|
| 23 |
+
from src.recipe.phone_recognition.error_calculator import (
|
| 24 |
+
ErrorCalculator,
|
| 25 |
+
)
|
| 26 |
+
except ImportError:
|
| 27 |
+
|
| 28 |
+
class ErrorCalculator:
|
| 29 |
+
"""No-op stub when rapidfuzz/panphon are unavailable."""
|
| 30 |
+
|
| 31 |
+
def __init__(self, *args, **kwargs):
|
| 32 |
+
pass
|
| 33 |
+
|
| 34 |
+
def __call__(self, *args, **kwargs):
|
| 35 |
+
return {}
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
from src.model.powsm.ctc import CTC
|
| 39 |
+
from src.utils import RankedLogger
|
| 40 |
+
|
| 41 |
+
log = RankedLogger(__name__, rank_zero_only=False)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class XeusPRModel(torch.nn.Module):
|
| 45 |
+
"""Encoder-only CTC model for phone recognition using Xeus pretrained weights."""
|
| 46 |
+
|
| 47 |
+
def __init__(
|
| 48 |
+
self,
|
| 49 |
+
encoder: Any,
|
| 50 |
+
ctc: CTC,
|
| 51 |
+
token_list: Union[Tuple, list],
|
| 52 |
+
frontend: Optional[Any] = None,
|
| 53 |
+
specaug: Optional[Any] = None,
|
| 54 |
+
normalize: Optional[Any] = None,
|
| 55 |
+
preencoder: Optional[Any] = None,
|
| 56 |
+
ignore_id: int = -1,
|
| 57 |
+
sym_blank: str = "<blank>",
|
| 58 |
+
freeze_frontend: bool = True,
|
| 59 |
+
weighted_sum: bool = False,
|
| 60 |
+
interctc_weight: float = 0.0,
|
| 61 |
+
interctc_use_conditioning: bool = False,
|
| 62 |
+
interctc_ctc_type: str = "phone",
|
| 63 |
+
ctc_aux: Optional[Any] = None,
|
| 64 |
+
decoder: Optional[Any] = None,
|
| 65 |
+
ctc_weight: float = 1.0,
|
| 66 |
+
lsm_weight: float = 0.0,
|
| 67 |
+
sym_sos: str = "<sos>",
|
| 68 |
+
sym_eos: str = "<eos>",
|
| 69 |
+
**kwargs,
|
| 70 |
+
):
|
| 71 |
+
super().__init__()
|
| 72 |
+
self.frontend = frontend
|
| 73 |
+
self.specaug = specaug
|
| 74 |
+
self.normalize = normalize
|
| 75 |
+
self.preencoder = preencoder
|
| 76 |
+
self.encoder = encoder
|
| 77 |
+
self.ctc = ctc
|
| 78 |
+
self.ctc_aux = ctc_aux
|
| 79 |
+
self.interctc_ctc_type = interctc_ctc_type
|
| 80 |
+
if interctc_use_conditioning:
|
| 81 |
+
vocab_size_cond = (
|
| 82 |
+
ctc_aux.ctc_lo.out_features
|
| 83 |
+
if interctc_ctc_type == "ortho" and ctc_aux is not None
|
| 84 |
+
else len(token_list)
|
| 85 |
+
)
|
| 86 |
+
self.encoder.conditioning_layer = torch.nn.Linear(
|
| 87 |
+
vocab_size_cond, encoder.output_size()
|
| 88 |
+
)
|
| 89 |
+
self.encoder.interctc_use_conditioning = True
|
| 90 |
+
self.token_list = list(token_list)
|
| 91 |
+
self.ignore_id = ignore_id
|
| 92 |
+
self.blank_id = token_list.index(sym_blank) if sym_blank in token_list else 0
|
| 93 |
+
sym_space = kwargs.get("sym_space", "<space>")
|
| 94 |
+
self.freeze_frontend = freeze_frontend
|
| 95 |
+
self.error_calculator = ErrorCalculator(
|
| 96 |
+
token_list,
|
| 97 |
+
blank_id=self.blank_id,
|
| 98 |
+
sym_space=sym_space,
|
| 99 |
+
ignore_id=ignore_id,
|
| 100 |
+
log_phone_metrics=True,
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
self.decoder = decoder
|
| 104 |
+
self.ctc_weight = ctc_weight
|
| 105 |
+
if decoder is not None:
|
| 106 |
+
self.sos = token_list.index(sym_sos)
|
| 107 |
+
self.eos = token_list.index(sym_eos)
|
| 108 |
+
self.criterion_att = LabelSmoothingLoss(
|
| 109 |
+
size=len(token_list),
|
| 110 |
+
padding_idx=ignore_id,
|
| 111 |
+
smoothing=lsm_weight,
|
| 112 |
+
normalize_length=False,
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
self.weighted_sum = weighted_sum
|
| 116 |
+
if self.weighted_sum:
|
| 117 |
+
n_layers = encoder.num_blocks
|
| 118 |
+
assert (
|
| 119 |
+
n_layers is not None and n_layers > 0
|
| 120 |
+
), "Cannot infer number of encoder layers for weighted_sum"
|
| 121 |
+
self.layer_weights = torch.nn.Parameter(torch.zeros(int(n_layers)))
|
| 122 |
+
self.interctc_weight = interctc_weight
|
| 123 |
+
self.sampling_rate = 16000
|
| 124 |
+
|
| 125 |
+
def points_by_frames(self) -> int:
|
| 126 |
+
"""Samples per encoder frame (CNN downsampling factor)."""
|
| 127 |
+
return self.frontend.downsampling_factor
|
| 128 |
+
|
| 129 |
+
@torch.no_grad()
|
| 130 |
+
def forced_align(self, speech, speech_lengths, text, text_lengths, utt_id=None):
|
| 131 |
+
"""CTC forced alignment via torchaudio.functional.forced_align (batch size 1)."""
|
| 132 |
+
assert speech.shape[0] == 1, "forced_align requires batch size 1"
|
| 133 |
+
text = text[:, : text_lengths.max()]
|
| 134 |
+
logits, logit_lengths = self.ctc_logits(speech, speech_lengths)
|
| 135 |
+
log_probs = F.log_softmax(logits, dim=-1)
|
| 136 |
+
align_label, align_prob = torchaudio.functional.forced_align(
|
| 137 |
+
log_probs, text, logit_lengths, text_lengths, blank=self.blank_id
|
| 138 |
+
)
|
| 139 |
+
return align_label, align_prob
|
| 140 |
+
|
| 141 |
+
def collect_feats(
|
| 142 |
+
self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs
|
| 143 |
+
) -> Dict[str, torch.Tensor]:
|
| 144 |
+
"""Extract features for stats collection."""
|
| 145 |
+
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
|
| 146 |
+
return {"feats": feats, "feats_lengths": feats_lengths}
|
| 147 |
+
|
| 148 |
+
def forward(self, speech, speech_lengths, text, text_lengths, **kwargs):
|
| 149 |
+
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
|
| 150 |
+
|
| 151 |
+
intermediate_outs = None
|
| 152 |
+
if isinstance(encoder_out, tuple):
|
| 153 |
+
intermediate_outs = encoder_out[1]
|
| 154 |
+
encoder_out = encoder_out[0]
|
| 155 |
+
|
| 156 |
+
loss_ctc, stats = self._calc_ctc_loss(
|
| 157 |
+
encoder_out, encoder_out_lens, text, text_lengths, **kwargs
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
if self.interctc_weight > 0.0 and intermediate_outs:
|
| 161 |
+
if self.interctc_ctc_type == "ortho" and self.ctc_aux is not None:
|
| 162 |
+
ctc_inter = self.ctc_aux
|
| 163 |
+
ys_inter = kwargs.get("asr_text_tokens")
|
| 164 |
+
ys_inter_lens = kwargs.get("asr_text_length")
|
| 165 |
+
else:
|
| 166 |
+
ctc_inter = self.ctc
|
| 167 |
+
ys_inter = torch.where(text == -1, self.ignore_id, text)[
|
| 168 |
+
:, : text_lengths.max()
|
| 169 |
+
]
|
| 170 |
+
ys_inter_lens = text_lengths
|
| 171 |
+
|
| 172 |
+
if ys_inter is not None and ys_inter_lens is not None:
|
| 173 |
+
loss_interctc = 0.0
|
| 174 |
+
for layer_idx, intermediate_out in intermediate_outs:
|
| 175 |
+
loss_ic = ctc_inter(
|
| 176 |
+
intermediate_out,
|
| 177 |
+
encoder_out_lens,
|
| 178 |
+
ys_inter,
|
| 179 |
+
ys_inter_lens,
|
| 180 |
+
)
|
| 181 |
+
loss_interctc = loss_interctc + loss_ic
|
| 182 |
+
stats[f"loss_interctc_layer{layer_idx}"] = loss_ic.detach()
|
| 183 |
+
loss_interctc = loss_interctc / len(intermediate_outs)
|
| 184 |
+
loss_ctc = (
|
| 185 |
+
1 - self.interctc_weight
|
| 186 |
+
) * loss_ctc + self.interctc_weight * loss_interctc
|
| 187 |
+
|
| 188 |
+
# Attention branch
|
| 189 |
+
if self.ctc_weight < 1.0 and self.decoder is not None:
|
| 190 |
+
loss_att, acc_att = self._calc_att_loss(
|
| 191 |
+
encoder_out, encoder_out_lens, text, text_lengths
|
| 192 |
+
)
|
| 193 |
+
stats["loss_att"] = loss_att.detach()
|
| 194 |
+
stats["acc_att"] = acc_att
|
| 195 |
+
loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
|
| 196 |
+
else:
|
| 197 |
+
loss = loss_ctc
|
| 198 |
+
|
| 199 |
+
loss, stats, weight = force_gatherable(
|
| 200 |
+
(loss, stats, speech.shape[0]), loss.device
|
| 201 |
+
)
|
| 202 |
+
return {"loss": loss, "stats": stats, "weight": weight}
|
| 203 |
+
|
| 204 |
+
def _extract_feats(
|
| 205 |
+
self, speech: torch.Tensor, speech_lengths: torch.Tensor
|
| 206 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 207 |
+
"""Extract features using frontend."""
|
| 208 |
+
speech = speech[:, : speech_lengths.max()]
|
| 209 |
+
return (
|
| 210 |
+
self.frontend(speech, speech_lengths)
|
| 211 |
+
if self.frontend
|
| 212 |
+
else (speech, speech_lengths)
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
def _apply_preprocessing(
|
| 216 |
+
self, speech: torch.Tensor, speech_lengths: torch.Tensor
|
| 217 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 218 |
+
"""Apply frontend, specaug, normalize, and preencoder."""
|
| 219 |
+
speech, speech_lengths = self._extract_feats(speech, speech_lengths)
|
| 220 |
+
|
| 221 |
+
if self.specaug and self.training:
|
| 222 |
+
speech, speech_lengths = self.specaug(speech, speech_lengths)
|
| 223 |
+
|
| 224 |
+
if self.normalize:
|
| 225 |
+
speech, speech_lengths = self.normalize(speech, speech_lengths)
|
| 226 |
+
|
| 227 |
+
if self.preencoder:
|
| 228 |
+
speech, speech_lengths = self.preencoder(speech, speech_lengths)
|
| 229 |
+
|
| 230 |
+
return speech, speech_lengths
|
| 231 |
+
|
| 232 |
+
def encode(
|
| 233 |
+
self, speech: torch.Tensor, speech_lengths: torch.Tensor
|
| 234 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 235 |
+
"""Encode speech to frame-level representations.
|
| 236 |
+
|
| 237 |
+
When weighted_sum=True, returns a weighted sum of all encoder layers.
|
| 238 |
+
Otherwise, calls the encoder without return_all_hs; if interctc_layer_idx
|
| 239 |
+
is configured on the encoder, returns (final_out, [(layer_idx, tensor), ...]).
|
| 240 |
+
"""
|
| 241 |
+
speech, speech_lengths = self._apply_preprocessing(speech, speech_lengths)
|
| 242 |
+
pad_masks = make_pad_mask(speech_lengths).to(speech.device)
|
| 243 |
+
if self.weighted_sum:
|
| 244 |
+
encoder_out, encoder_out_lens, _ = self.encoder(
|
| 245 |
+
speech, speech_lengths, masks=pad_masks, return_all_hs=True
|
| 246 |
+
)
|
| 247 |
+
hs_list = encoder_out[1]
|
| 248 |
+
assert len(hs_list) == self.layer_weights.numel()
|
| 249 |
+
w = torch.softmax(self.layer_weights, dim=0).to(
|
| 250 |
+
hs_list[0].device, hs_list[0].dtype
|
| 251 |
+
)
|
| 252 |
+
hs = torch.stack(hs_list, dim=0) # (L, B, T, D)
|
| 253 |
+
return (w.view(-1, 1, 1, 1) * hs).sum(0), encoder_out_lens
|
| 254 |
+
else:
|
| 255 |
+
ctc_for_encoder = (
|
| 256 |
+
self.ctc_aux
|
| 257 |
+
if self.interctc_ctc_type == "ortho" and self.ctc_aux is not None
|
| 258 |
+
else self.ctc
|
| 259 |
+
)
|
| 260 |
+
encoder_out, encoder_out_lens, _ = self.encoder(
|
| 261 |
+
speech, speech_lengths, masks=pad_masks, ctc=ctc_for_encoder
|
| 262 |
+
)
|
| 263 |
+
return encoder_out, encoder_out_lens
|
| 264 |
+
|
| 265 |
+
def ctc_collapse_batch(self, x: torch.Tensor, max_length: int, pad: int = -1):
|
| 266 |
+
B, T = x.shape
|
| 267 |
+
blank = self.blank_id
|
| 268 |
+
x_prev = torch.cat(
|
| 269 |
+
[torch.full((B, 1), blank, device=x.device, dtype=x.dtype), x[:, :-1]],
|
| 270 |
+
dim=1,
|
| 271 |
+
)
|
| 272 |
+
keep = (x != blank) & ((x_prev == blank) | (x != x_prev))
|
| 273 |
+
pos = keep.long().cumsum(1) - 1
|
| 274 |
+
lengths = keep.sum(1)
|
| 275 |
+
out = torch.full((B, T), pad, device=x.device, dtype=x.dtype)
|
| 276 |
+
# Compute batch indices and output positions for kept elements
|
| 277 |
+
batch_idx = (
|
| 278 |
+
torch.arange(B, device=x.device, dtype=torch.long).unsqueeze(1).expand_as(x)
|
| 279 |
+
)
|
| 280 |
+
output_pos = pos.clone()
|
| 281 |
+
# Only use positions where keep is True
|
| 282 |
+
batch_idx_keep = batch_idx[keep]
|
| 283 |
+
output_pos_keep = output_pos[keep]
|
| 284 |
+
# Flatten the output and set values at correct positions
|
| 285 |
+
flat_out = out.view(-1)
|
| 286 |
+
flat_idx = batch_idx_keep * T + output_pos_keep
|
| 287 |
+
flat_out[flat_idx] = x[keep]
|
| 288 |
+
out = flat_out.view(B, T)
|
| 289 |
+
##### Trim to max_length from ground truth lengths
|
| 290 |
+
out = out[:, :max_length]
|
| 291 |
+
lengths = torch.clamp(lengths, max=max_length)
|
| 292 |
+
return out, lengths
|
| 293 |
+
|
| 294 |
+
def _calc_att_loss(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens):
|
| 295 |
+
ys_pad = torch.where(ys_pad == -1, self.ignore_id, ys_pad)
|
| 296 |
+
ys = [y[y != self.ignore_id][:l] for y, l in zip(ys_pad, ys_pad_lens)]
|
| 297 |
+
_sos = ys_pad.new([self.sos])
|
| 298 |
+
_eos = ys_pad.new([self.eos])
|
| 299 |
+
ys_in = [torch.cat([_sos, y]) for y in ys]
|
| 300 |
+
ys_out = [torch.cat([y, _eos]) for y in ys]
|
| 301 |
+
ys_in_pad = pad_list(ys_in, self.eos)
|
| 302 |
+
ys_out_pad = pad_list(ys_out, self.ignore_id)
|
| 303 |
+
ys_in_lens = torch.tensor([len(y) for y in ys_in], device=ys_pad.device)
|
| 304 |
+
|
| 305 |
+
decoder_out, _ = self.decoder(
|
| 306 |
+
encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
|
| 307 |
+
)
|
| 308 |
+
loss_att = self.criterion_att(decoder_out, ys_out_pad)
|
| 309 |
+
acc_att = th_accuracy(
|
| 310 |
+
decoder_out.view(-1, len(self.token_list)),
|
| 311 |
+
ys_out_pad,
|
| 312 |
+
ignore_label=self.ignore_id,
|
| 313 |
+
)
|
| 314 |
+
return loss_att, acc_att
|
| 315 |
+
|
| 316 |
+
def _calc_ctc_loss(
|
| 317 |
+
self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, **kwargs
|
| 318 |
+
):
|
| 319 |
+
ys_pad = torch.where(ys_pad == -1, self.ignore_id, ys_pad)
|
| 320 |
+
ys_pad = ys_pad[:, : ys_pad_lens.max()]
|
| 321 |
+
loss_ctc = self.ctc(
|
| 322 |
+
encoder_out,
|
| 323 |
+
encoder_out_lens,
|
| 324 |
+
ys_pad,
|
| 325 |
+
ys_pad_lens,
|
| 326 |
+
lang_sym=kwargs.get("lang_sym"),
|
| 327 |
+
accent_sym=kwargs.get("accent_sym"),
|
| 328 |
+
)
|
| 329 |
+
stats = {}
|
| 330 |
+
assert self.error_calculator is not None, "ErrorCalculator not initialized"
|
| 331 |
+
if not self.training: # err calc, slow?
|
| 332 |
+
with torch.no_grad():
|
| 333 |
+
ys_hat = self.ctc.argmax(encoder_out).data # greedy-top1
|
| 334 |
+
metrics = self.error_calculator(
|
| 335 |
+
ys_hat.cpu(), ys_pad.cpu(), ys_pad_lens.cpu()
|
| 336 |
+
)
|
| 337 |
+
for k, v in metrics.items():
|
| 338 |
+
stats[k + "_ctc"] = v
|
| 339 |
+
return loss_ctc, stats
|
| 340 |
+
|
| 341 |
+
def ctc_logits(
|
| 342 |
+
self, speech: torch.Tensor, speech_lengths: torch.Tensor
|
| 343 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 344 |
+
"""Get CTC logits for inference."""
|
| 345 |
+
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
|
| 346 |
+
if isinstance(encoder_out, tuple):
|
| 347 |
+
encoder_out = encoder_out[0]
|
| 348 |
+
return self.ctc.ctc_lo(encoder_out), encoder_out_lens
|
| 349 |
+
|
| 350 |
+
def encoder_output_size(self) -> int:
|
| 351 |
+
return self.encoder.output_size()
|
| 352 |
+
|
| 353 |
+
def get_blank_id(self) -> int:
|
| 354 |
+
return self.blank_id
|
| 355 |
+
|
| 356 |
+
def get_frontend(self):
|
| 357 |
+
return self.frontend
|
| 358 |
+
|
| 359 |
+
def get_trainable_parameters(self):
|
| 360 |
+
trainable_params = {"head": [], "encoder": []}
|
| 361 |
+
for n, p in self.named_parameters():
|
| 362 |
+
if (
|
| 363 |
+
n.startswith("ctc")
|
| 364 |
+
or n.startswith("decoder")
|
| 365 |
+
or n.startswith("criterion_att")
|
| 366 |
+
):
|
| 367 |
+
trainable_params["head"].append(p)
|
| 368 |
+
elif n.startswith("encoder"):
|
| 369 |
+
trainable_params["encoder"].append(p)
|
| 370 |
+
elif n.startswith("frontend"):
|
| 371 |
+
if self.freeze_frontend:
|
| 372 |
+
p.requires_grad = False
|
| 373 |
+
else:
|
| 374 |
+
trainable_params["encoder"].append(p)
|
| 375 |
+
else:
|
| 376 |
+
# freeze other parts:
|
| 377 |
+
p.requires_grad = False
|
| 378 |
+
return trainable_params
|
src/recipe/__init__.py
ADDED
|
File without changes
|
src/recipe/phone_recognition/__init__.py
ADDED
|
File without changes
|
src/recipe/phone_recognition/greedy_ctc_strategy.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from typing import List, Dict, Any, Union
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def ctc_collapse_vectorized(
|
| 6 |
+
ids: torch.Tensor, blank_id: int, ignore_id: int = -1
|
| 7 |
+
) -> List[List[int]]:
|
| 8 |
+
"""Optimized CTC collapse for batch tensors."""
|
| 9 |
+
mask = torch.ones_like(ids, dtype=torch.bool)
|
| 10 |
+
mask[:, 1:] = ids[:, 1:] != ids[:, :-1]
|
| 11 |
+
mask &= ids != blank_id
|
| 12 |
+
if ignore_id != -1:
|
| 13 |
+
mask &= ids != ignore_id
|
| 14 |
+
|
| 15 |
+
return [ids[i][mask[i]].tolist() for i in range(ids.size(0))]
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class GreedyCTCInference:
|
| 19 |
+
"""A scalable inference engine for any CTC-based phone recognizer."""
|
| 20 |
+
|
| 21 |
+
def __init__(self, token_list: List[str], blank_id: int):
|
| 22 |
+
self.token_list = token_list
|
| 23 |
+
self.blank_id = blank_id
|
| 24 |
+
|
| 25 |
+
@torch.no_grad()
|
| 26 |
+
def __call__(
|
| 27 |
+
self,
|
| 28 |
+
model: torch.nn.Module,
|
| 29 |
+
speech: torch.Tensor,
|
| 30 |
+
speech_lengths: torch.Tensor,
|
| 31 |
+
**kwargs
|
| 32 |
+
) -> List[Dict[str, Any]]:
|
| 33 |
+
# 1. Standardized Forward pass
|
| 34 |
+
# Works as long as model has .encode() and .ctc
|
| 35 |
+
encoder_out, _ = model.encode(speech, speech_lengths)
|
| 36 |
+
if isinstance(encoder_out, tuple):
|
| 37 |
+
encoder_out = encoder_out[0]
|
| 38 |
+
logits = model.ctc.ctc_lo(encoder_out)
|
| 39 |
+
|
| 40 |
+
# 2. Greedy search
|
| 41 |
+
y_hat = torch.argmax(logits, dim=-1)
|
| 42 |
+
|
| 43 |
+
# 3. Collapse
|
| 44 |
+
collapsed_ids = ctc_collapse_vectorized(y_hat, self.blank_id)
|
| 45 |
+
|
| 46 |
+
# 4. Map to text
|
| 47 |
+
results = []
|
| 48 |
+
for ids in collapsed_ids:
|
| 49 |
+
tokens = [self.token_list[i] for i in ids]
|
| 50 |
+
raw_text = "/".join(tokens)
|
| 51 |
+
# Filter special tokens
|
| 52 |
+
clean_tokens = [
|
| 53 |
+
t for t in tokens if not (t.startswith("<") and t.endswith(">"))
|
| 54 |
+
]
|
| 55 |
+
processed = "".join(clean_tokens).strip() # replace(self.sym_space, " ")
|
| 56 |
+
|
| 57 |
+
results.append(
|
| 58 |
+
{
|
| 59 |
+
"processed_transcript": processed,
|
| 60 |
+
"predicted_transcript": raw_text,
|
| 61 |
+
}
|
| 62 |
+
)
|
| 63 |
+
return results
|
src/utils/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from src.utils.pylogger import RankedLogger
|
src/utils/pylogger.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import Mapping, Optional
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class RankedLogger(logging.LoggerAdapter):
|
| 6 |
+
"""Simplified logger for single-process inference (no Lightning)."""
|
| 7 |
+
|
| 8 |
+
def __init__(
|
| 9 |
+
self,
|
| 10 |
+
name: str = __name__,
|
| 11 |
+
rank_zero_only: bool = False,
|
| 12 |
+
extra: Optional[Mapping[str, object]] = None,
|
| 13 |
+
) -> None:
|
| 14 |
+
logger = logging.getLogger(name)
|
| 15 |
+
super().__init__(logger=logger, extra=extra)
|
| 16 |
+
self.rank_zero_only = rank_zero_only
|
| 17 |
+
|
| 18 |
+
def log(
|
| 19 |
+
self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs
|
| 20 |
+
) -> None:
|
| 21 |
+
if self.isEnabledFor(level):
|
| 22 |
+
msg, kwargs = self.process(msg, kwargs)
|
| 23 |
+
self.logger.log(level, msg, *args, **kwargs)
|