Commit ·
7d0662d
1
Parent(s): 132d83f
Release HuPER Corrector weights and inference code
Browse files- .gitattributes +0 -34
- edit_seq_speech/__init__.py +0 -0
- edit_seq_speech/config/vocab.json +95 -0
- edit_seq_speech/inference.py +165 -0
- edit_seq_speech/model.py +330 -0
.gitattributes
CHANGED
|
@@ -1,35 +1 @@
|
|
| 1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
edit_seq_speech/__init__.py
ADDED
|
File without changes
|
edit_seq_speech/config/vocab.json
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"op_to_id": {
|
| 3 |
+
"KEEP": 0,
|
| 4 |
+
"DEL": 1,
|
| 5 |
+
"SUB:<PAD>": 2,
|
| 6 |
+
"SUB:AA": 3,
|
| 7 |
+
"SUB:AE": 4,
|
| 8 |
+
"SUB:AH": 5,
|
| 9 |
+
"SUB:AO": 6,
|
| 10 |
+
"SUB:AW": 7,
|
| 11 |
+
"SUB:AY": 8,
|
| 12 |
+
"SUB:B": 9,
|
| 13 |
+
"SUB:CH": 10,
|
| 14 |
+
"SUB:D": 11,
|
| 15 |
+
"SUB:DH": 12,
|
| 16 |
+
"SUB:DX": 13,
|
| 17 |
+
"SUB:EH": 14,
|
| 18 |
+
"SUB:ER": 15,
|
| 19 |
+
"SUB:EY": 16,
|
| 20 |
+
"SUB:F": 17,
|
| 21 |
+
"SUB:G": 18,
|
| 22 |
+
"SUB:HH": 19,
|
| 23 |
+
"SUB:IH": 20,
|
| 24 |
+
"SUB:IY": 21,
|
| 25 |
+
"SUB:JH": 22,
|
| 26 |
+
"SUB:K": 23,
|
| 27 |
+
"SUB:L": 24,
|
| 28 |
+
"SUB:M": 25,
|
| 29 |
+
"SUB:N": 26,
|
| 30 |
+
"SUB:NG": 27,
|
| 31 |
+
"SUB:OW": 28,
|
| 32 |
+
"SUB:OY": 29,
|
| 33 |
+
"SUB:P": 30,
|
| 34 |
+
"SUB:R": 31,
|
| 35 |
+
"SUB:S": 32,
|
| 36 |
+
"SUB:SH": 33,
|
| 37 |
+
"SUB:T": 34,
|
| 38 |
+
"SUB:TH": 35,
|
| 39 |
+
"SUB:UH": 36,
|
| 40 |
+
"SUB:UW": 37,
|
| 41 |
+
"SUB:V": 38,
|
| 42 |
+
"SUB:W": 39,
|
| 43 |
+
"SUB:Y": 40,
|
| 44 |
+
"SUB:Z": 41,
|
| 45 |
+
"SUB:ZH": 42
|
| 46 |
+
},
|
| 47 |
+
"insert_to_id": {
|
| 48 |
+
"<NONE>": 0,
|
| 49 |
+
"<PAD>": 1,
|
| 50 |
+
"AA": 2,
|
| 51 |
+
"AE": 3,
|
| 52 |
+
"AH": 4,
|
| 53 |
+
"AO": 5,
|
| 54 |
+
"AW": 6,
|
| 55 |
+
"AY": 7,
|
| 56 |
+
"B": 8,
|
| 57 |
+
"CH": 9,
|
| 58 |
+
"D": 10,
|
| 59 |
+
"DH": 11,
|
| 60 |
+
"DX": 12,
|
| 61 |
+
"EH": 13,
|
| 62 |
+
"ER": 14,
|
| 63 |
+
"EY": 15,
|
| 64 |
+
"F": 16,
|
| 65 |
+
"G": 17,
|
| 66 |
+
"HH": 18,
|
| 67 |
+
"IH": 19,
|
| 68 |
+
"IY": 20,
|
| 69 |
+
"JH": 21,
|
| 70 |
+
"K": 22,
|
| 71 |
+
"L": 23,
|
| 72 |
+
"M": 24,
|
| 73 |
+
"N": 25,
|
| 74 |
+
"NG": 26,
|
| 75 |
+
"OW": 27,
|
| 76 |
+
"OY": 28,
|
| 77 |
+
"P": 29,
|
| 78 |
+
"R": 30,
|
| 79 |
+
"S": 31,
|
| 80 |
+
"SH": 32,
|
| 81 |
+
"T": 33,
|
| 82 |
+
"TH": 34,
|
| 83 |
+
"UH": 35,
|
| 84 |
+
"UW": 36,
|
| 85 |
+
"V": 37,
|
| 86 |
+
"W": 38,
|
| 87 |
+
"Y": 39,
|
| 88 |
+
"Z": 40,
|
| 89 |
+
"ZH": 41
|
| 90 |
+
},
|
| 91 |
+
"stats": {
|
| 92 |
+
"num_ops": 43,
|
| 93 |
+
"num_inserts": 42
|
| 94 |
+
}
|
| 95 |
+
}
|
edit_seq_speech/inference.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torchaudio
|
| 4 |
+
import json
|
| 5 |
+
import re
|
| 6 |
+
import os
|
| 7 |
+
from g2p_en import G2p
|
| 8 |
+
import pytorch_lightning as pl
|
| 9 |
+
|
| 10 |
+
from .model import PhonemeCorrector
|
| 11 |
+
from transformers import Wav2Vec2Processor, HubertModel
|
| 12 |
+
|
| 13 |
+
class PhonemeCorrectionInference:
|
| 14 |
+
def __init__(self, checkpoint_path, vocab_path, audio_model_name="facebook/hubert-large-ls960-ft", device=None):
|
| 15 |
+
self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 16 |
+
|
| 17 |
+
# 1. Load Vocab / Config
|
| 18 |
+
print(f"Loading config from {vocab_path}...")
|
| 19 |
+
with open(vocab_path, 'r') as f:
|
| 20 |
+
self.config = json.load(f)
|
| 21 |
+
|
| 22 |
+
self.op_map = self.config['op_to_id']
|
| 23 |
+
self.ins_map = self.config['insert_to_id']
|
| 24 |
+
|
| 25 |
+
# Create Reverse Maps (ID -> String)
|
| 26 |
+
self.id2op = {v: k for k, v in self.op_map.items()}
|
| 27 |
+
self.id2ins = {v: k for k, v in self.ins_map.items()}
|
| 28 |
+
|
| 29 |
+
# 2. Load G2P
|
| 30 |
+
self.g2p = G2p()
|
| 31 |
+
|
| 32 |
+
# 3. Load Model
|
| 33 |
+
print(f"Loading model from {checkpoint_path}...")
|
| 34 |
+
if os.path.exists(checkpoint_path):
|
| 35 |
+
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
| 36 |
+
hparams = checkpoint.get('hyper_parameters', {})
|
| 37 |
+
|
| 38 |
+
vocab_size = max(self.ins_map.values()) + 1
|
| 39 |
+
audio_vocab_size = hparams.get('audio_vocab_size', 2048)
|
| 40 |
+
|
| 41 |
+
self.model = PhonemeCorrector.load_from_checkpoint(
|
| 42 |
+
checkpoint_path,
|
| 43 |
+
map_location=self.device,
|
| 44 |
+
vocab_size=vocab_size,
|
| 45 |
+
audio_vocab_size=audio_vocab_size
|
| 46 |
+
)
|
| 47 |
+
else:
|
| 48 |
+
raise FileNotFoundError(f"Checkpoint not found at {checkpoint_path}")
|
| 49 |
+
|
| 50 |
+
self.model.to(self.device)
|
| 51 |
+
self.model.eval()
|
| 52 |
+
|
| 53 |
+
# 4. Load Audio Tokenizer
|
| 54 |
+
print(f"Loading Audio Tokenizer: {audio_model_name}")
|
| 55 |
+
self.audio_processor = Wav2Vec2Processor.from_pretrained(audio_model_name)
|
| 56 |
+
self.audio_model = HubertModel.from_pretrained(audio_model_name).eval().to(self.device)
|
| 57 |
+
|
| 58 |
+
def _clean_phn(self, phn_list):
|
| 59 |
+
"""Standard cleaning to match training."""
|
| 60 |
+
IGNORED = {"SIL", "'", "SPN", " "}
|
| 61 |
+
return [p.rstrip('012') for p in phn_list if p.rstrip('012') not in IGNORED]
|
| 62 |
+
|
| 63 |
+
def _get_audio_tokens(self, wav_path):
|
| 64 |
+
"""
|
| 65 |
+
Runs the audio tokenizer.
|
| 66 |
+
IMPORTANT: This must match your training data generation logic.
|
| 67 |
+
"""
|
| 68 |
+
waveform, sr = torchaudio.load(wav_path)
|
| 69 |
+
if sr != 16000:
|
| 70 |
+
resampler = torchaudio.transforms.Resample(sr, 16000)
|
| 71 |
+
waveform = resampler(waveform)
|
| 72 |
+
|
| 73 |
+
inputs = self.audio_processor(waveform.squeeze().numpy(), return_tensors="pt", sampling_rate=16000)
|
| 74 |
+
input_values = inputs.input_values.to(self.device)
|
| 75 |
+
|
| 76 |
+
with torch.no_grad():
|
| 77 |
+
outputs = self.audio_model(input_values)
|
| 78 |
+
|
| 79 |
+
# Placeholder Quantization (Argmax) - Replace if using K-Means
|
| 80 |
+
features = outputs.last_hidden_state
|
| 81 |
+
tokens = torch.argmax(features, dim=-1).squeeze()
|
| 82 |
+
|
| 83 |
+
# Downsample to 25Hz (Assuming model is 50Hz)
|
| 84 |
+
tokens = tokens[::2]
|
| 85 |
+
return tokens.unsqueeze(0) # (1, T)
|
| 86 |
+
|
| 87 |
+
def predict(self, wav_path, text):
|
| 88 |
+
# A. Prepare Inputs
|
| 89 |
+
# 1. Text -> Phonemes -> IDs
|
| 90 |
+
# raw_phns = self.g2p(text)
|
| 91 |
+
raw_phns = text.split() # Assuming input text is already phonemized for inference
|
| 92 |
+
src_phns = self._clean_phn(raw_phns)
|
| 93 |
+
|
| 94 |
+
# Create text vocab from insert_to_id (same as dataset)
|
| 95 |
+
text_vocab = {k: v for k, v in self.ins_map.items() if k not in ['<NONE>', '<PAD>']}
|
| 96 |
+
text_ids = [text_vocab.get(p, text_vocab.get("AA", 2)) for p in src_phns]
|
| 97 |
+
text_tensor = torch.tensor([text_ids], dtype=torch.long).to(self.device)
|
| 98 |
+
|
| 99 |
+
# 2. Audio -> Tokens
|
| 100 |
+
audio_tensor = self._get_audio_tokens(wav_path)
|
| 101 |
+
|
| 102 |
+
# B. Run Model
|
| 103 |
+
with torch.no_grad():
|
| 104 |
+
# Create masks
|
| 105 |
+
txt_mask = torch.ones_like(text_tensor)
|
| 106 |
+
aud_mask = torch.ones_like(audio_tensor)
|
| 107 |
+
|
| 108 |
+
logits_op, logits_ins = self.model(
|
| 109 |
+
text_tensor, audio_tensor, txt_mask, aud_mask
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
# C. Decode
|
| 113 |
+
pred_ops = torch.argmax(logits_op, dim=-1).squeeze().tolist()
|
| 114 |
+
pred_ins = torch.argmax(logits_ins, dim=-1).squeeze().tolist()
|
| 115 |
+
|
| 116 |
+
# Ensure lists
|
| 117 |
+
if not isinstance(pred_ops, list): pred_ops = [pred_ops]
|
| 118 |
+
if not isinstance(pred_ins, list): pred_ins = [pred_ins]
|
| 119 |
+
|
| 120 |
+
# D. Reconstruct Sequence
|
| 121 |
+
final_phonemes = []
|
| 122 |
+
log = []
|
| 123 |
+
|
| 124 |
+
for i, (orig, op_id, ins_id) in enumerate(zip(src_phns, pred_ops, pred_ins)):
|
| 125 |
+
|
| 126 |
+
# 1. Apply Operation
|
| 127 |
+
op_str = self.id2op.get(op_id, "KEEP")
|
| 128 |
+
curr_log = {"src": orig, "op": op_str, "ins": "NONE"}
|
| 129 |
+
|
| 130 |
+
if op_str == "KEEP":
|
| 131 |
+
final_phonemes.append(orig)
|
| 132 |
+
elif op_str == "DEL":
|
| 133 |
+
pass # Do not append
|
| 134 |
+
elif op_str.startswith("SUB:"):
|
| 135 |
+
# Extract phoneme: "SUB:AA" -> "AA"
|
| 136 |
+
new_phn = op_str.split(":")[1]
|
| 137 |
+
final_phonemes.append(new_phn)
|
| 138 |
+
|
| 139 |
+
# 2. Apply Insertion
|
| 140 |
+
ins_str = self.id2ins.get(ins_id, "<NONE>")
|
| 141 |
+
if ins_str != "<NONE>":
|
| 142 |
+
final_phonemes.append(ins_str)
|
| 143 |
+
curr_log["ins"] = ins_str
|
| 144 |
+
|
| 145 |
+
log.append(curr_log)
|
| 146 |
+
|
| 147 |
+
return final_phonemes, log
|
| 148 |
+
|
| 149 |
+
if __name__ == "__main__":
|
| 150 |
+
ckpt_path = "/data/chenxu/checkpoints/edit_seq_speech/phoneme-corrector/last.ckpt"
|
| 151 |
+
vocab_path = "edit_seq_speech/config/vocab.json"
|
| 152 |
+
wav_file = "test.wav"
|
| 153 |
+
text_input = "Last Sunday"
|
| 154 |
+
|
| 155 |
+
if os.path.exists(ckpt_path) and os.path.exists(wav_file):
|
| 156 |
+
infer = PhonemeCorrectionInference(ckpt_path, vocab_path)
|
| 157 |
+
result, details = infer.predict(wav_file, text_input)
|
| 158 |
+
|
| 159 |
+
print(f"Input Text: {text_input}")
|
| 160 |
+
print(f"Result Phn: {result}")
|
| 161 |
+
print("-" * 20)
|
| 162 |
+
for step in details:
|
| 163 |
+
print(f"{step['src']} -> {step['op']} + Insert({step['ins']})")
|
| 164 |
+
else:
|
| 165 |
+
print("Please set valid paths for checkpoint and wav file.")
|
edit_seq_speech/model.py
ADDED
|
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import math
|
| 4 |
+
import pytorch_lightning as pl
|
| 5 |
+
|
| 6 |
+
class PhonemeCorrector(pl.LightningModule):
|
| 7 |
+
def __init__(self, vocab_size, audio_vocab_size, d_model=256, nhead=4, num_layers=4, dropout=0.1, lr=1e-4,
|
| 8 |
+
weight_decay=0.01, scheduler_config=None, optimizer_config=None):
|
| 9 |
+
super().__init__()
|
| 10 |
+
self.save_hyperparameters()
|
| 11 |
+
self.scheduler_config = scheduler_config or {}
|
| 12 |
+
self.optimizer_config = optimizer_config or {}
|
| 13 |
+
|
| 14 |
+
# 1. Embeddings
|
| 15 |
+
self.text_embedding = nn.Embedding(vocab_size, d_model)
|
| 16 |
+
self.audio_embedding = nn.Embedding(audio_vocab_size, d_model)
|
| 17 |
+
|
| 18 |
+
# Positional Encoding (Standard Sinusoidal)
|
| 19 |
+
self.pos_encoder = PositionalEncoding(d_model, dropout)
|
| 20 |
+
|
| 21 |
+
# 2. The Core Transformer (Text querying Audio)
|
| 22 |
+
decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead, batch_first=True)
|
| 23 |
+
self.transformer = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
|
| 24 |
+
|
| 25 |
+
# 3. Prediction Heads - 2-head architecture
|
| 26 |
+
# Head 1: Operation (KEEP, DEL, SUB:AA, SUB:AE, ...)
|
| 27 |
+
# num_ops = vocab_size + 2 (KEEP=0, DEL=1, SUB:phonemes=2+)
|
| 28 |
+
# This matches the precomputed op_ids format
|
| 29 |
+
num_ops = vocab_size + 2
|
| 30 |
+
self.head_op = nn.Linear(d_model, num_ops)
|
| 31 |
+
|
| 32 |
+
# Head 2: Insertion (NONE=0, AA, AE, ...)
|
| 33 |
+
# num_inserts = vocab_size (NONE=0, then phonemes)
|
| 34 |
+
num_inserts = vocab_size
|
| 35 |
+
self.head_ins = nn.Linear(d_model, num_inserts)
|
| 36 |
+
|
| 37 |
+
def forward(self, text_ids, audio_ids, text_mask=None, audio_mask=None):
|
| 38 |
+
"""
|
| 39 |
+
text_ids: (Batch, Text_Len)
|
| 40 |
+
audio_ids: (Batch, Audio_Len)
|
| 41 |
+
masks: (Batch, Len) - 1 for valid, 0 for pad.
|
| 42 |
+
"""
|
| 43 |
+
text_emb = self.pos_encoder(self.text_embedding(text_ids))
|
| 44 |
+
audio_emb = self.pos_encoder(self.audio_embedding(audio_ids))
|
| 45 |
+
|
| 46 |
+
txt_pad_mask = (text_mask == 0) if text_mask is not None else None
|
| 47 |
+
aud_pad_mask = (audio_mask == 0) if audio_mask is not None else None
|
| 48 |
+
|
| 49 |
+
encoded_features = self.transformer(
|
| 50 |
+
tgt=text_emb,
|
| 51 |
+
memory=audio_emb,
|
| 52 |
+
tgt_key_padding_mask=txt_pad_mask,
|
| 53 |
+
memory_key_padding_mask=aud_pad_mask
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
logits_op = self.head_op(encoded_features)
|
| 57 |
+
logits_ins = self.head_ins(encoded_features)
|
| 58 |
+
|
| 59 |
+
return logits_op, logits_ins
|
| 60 |
+
|
| 61 |
+
def training_step(self, batch, batch_idx):
|
| 62 |
+
input_ids = batch['input_ids']
|
| 63 |
+
audio_tokens = batch['audio_tokens']
|
| 64 |
+
lbl_op = batch['labels']['op']
|
| 65 |
+
lbl_ins = batch['labels']['ins']
|
| 66 |
+
txt_mask = batch['masks']['text']
|
| 67 |
+
audio_mask = batch['masks']['audio']
|
| 68 |
+
|
| 69 |
+
logits_op, logits_ins = self(input_ids, audio_tokens, txt_mask, audio_mask)
|
| 70 |
+
|
| 71 |
+
# Active loss mask (only compute loss on valid text tokens)
|
| 72 |
+
active_loss = txt_mask.view(-1) == 1
|
| 73 |
+
|
| 74 |
+
# OP LOSS (includes KEEP, DEL, and all SUB:phoneme operations)
|
| 75 |
+
num_ops = self.hparams.vocab_size + 2
|
| 76 |
+
loss_op = nn.functional.cross_entropy(
|
| 77 |
+
logits_op.view(-1, num_ops)[active_loss],
|
| 78 |
+
lbl_op.view(-1)[active_loss]
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
# INS LOSS
|
| 82 |
+
loss_ins = nn.functional.cross_entropy(
|
| 83 |
+
logits_ins.view(-1, self.hparams.vocab_size)[active_loss],
|
| 84 |
+
lbl_ins.view(-1)[active_loss]
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
loss = loss_op + loss_ins
|
| 88 |
+
self.log('train_loss', loss, prog_bar=True)
|
| 89 |
+
self.log('train_loss_op', loss_op)
|
| 90 |
+
self.log('train_loss_ins', loss_ins)
|
| 91 |
+
return loss
|
| 92 |
+
|
| 93 |
+
def validation_step(self, batch, batch_idx):
|
| 94 |
+
input_ids = batch['input_ids']
|
| 95 |
+
audio_tokens = batch['audio_tokens']
|
| 96 |
+
lbl_op = batch['labels']['op']
|
| 97 |
+
lbl_ins = batch['labels']['ins']
|
| 98 |
+
txt_mask = batch['masks']['text']
|
| 99 |
+
audio_mask = batch['masks']['audio']
|
| 100 |
+
|
| 101 |
+
logits_op, logits_ins = self(input_ids, audio_tokens, txt_mask, audio_mask)
|
| 102 |
+
|
| 103 |
+
# Compute losses
|
| 104 |
+
active_loss = txt_mask.view(-1) == 1
|
| 105 |
+
num_ops = self.hparams.vocab_size + 2
|
| 106 |
+
|
| 107 |
+
loss_op = nn.functional.cross_entropy(
|
| 108 |
+
logits_op.view(-1, num_ops)[active_loss],
|
| 109 |
+
lbl_op.view(-1)[active_loss]
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
loss_ins = nn.functional.cross_entropy(
|
| 113 |
+
logits_ins.view(-1, self.hparams.vocab_size)[active_loss],
|
| 114 |
+
lbl_ins.view(-1)[active_loss]
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
loss = loss_op + loss_ins
|
| 118 |
+
|
| 119 |
+
# Compute accuracy
|
| 120 |
+
pred_op = torch.argmax(logits_op, dim=-1)
|
| 121 |
+
pred_ins = torch.argmax(logits_ins, dim=-1)
|
| 122 |
+
|
| 123 |
+
# OP accuracy
|
| 124 |
+
op_correct = (pred_op == lbl_op) & txt_mask
|
| 125 |
+
op_acc = op_correct.sum().float() / txt_mask.sum().float()
|
| 126 |
+
|
| 127 |
+
# INS accuracy
|
| 128 |
+
ins_correct = (pred_ins == lbl_ins) & txt_mask
|
| 129 |
+
ins_acc = ins_correct.sum().float() / txt_mask.sum().float()
|
| 130 |
+
|
| 131 |
+
# Overall accuracy: correct OP prediction
|
| 132 |
+
overall_acc = op_acc
|
| 133 |
+
|
| 134 |
+
# Per-operation accuracy (KEEP=0, DEL=1, SUB>=2)
|
| 135 |
+
keep_mask = (lbl_op == 0) & txt_mask
|
| 136 |
+
del_mask = (lbl_op == 1) & txt_mask
|
| 137 |
+
sub_op_mask = (lbl_op >= 2) & txt_mask
|
| 138 |
+
|
| 139 |
+
keep_acc = torch.tensor(0.0, device=loss.device)
|
| 140 |
+
del_acc = torch.tensor(0.0, device=loss.device)
|
| 141 |
+
sub_op_acc = torch.tensor(0.0, device=loss.device)
|
| 142 |
+
|
| 143 |
+
if keep_mask.sum() > 0:
|
| 144 |
+
keep_correct = (pred_op == lbl_op) & keep_mask
|
| 145 |
+
keep_acc = keep_correct.sum().float() / keep_mask.sum().float()
|
| 146 |
+
|
| 147 |
+
if del_mask.sum() > 0:
|
| 148 |
+
del_correct = (pred_op == lbl_op) & del_mask
|
| 149 |
+
del_acc = del_correct.sum().float() / del_mask.sum().float()
|
| 150 |
+
|
| 151 |
+
if sub_op_mask.sum() > 0:
|
| 152 |
+
sub_op_correct = (pred_op == lbl_op) & sub_op_mask
|
| 153 |
+
sub_op_acc = sub_op_correct.sum().float() / sub_op_mask.sum().float()
|
| 154 |
+
|
| 155 |
+
# Log metrics
|
| 156 |
+
self.log('val_loss', loss, prog_bar=True, sync_dist=True)
|
| 157 |
+
self.log('val_loss_op', loss_op, sync_dist=True)
|
| 158 |
+
self.log('val_loss_ins', loss_ins, sync_dist=True)
|
| 159 |
+
self.log('val_acc', overall_acc, prog_bar=True, sync_dist=True)
|
| 160 |
+
self.log('val_acc_op', op_acc, sync_dist=True)
|
| 161 |
+
self.log('val_acc_ins', ins_acc, sync_dist=True)
|
| 162 |
+
self.log('val_acc_keep', keep_acc, sync_dist=True)
|
| 163 |
+
self.log('val_acc_del', del_acc, sync_dist=True)
|
| 164 |
+
self.log('val_acc_sub_op', sub_op_acc, sync_dist=True)
|
| 165 |
+
|
| 166 |
+
return {
|
| 167 |
+
'val_loss': loss,
|
| 168 |
+
'val_acc': overall_acc,
|
| 169 |
+
'val_acc_op': op_acc,
|
| 170 |
+
'val_acc_ins': ins_acc
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
def configure_optimizers(self):
|
| 174 |
+
# Get optimizer configuration
|
| 175 |
+
optimizer_name = self.optimizer_config.get("name", "adamw").lower()
|
| 176 |
+
lr = self.hparams.lr
|
| 177 |
+
weight_decay = getattr(self.hparams, 'weight_decay', 0.01)
|
| 178 |
+
|
| 179 |
+
if optimizer_name == "adamw":
|
| 180 |
+
optimizer = torch.optim.AdamW(
|
| 181 |
+
self.parameters(),
|
| 182 |
+
lr=lr,
|
| 183 |
+
weight_decay=weight_decay,
|
| 184 |
+
betas=self.optimizer_config.get("betas", [0.9, 0.999]),
|
| 185 |
+
eps=self.optimizer_config.get("eps", 1.0e-8)
|
| 186 |
+
)
|
| 187 |
+
elif optimizer_name == "adam":
|
| 188 |
+
optimizer = torch.optim.Adam(
|
| 189 |
+
self.parameters(),
|
| 190 |
+
lr=lr,
|
| 191 |
+
weight_decay=weight_decay,
|
| 192 |
+
betas=self.optimizer_config.get("betas", [0.9, 0.999]),
|
| 193 |
+
eps=self.optimizer_config.get("eps", 1.0e-8)
|
| 194 |
+
)
|
| 195 |
+
else:
|
| 196 |
+
raise ValueError(f"Unknown optimizer: {optimizer_name}")
|
| 197 |
+
|
| 198 |
+
# Configure scheduler
|
| 199 |
+
scheduler_type = self.scheduler_config.get("type", "cosine").lower()
|
| 200 |
+
|
| 201 |
+
# Calculate total training steps
|
| 202 |
+
max_epochs = getattr(self.trainer, 'max_epochs', 50)
|
| 203 |
+
if self.trainer and hasattr(self.trainer, 'estimated_stepping_batches'):
|
| 204 |
+
total_steps = self.trainer.estimated_stepping_batches
|
| 205 |
+
else:
|
| 206 |
+
# Fallback: estimate steps per epoch
|
| 207 |
+
estimated_steps_per_epoch = 1000 # Conservative estimate
|
| 208 |
+
total_steps = max_epochs * estimated_steps_per_epoch
|
| 209 |
+
|
| 210 |
+
warmup_ratio = self.scheduler_config.get("warmup_ratio", 0.1)
|
| 211 |
+
warmup_steps = max(1, int(total_steps * warmup_ratio))
|
| 212 |
+
|
| 213 |
+
if scheduler_type == "cosine":
|
| 214 |
+
# Use transformers' cosine scheduler with warmup
|
| 215 |
+
try:
|
| 216 |
+
from transformers import get_cosine_schedule_with_warmup
|
| 217 |
+
eta_min = self.scheduler_config.get("eta_min", 1.0e-6)
|
| 218 |
+
scheduler = get_cosine_schedule_with_warmup(
|
| 219 |
+
optimizer,
|
| 220 |
+
num_warmup_steps=warmup_steps,
|
| 221 |
+
num_training_steps=total_steps,
|
| 222 |
+
num_cycles=0.5, # Default cosine cycles
|
| 223 |
+
last_epoch=-1
|
| 224 |
+
)
|
| 225 |
+
except ImportError:
|
| 226 |
+
# Fallback to PyTorch implementation
|
| 227 |
+
from torch.optim.lr_scheduler import LambdaLR
|
| 228 |
+
import math
|
| 229 |
+
eta_min = self.scheduler_config.get("eta_min", 1.0e-6)
|
| 230 |
+
def lr_lambda(step):
|
| 231 |
+
if step < warmup_steps:
|
| 232 |
+
return step / warmup_steps
|
| 233 |
+
else:
|
| 234 |
+
# Cosine annealing after warmup
|
| 235 |
+
progress = (step - warmup_steps) / (total_steps - warmup_steps)
|
| 236 |
+
cosine_value = 0.5 * (1 + math.cos(math.pi * progress))
|
| 237 |
+
return eta_min / lr + (1 - eta_min / lr) * cosine_value
|
| 238 |
+
scheduler = LambdaLR(optimizer, lr_lambda)
|
| 239 |
+
|
| 240 |
+
elif scheduler_type == "linear":
|
| 241 |
+
# Use transformers' linear scheduler with warmup
|
| 242 |
+
try:
|
| 243 |
+
from transformers import get_linear_schedule_with_warmup
|
| 244 |
+
scheduler = get_linear_schedule_with_warmup(
|
| 245 |
+
optimizer,
|
| 246 |
+
num_warmup_steps=warmup_steps,
|
| 247 |
+
num_training_steps=total_steps
|
| 248 |
+
)
|
| 249 |
+
except ImportError:
|
| 250 |
+
# Fallback to PyTorch implementation
|
| 251 |
+
from torch.optim.lr_scheduler import LambdaLR
|
| 252 |
+
def lr_lambda(step):
|
| 253 |
+
if step < warmup_steps:
|
| 254 |
+
return step / warmup_steps
|
| 255 |
+
else:
|
| 256 |
+
progress = (step - warmup_steps) / (total_steps - warmup_steps)
|
| 257 |
+
return max(0.0, 1.0 - progress)
|
| 258 |
+
scheduler = LambdaLR(optimizer, lr_lambda)
|
| 259 |
+
|
| 260 |
+
elif scheduler_type == "polynomial":
|
| 261 |
+
# Use transformers' polynomial scheduler with warmup
|
| 262 |
+
try:
|
| 263 |
+
from transformers import get_polynomial_decay_schedule_with_warmup
|
| 264 |
+
power = self.scheduler_config.get("power", 1.0)
|
| 265 |
+
scheduler = get_polynomial_decay_schedule_with_warmup(
|
| 266 |
+
optimizer,
|
| 267 |
+
num_warmup_steps=warmup_steps,
|
| 268 |
+
num_training_steps=total_steps,
|
| 269 |
+
power=power
|
| 270 |
+
)
|
| 271 |
+
except ImportError:
|
| 272 |
+
# Fallback: use linear scheduler
|
| 273 |
+
from torch.optim.lr_scheduler import LambdaLR
|
| 274 |
+
def lr_lambda(step):
|
| 275 |
+
if step < warmup_steps:
|
| 276 |
+
return step / warmup_steps
|
| 277 |
+
else:
|
| 278 |
+
progress = (step - warmup_steps) / (total_steps - warmup_steps)
|
| 279 |
+
return max(0.0, (1.0 - progress) ** power)
|
| 280 |
+
scheduler = LambdaLR(optimizer, lr_lambda)
|
| 281 |
+
|
| 282 |
+
elif scheduler_type == "reduce_on_plateau":
|
| 283 |
+
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
| 284 |
+
scheduler = ReduceLROnPlateau(
|
| 285 |
+
optimizer,
|
| 286 |
+
mode='min',
|
| 287 |
+
factor=self.scheduler_config.get("factor", 0.5),
|
| 288 |
+
patience=self.scheduler_config.get("patience", 3),
|
| 289 |
+
min_lr=self.scheduler_config.get("min_lr", 1.0e-6),
|
| 290 |
+
verbose=True
|
| 291 |
+
)
|
| 292 |
+
return {
|
| 293 |
+
"optimizer": optimizer,
|
| 294 |
+
"lr_scheduler": {
|
| 295 |
+
"scheduler": scheduler,
|
| 296 |
+
"monitor": "train_loss",
|
| 297 |
+
"interval": "epoch",
|
| 298 |
+
"frequency": 1,
|
| 299 |
+
}
|
| 300 |
+
}
|
| 301 |
+
else:
|
| 302 |
+
# No scheduler
|
| 303 |
+
return optimizer
|
| 304 |
+
|
| 305 |
+
return {
|
| 306 |
+
"optimizer": optimizer,
|
| 307 |
+
"lr_scheduler": {
|
| 308 |
+
"scheduler": scheduler,
|
| 309 |
+
"interval": "step",
|
| 310 |
+
"frequency": 1,
|
| 311 |
+
}
|
| 312 |
+
}
|
| 313 |
+
|
| 314 |
+
# Helper for Positional Encoding
|
| 315 |
+
class PositionalEncoding(nn.Module):
|
| 316 |
+
def __init__(self, d_model, dropout=0.1, max_len=5000):
|
| 317 |
+
super().__init__()
|
| 318 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 319 |
+
|
| 320 |
+
pe = torch.zeros(max_len, d_model)
|
| 321 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
| 322 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
| 323 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 324 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 325 |
+
self.register_buffer('pe', pe.unsqueeze(0))
|
| 326 |
+
|
| 327 |
+
def forward(self, x):
|
| 328 |
+
# x: (Batch, Seq, Dim)
|
| 329 |
+
x = x + self.pe[:, :x.size(1)]
|
| 330 |
+
return self.dropout(x)
|