|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torchaudio |
|
|
import json |
|
|
import re |
|
|
import os |
|
|
from g2p_en import G2p |
|
|
import pytorch_lightning as pl |
|
|
|
|
|
from .model import PhonemeCorrector |
|
|
from transformers import Wav2Vec2Processor, HubertModel |
|
|
from safetensors.torch import load_file as safetensors_load_file |
|
|
|
|
|
class PhonemeCorrectionInference: |
|
|
def __init__(self, checkpoint_path, vocab_path, audio_model_name="facebook/hubert-large-ls960-ft", device=None): |
|
|
self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
print(f"Loading config from {vocab_path}...") |
|
|
with open(vocab_path, "r") as f: |
|
|
self.config = json.load(f) |
|
|
|
|
|
self.op_map = self.config["op_to_id"] |
|
|
self.ins_map = self.config["insert_to_id"] |
|
|
self.id2op = {v: k for k, v in self.op_map.items()} |
|
|
self.id2ins = {v: k for k, v in self.ins_map.items()} |
|
|
|
|
|
|
|
|
self.g2p = G2p() |
|
|
|
|
|
|
|
|
if not os.path.exists(checkpoint_path): |
|
|
raise FileNotFoundError(f"Checkpoint not found at {checkpoint_path}") |
|
|
|
|
|
hparams = {} |
|
|
hp_candidates = [ |
|
|
os.path.join(os.path.dirname(checkpoint_path), "hparams.json"), |
|
|
os.path.join(os.path.dirname(os.path.dirname(checkpoint_path)), "hparams.json"), |
|
|
] |
|
|
for hp in hp_candidates: |
|
|
if os.path.exists(hp): |
|
|
with open(hp, "r") as f: |
|
|
hparams = json.load(f) |
|
|
break |
|
|
|
|
|
|
|
|
print(f"Loading model weights from {checkpoint_path}...") |
|
|
lower = checkpoint_path.lower() |
|
|
if lower.endswith(".safetensors"): |
|
|
state_dict = safetensors_load_file(checkpoint_path, device="cpu") |
|
|
elif lower.endswith(".ckpt") or lower.endswith(".pt") or lower.endswith(".pth"): |
|
|
|
|
|
ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False) |
|
|
state_dict = ckpt.get("state_dict", ckpt) |
|
|
if not hparams and isinstance(ckpt, dict): |
|
|
hparams = ckpt.get("hyper_parameters", {}) or {} |
|
|
else: |
|
|
raise ValueError(f"Unsupported checkpoint format: {checkpoint_path}") |
|
|
|
|
|
|
|
|
vocab_size_from_vocab = max(self.ins_map.values()) + 1 |
|
|
|
|
|
|
|
|
vocab_size = int(hparams.get("vocab_size", vocab_size_from_vocab)) |
|
|
audio_vocab_size = int(hparams.get("audio_vocab_size", 2048)) |
|
|
d_model = int(hparams.get("d_model", 256)) |
|
|
nhead = int(hparams.get("nhead", 4)) |
|
|
num_layers = int(hparams.get("num_layers", 4)) |
|
|
dropout = float(hparams.get("dropout", 0.1)) |
|
|
lr = float(hparams.get("lr", 1e-4)) |
|
|
weight_decay = float(hparams.get("weight_decay", 0.01)) |
|
|
scheduler_config = hparams.get("scheduler_config", None) |
|
|
optimizer_config = hparams.get("optimizer_config", None) |
|
|
|
|
|
|
|
|
if "text_embedding.weight" in state_dict: |
|
|
vsd, dsd = state_dict["text_embedding.weight"].shape |
|
|
asd = state_dict["audio_embedding.weight"].shape[0] |
|
|
if vsd != vocab_size_from_vocab: |
|
|
raise ValueError( |
|
|
f"vocab.json (vocab_size={vocab_size_from_vocab}) does not match weights (vocab_size={vsd}). " |
|
|
"Please upload the matching vocab.json." |
|
|
) |
|
|
|
|
|
vocab_size = vsd |
|
|
audio_vocab_size = asd |
|
|
d_model = dsd |
|
|
|
|
|
self.model = PhonemeCorrector( |
|
|
vocab_size=vocab_size, |
|
|
audio_vocab_size=audio_vocab_size, |
|
|
d_model=d_model, |
|
|
nhead=nhead, |
|
|
num_layers=num_layers, |
|
|
dropout=dropout, |
|
|
lr=lr, |
|
|
weight_decay=weight_decay, |
|
|
scheduler_config=scheduler_config, |
|
|
optimizer_config=optimizer_config, |
|
|
) |
|
|
missing, unexpected = self.model.load_state_dict(state_dict, strict=False) |
|
|
if missing or unexpected: |
|
|
print(f"[load_state_dict] missing={len(missing)} unexpected={len(unexpected)}") |
|
|
if missing[:5]: |
|
|
print(" missing (first 5):", missing[:5]) |
|
|
if unexpected[:5]: |
|
|
print(" unexpected (first 5):", unexpected[:5]) |
|
|
|
|
|
self.model.to(self.device).eval() |
|
|
|
|
|
|
|
|
print(f"Loading Audio Tokenizer: {audio_model_name}") |
|
|
self.audio_processor = Wav2Vec2Processor.from_pretrained(audio_model_name) |
|
|
self.audio_model = HubertModel.from_pretrained(audio_model_name).eval().to(self.device) |
|
|
|
|
|
def _clean_phn(self, phn_list): |
|
|
"""Standard cleaning to match training.""" |
|
|
IGNORED = {"SIL", "'", "SPN", " "} |
|
|
return [p.rstrip('012') for p in phn_list if p.rstrip('012') not in IGNORED] |
|
|
|
|
|
def _get_audio_tokens(self, wav_path): |
|
|
""" |
|
|
Runs the audio tokenizer. |
|
|
IMPORTANT: This must match your training data generation logic. |
|
|
""" |
|
|
waveform, sr = torchaudio.load(wav_path) |
|
|
if sr != 16000: |
|
|
resampler = torchaudio.transforms.Resample(sr, 16000) |
|
|
waveform = resampler(waveform) |
|
|
|
|
|
inputs = self.audio_processor(waveform.squeeze().numpy(), return_tensors="pt", sampling_rate=16000) |
|
|
input_values = inputs.input_values.to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.audio_model(input_values) |
|
|
|
|
|
|
|
|
features = outputs.last_hidden_state |
|
|
tokens = torch.argmax(features, dim=-1).squeeze() |
|
|
|
|
|
|
|
|
tokens = tokens[::2] |
|
|
return tokens.unsqueeze(0) |
|
|
|
|
|
def predict(self, wav_path, text): |
|
|
|
|
|
|
|
|
|
|
|
raw_phns = text.split() |
|
|
src_phns = self._clean_phn(raw_phns) |
|
|
|
|
|
|
|
|
text_vocab = {k: v for k, v in self.ins_map.items() if k not in ['<NONE>', '<PAD>']} |
|
|
text_ids = [text_vocab.get(p, text_vocab.get("AA", 2)) for p in src_phns] |
|
|
text_tensor = torch.tensor([text_ids], dtype=torch.long).to(self.device) |
|
|
|
|
|
|
|
|
audio_tensor = self._get_audio_tokens(wav_path) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
txt_mask = torch.ones_like(text_tensor) |
|
|
aud_mask = torch.ones_like(audio_tensor) |
|
|
|
|
|
logits_op, logits_ins = self.model( |
|
|
text_tensor, audio_tensor, txt_mask, aud_mask |
|
|
) |
|
|
|
|
|
|
|
|
pred_ops = torch.argmax(logits_op, dim=-1).squeeze().tolist() |
|
|
pred_ins = torch.argmax(logits_ins, dim=-1).squeeze().tolist() |
|
|
|
|
|
|
|
|
if not isinstance(pred_ops, list): pred_ops = [pred_ops] |
|
|
if not isinstance(pred_ins, list): pred_ins = [pred_ins] |
|
|
|
|
|
|
|
|
final_phonemes = [] |
|
|
log = [] |
|
|
|
|
|
for i, (orig, op_id, ins_id) in enumerate(zip(src_phns, pred_ops, pred_ins)): |
|
|
|
|
|
|
|
|
op_str = self.id2op.get(op_id, "KEEP") |
|
|
curr_log = {"src": orig, "op": op_str, "ins": "NONE"} |
|
|
|
|
|
if op_str == "KEEP": |
|
|
final_phonemes.append(orig) |
|
|
elif op_str == "DEL": |
|
|
pass |
|
|
elif op_str.startswith("SUB:"): |
|
|
|
|
|
new_phn = op_str.split(":")[1] |
|
|
final_phonemes.append(new_phn) |
|
|
|
|
|
|
|
|
ins_str = self.id2ins.get(ins_id, "<NONE>") |
|
|
if ins_str != "<NONE>": |
|
|
final_phonemes.append(ins_str) |
|
|
curr_log["ins"] = ins_str |
|
|
|
|
|
log.append(curr_log) |
|
|
|
|
|
return final_phonemes, log |
|
|
|
|
|
if __name__ == "__main__": |
|
|
ckpt_path = "/data/chenxu/checkpoints/edit_seq_speech/phoneme-corrector/last.ckpt" |
|
|
vocab_path = "edit_seq_speech/config/vocab.json" |
|
|
wav_file = "test.wav" |
|
|
text_input = "Last Sunday" |
|
|
|
|
|
if os.path.exists(ckpt_path) and os.path.exists(wav_file): |
|
|
infer = PhonemeCorrectionInference(ckpt_path, vocab_path) |
|
|
result, details = infer.predict(wav_file, text_input) |
|
|
|
|
|
print(f"Input Text: {text_input}") |
|
|
print(f"Result Phn: {result}") |
|
|
print("-" * 20) |
|
|
for step in details: |
|
|
print(f"{step['src']} -> {step['op']} + Insert({step['ins']})") |
|
|
else: |
|
|
print("Please set valid paths for checkpoint and wav file.") |