import torch import torch.nn as nn import torchaudio import json import re import os from g2p_en import G2p import pytorch_lightning as pl from .model import PhonemeCorrector from transformers import Wav2Vec2Processor, HubertModel from safetensors.torch import load_file as safetensors_load_file class PhonemeCorrectionInference: def __init__(self, checkpoint_path, vocab_path, audio_model_name="facebook/hubert-large-ls960-ft", device=None): self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu") # 1) Load vocab print(f"Loading config from {vocab_path}...") with open(vocab_path, "r") as f: self.config = json.load(f) self.op_map = self.config["op_to_id"] self.ins_map = self.config["insert_to_id"] self.id2op = {v: k for k, v in self.op_map.items()} self.id2ins = {v: k for k, v in self.ins_map.items()} # 2) Load G2P self.g2p = G2p() # 3) Load hparams.json (prefer same dir as checkpoint, fallback to parent) if not os.path.exists(checkpoint_path): raise FileNotFoundError(f"Checkpoint not found at {checkpoint_path}") hparams = {} hp_candidates = [ os.path.join(os.path.dirname(checkpoint_path), "hparams.json"), os.path.join(os.path.dirname(os.path.dirname(checkpoint_path)), "hparams.json"), ] for hp in hp_candidates: if os.path.exists(hp): with open(hp, "r") as f: hparams = json.load(f) break # 4) Load weights/state_dict print(f"Loading model weights from {checkpoint_path}...") lower = checkpoint_path.lower() if lower.endswith(".safetensors"): state_dict = safetensors_load_file(checkpoint_path, device="cpu") elif lower.endswith(".ckpt") or lower.endswith(".pt") or lower.endswith(".pth"): # NOTE: weights_only=False is needed for Lightning-style checkpoints in PyTorch 2.6+ ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False) state_dict = ckpt.get("state_dict", ckpt) if not hparams and isinstance(ckpt, dict): hparams = ckpt.get("hyper_parameters", {}) or {} else: raise ValueError(f"Unsupported checkpoint format: {checkpoint_path}") # 5) Build model with correct hyperparams vocab_size_from_vocab = max(self.ins_map.values()) + 1 # Prefer hparams.json, but also sanity-check against state_dict shapes vocab_size = int(hparams.get("vocab_size", vocab_size_from_vocab)) audio_vocab_size = int(hparams.get("audio_vocab_size", 2048)) d_model = int(hparams.get("d_model", 256)) nhead = int(hparams.get("nhead", 4)) num_layers = int(hparams.get("num_layers", 4)) dropout = float(hparams.get("dropout", 0.1)) lr = float(hparams.get("lr", 1e-4)) weight_decay = float(hparams.get("weight_decay", 0.01)) scheduler_config = hparams.get("scheduler_config", None) optimizer_config = hparams.get("optimizer_config", None) # Hard check: vocab.json and weights must agree if "text_embedding.weight" in state_dict: vsd, dsd = state_dict["text_embedding.weight"].shape asd = state_dict["audio_embedding.weight"].shape[0] if vsd != vocab_size_from_vocab: raise ValueError( f"vocab.json (vocab_size={vocab_size_from_vocab}) does not match weights (vocab_size={vsd}). " "Please upload the matching vocab.json." ) # Override to match weights exactly (safer) vocab_size = vsd audio_vocab_size = asd d_model = dsd self.model = PhonemeCorrector( vocab_size=vocab_size, audio_vocab_size=audio_vocab_size, d_model=d_model, nhead=nhead, num_layers=num_layers, dropout=dropout, lr=lr, weight_decay=weight_decay, scheduler_config=scheduler_config, optimizer_config=optimizer_config, ) missing, unexpected = self.model.load_state_dict(state_dict, strict=False) if missing or unexpected: print(f"[load_state_dict] missing={len(missing)} unexpected={len(unexpected)}") if missing[:5]: print(" missing (first 5):", missing[:5]) if unexpected[:5]: print(" unexpected (first 5):", unexpected[:5]) self.model.to(self.device).eval() # 6) Load Audio Tokenizer print(f"Loading Audio Tokenizer: {audio_model_name}") self.audio_processor = Wav2Vec2Processor.from_pretrained(audio_model_name) self.audio_model = HubertModel.from_pretrained(audio_model_name).eval().to(self.device) def _clean_phn(self, phn_list): """Standard cleaning to match training.""" IGNORED = {"SIL", "'", "SPN", " "} return [p.rstrip('012') for p in phn_list if p.rstrip('012') not in IGNORED] def _get_audio_tokens(self, wav_path): """ Runs the audio tokenizer. IMPORTANT: This must match your training data generation logic. """ waveform, sr = torchaudio.load(wav_path) if sr != 16000: resampler = torchaudio.transforms.Resample(sr, 16000) waveform = resampler(waveform) inputs = self.audio_processor(waveform.squeeze().numpy(), return_tensors="pt", sampling_rate=16000) input_values = inputs.input_values.to(self.device) with torch.no_grad(): outputs = self.audio_model(input_values) # Placeholder Quantization (Argmax) - Replace if using K-Means features = outputs.last_hidden_state tokens = torch.argmax(features, dim=-1).squeeze() # Downsample to 25Hz (Assuming model is 50Hz) tokens = tokens[::2] return tokens.unsqueeze(0) # (1, T) def predict(self, wav_path, text): # A. Prepare Inputs # 1. Text -> Phonemes -> IDs # raw_phns = self.g2p(text) raw_phns = text.split() # Assuming input text is already phonemized for inference src_phns = self._clean_phn(raw_phns) # Create text vocab from insert_to_id (same as dataset) text_vocab = {k: v for k, v in self.ins_map.items() if k not in ['', '']} text_ids = [text_vocab.get(p, text_vocab.get("AA", 2)) for p in src_phns] text_tensor = torch.tensor([text_ids], dtype=torch.long).to(self.device) # 2. Audio -> Tokens audio_tensor = self._get_audio_tokens(wav_path) # B. Run Model with torch.no_grad(): # Create masks txt_mask = torch.ones_like(text_tensor) aud_mask = torch.ones_like(audio_tensor) logits_op, logits_ins = self.model( text_tensor, audio_tensor, txt_mask, aud_mask ) # C. Decode pred_ops = torch.argmax(logits_op, dim=-1).squeeze().tolist() pred_ins = torch.argmax(logits_ins, dim=-1).squeeze().tolist() # Ensure lists if not isinstance(pred_ops, list): pred_ops = [pred_ops] if not isinstance(pred_ins, list): pred_ins = [pred_ins] # D. Reconstruct Sequence final_phonemes = [] log = [] for i, (orig, op_id, ins_id) in enumerate(zip(src_phns, pred_ops, pred_ins)): # 1. Apply Operation op_str = self.id2op.get(op_id, "KEEP") curr_log = {"src": orig, "op": op_str, "ins": "NONE"} if op_str == "KEEP": final_phonemes.append(orig) elif op_str == "DEL": pass # Do not append elif op_str.startswith("SUB:"): # Extract phoneme: "SUB:AA" -> "AA" new_phn = op_str.split(":")[1] final_phonemes.append(new_phn) # 2. Apply Insertion ins_str = self.id2ins.get(ins_id, "") if ins_str != "": final_phonemes.append(ins_str) curr_log["ins"] = ins_str log.append(curr_log) return final_phonemes, log if __name__ == "__main__": ckpt_path = "/data/chenxu/checkpoints/edit_seq_speech/phoneme-corrector/last.ckpt" vocab_path = "edit_seq_speech/config/vocab.json" wav_file = "test.wav" text_input = "Last Sunday" if os.path.exists(ckpt_path) and os.path.exists(wav_file): infer = PhonemeCorrectionInference(ckpt_path, vocab_path) result, details = infer.predict(wav_file, text_input) print(f"Input Text: {text_input}") print(f"Result Phn: {result}") print("-" * 20) for step in details: print(f"{step['src']} -> {step['op']} + Insert({step['ins']})") else: print("Please set valid paths for checkpoint and wav file.")