File size: 9,238 Bytes
7d0662d 08836fe 7d0662d 08836fe 7d0662d 08836fe 7d0662d 08836fe 7d0662d 08836fe 7d0662d 08836fe 7d0662d 08836fe 7d0662d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 | import torch
import torch.nn as nn
import torchaudio
import json
import re
import os
from g2p_en import G2p
import pytorch_lightning as pl
from .model import PhonemeCorrector
from transformers import Wav2Vec2Processor, HubertModel
from safetensors.torch import load_file as safetensors_load_file
class PhonemeCorrectionInference:
def __init__(self, checkpoint_path, vocab_path, audio_model_name="facebook/hubert-large-ls960-ft", device=None):
self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 1) Load vocab
print(f"Loading config from {vocab_path}...")
with open(vocab_path, "r") as f:
self.config = json.load(f)
self.op_map = self.config["op_to_id"]
self.ins_map = self.config["insert_to_id"]
self.id2op = {v: k for k, v in self.op_map.items()}
self.id2ins = {v: k for k, v in self.ins_map.items()}
# 2) Load G2P
self.g2p = G2p()
# 3) Load hparams.json (prefer same dir as checkpoint, fallback to parent)
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(f"Checkpoint not found at {checkpoint_path}")
hparams = {}
hp_candidates = [
os.path.join(os.path.dirname(checkpoint_path), "hparams.json"),
os.path.join(os.path.dirname(os.path.dirname(checkpoint_path)), "hparams.json"),
]
for hp in hp_candidates:
if os.path.exists(hp):
with open(hp, "r") as f:
hparams = json.load(f)
break
# 4) Load weights/state_dict
print(f"Loading model weights from {checkpoint_path}...")
lower = checkpoint_path.lower()
if lower.endswith(".safetensors"):
state_dict = safetensors_load_file(checkpoint_path, device="cpu")
elif lower.endswith(".ckpt") or lower.endswith(".pt") or lower.endswith(".pth"):
# NOTE: weights_only=False is needed for Lightning-style checkpoints in PyTorch 2.6+
ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
state_dict = ckpt.get("state_dict", ckpt)
if not hparams and isinstance(ckpt, dict):
hparams = ckpt.get("hyper_parameters", {}) or {}
else:
raise ValueError(f"Unsupported checkpoint format: {checkpoint_path}")
# 5) Build model with correct hyperparams
vocab_size_from_vocab = max(self.ins_map.values()) + 1
# Prefer hparams.json, but also sanity-check against state_dict shapes
vocab_size = int(hparams.get("vocab_size", vocab_size_from_vocab))
audio_vocab_size = int(hparams.get("audio_vocab_size", 2048))
d_model = int(hparams.get("d_model", 256))
nhead = int(hparams.get("nhead", 4))
num_layers = int(hparams.get("num_layers", 4))
dropout = float(hparams.get("dropout", 0.1))
lr = float(hparams.get("lr", 1e-4))
weight_decay = float(hparams.get("weight_decay", 0.01))
scheduler_config = hparams.get("scheduler_config", None)
optimizer_config = hparams.get("optimizer_config", None)
# Hard check: vocab.json and weights must agree
if "text_embedding.weight" in state_dict:
vsd, dsd = state_dict["text_embedding.weight"].shape
asd = state_dict["audio_embedding.weight"].shape[0]
if vsd != vocab_size_from_vocab:
raise ValueError(
f"vocab.json (vocab_size={vocab_size_from_vocab}) does not match weights (vocab_size={vsd}). "
"Please upload the matching vocab.json."
)
# Override to match weights exactly (safer)
vocab_size = vsd
audio_vocab_size = asd
d_model = dsd
self.model = PhonemeCorrector(
vocab_size=vocab_size,
audio_vocab_size=audio_vocab_size,
d_model=d_model,
nhead=nhead,
num_layers=num_layers,
dropout=dropout,
lr=lr,
weight_decay=weight_decay,
scheduler_config=scheduler_config,
optimizer_config=optimizer_config,
)
missing, unexpected = self.model.load_state_dict(state_dict, strict=False)
if missing or unexpected:
print(f"[load_state_dict] missing={len(missing)} unexpected={len(unexpected)}")
if missing[:5]:
print(" missing (first 5):", missing[:5])
if unexpected[:5]:
print(" unexpected (first 5):", unexpected[:5])
self.model.to(self.device).eval()
# 6) Load Audio Tokenizer
print(f"Loading Audio Tokenizer: {audio_model_name}")
self.audio_processor = Wav2Vec2Processor.from_pretrained(audio_model_name)
self.audio_model = HubertModel.from_pretrained(audio_model_name).eval().to(self.device)
def _clean_phn(self, phn_list):
"""Standard cleaning to match training."""
IGNORED = {"SIL", "'", "SPN", " "}
return [p.rstrip('012') for p in phn_list if p.rstrip('012') not in IGNORED]
def _get_audio_tokens(self, wav_path):
"""
Runs the audio tokenizer.
IMPORTANT: This must match your training data generation logic.
"""
waveform, sr = torchaudio.load(wav_path)
if sr != 16000:
resampler = torchaudio.transforms.Resample(sr, 16000)
waveform = resampler(waveform)
inputs = self.audio_processor(waveform.squeeze().numpy(), return_tensors="pt", sampling_rate=16000)
input_values = inputs.input_values.to(self.device)
with torch.no_grad():
outputs = self.audio_model(input_values)
# Placeholder Quantization (Argmax) - Replace if using K-Means
features = outputs.last_hidden_state
tokens = torch.argmax(features, dim=-1).squeeze()
# Downsample to 25Hz (Assuming model is 50Hz)
tokens = tokens[::2]
return tokens.unsqueeze(0) # (1, T)
def predict(self, wav_path, text):
# A. Prepare Inputs
# 1. Text -> Phonemes -> IDs
# raw_phns = self.g2p(text)
raw_phns = text.split() # Assuming input text is already phonemized for inference
src_phns = self._clean_phn(raw_phns)
# Create text vocab from insert_to_id (same as dataset)
text_vocab = {k: v for k, v in self.ins_map.items() if k not in ['<NONE>', '<PAD>']}
text_ids = [text_vocab.get(p, text_vocab.get("AA", 2)) for p in src_phns]
text_tensor = torch.tensor([text_ids], dtype=torch.long).to(self.device)
# 2. Audio -> Tokens
audio_tensor = self._get_audio_tokens(wav_path)
# B. Run Model
with torch.no_grad():
# Create masks
txt_mask = torch.ones_like(text_tensor)
aud_mask = torch.ones_like(audio_tensor)
logits_op, logits_ins = self.model(
text_tensor, audio_tensor, txt_mask, aud_mask
)
# C. Decode
pred_ops = torch.argmax(logits_op, dim=-1).squeeze().tolist()
pred_ins = torch.argmax(logits_ins, dim=-1).squeeze().tolist()
# Ensure lists
if not isinstance(pred_ops, list): pred_ops = [pred_ops]
if not isinstance(pred_ins, list): pred_ins = [pred_ins]
# D. Reconstruct Sequence
final_phonemes = []
log = []
for i, (orig, op_id, ins_id) in enumerate(zip(src_phns, pred_ops, pred_ins)):
# 1. Apply Operation
op_str = self.id2op.get(op_id, "KEEP")
curr_log = {"src": orig, "op": op_str, "ins": "NONE"}
if op_str == "KEEP":
final_phonemes.append(orig)
elif op_str == "DEL":
pass # Do not append
elif op_str.startswith("SUB:"):
# Extract phoneme: "SUB:AA" -> "AA"
new_phn = op_str.split(":")[1]
final_phonemes.append(new_phn)
# 2. Apply Insertion
ins_str = self.id2ins.get(ins_id, "<NONE>")
if ins_str != "<NONE>":
final_phonemes.append(ins_str)
curr_log["ins"] = ins_str
log.append(curr_log)
return final_phonemes, log
if __name__ == "__main__":
ckpt_path = "/data/chenxu/checkpoints/edit_seq_speech/phoneme-corrector/last.ckpt"
vocab_path = "edit_seq_speech/config/vocab.json"
wav_file = "test.wav"
text_input = "Last Sunday"
if os.path.exists(ckpt_path) and os.path.exists(wav_file):
infer = PhonemeCorrectionInference(ckpt_path, vocab_path)
result, details = infer.predict(wav_file, text_input)
print(f"Input Text: {text_input}")
print(f"Result Phn: {result}")
print("-" * 20)
for step in details:
print(f"{step['src']} -> {step['op']} + Insert({step['ins']})")
else:
print("Please set valid paths for checkpoint and wav file.") |