pineapple-lover's picture
Fix: load safetensors checkpoints correctly
08836fe
import torch
import torch.nn as nn
import torchaudio
import json
import re
import os
from g2p_en import G2p
import pytorch_lightning as pl
from .model import PhonemeCorrector
from transformers import Wav2Vec2Processor, HubertModel
from safetensors.torch import load_file as safetensors_load_file
class PhonemeCorrectionInference:
def __init__(self, checkpoint_path, vocab_path, audio_model_name="facebook/hubert-large-ls960-ft", device=None):
self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 1) Load vocab
print(f"Loading config from {vocab_path}...")
with open(vocab_path, "r") as f:
self.config = json.load(f)
self.op_map = self.config["op_to_id"]
self.ins_map = self.config["insert_to_id"]
self.id2op = {v: k for k, v in self.op_map.items()}
self.id2ins = {v: k for k, v in self.ins_map.items()}
# 2) Load G2P
self.g2p = G2p()
# 3) Load hparams.json (prefer same dir as checkpoint, fallback to parent)
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(f"Checkpoint not found at {checkpoint_path}")
hparams = {}
hp_candidates = [
os.path.join(os.path.dirname(checkpoint_path), "hparams.json"),
os.path.join(os.path.dirname(os.path.dirname(checkpoint_path)), "hparams.json"),
]
for hp in hp_candidates:
if os.path.exists(hp):
with open(hp, "r") as f:
hparams = json.load(f)
break
# 4) Load weights/state_dict
print(f"Loading model weights from {checkpoint_path}...")
lower = checkpoint_path.lower()
if lower.endswith(".safetensors"):
state_dict = safetensors_load_file(checkpoint_path, device="cpu")
elif lower.endswith(".ckpt") or lower.endswith(".pt") or lower.endswith(".pth"):
# NOTE: weights_only=False is needed for Lightning-style checkpoints in PyTorch 2.6+
ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
state_dict = ckpt.get("state_dict", ckpt)
if not hparams and isinstance(ckpt, dict):
hparams = ckpt.get("hyper_parameters", {}) or {}
else:
raise ValueError(f"Unsupported checkpoint format: {checkpoint_path}")
# 5) Build model with correct hyperparams
vocab_size_from_vocab = max(self.ins_map.values()) + 1
# Prefer hparams.json, but also sanity-check against state_dict shapes
vocab_size = int(hparams.get("vocab_size", vocab_size_from_vocab))
audio_vocab_size = int(hparams.get("audio_vocab_size", 2048))
d_model = int(hparams.get("d_model", 256))
nhead = int(hparams.get("nhead", 4))
num_layers = int(hparams.get("num_layers", 4))
dropout = float(hparams.get("dropout", 0.1))
lr = float(hparams.get("lr", 1e-4))
weight_decay = float(hparams.get("weight_decay", 0.01))
scheduler_config = hparams.get("scheduler_config", None)
optimizer_config = hparams.get("optimizer_config", None)
# Hard check: vocab.json and weights must agree
if "text_embedding.weight" in state_dict:
vsd, dsd = state_dict["text_embedding.weight"].shape
asd = state_dict["audio_embedding.weight"].shape[0]
if vsd != vocab_size_from_vocab:
raise ValueError(
f"vocab.json (vocab_size={vocab_size_from_vocab}) does not match weights (vocab_size={vsd}). "
"Please upload the matching vocab.json."
)
# Override to match weights exactly (safer)
vocab_size = vsd
audio_vocab_size = asd
d_model = dsd
self.model = PhonemeCorrector(
vocab_size=vocab_size,
audio_vocab_size=audio_vocab_size,
d_model=d_model,
nhead=nhead,
num_layers=num_layers,
dropout=dropout,
lr=lr,
weight_decay=weight_decay,
scheduler_config=scheduler_config,
optimizer_config=optimizer_config,
)
missing, unexpected = self.model.load_state_dict(state_dict, strict=False)
if missing or unexpected:
print(f"[load_state_dict] missing={len(missing)} unexpected={len(unexpected)}")
if missing[:5]:
print(" missing (first 5):", missing[:5])
if unexpected[:5]:
print(" unexpected (first 5):", unexpected[:5])
self.model.to(self.device).eval()
# 6) Load Audio Tokenizer
print(f"Loading Audio Tokenizer: {audio_model_name}")
self.audio_processor = Wav2Vec2Processor.from_pretrained(audio_model_name)
self.audio_model = HubertModel.from_pretrained(audio_model_name).eval().to(self.device)
def _clean_phn(self, phn_list):
"""Standard cleaning to match training."""
IGNORED = {"SIL", "'", "SPN", " "}
return [p.rstrip('012') for p in phn_list if p.rstrip('012') not in IGNORED]
def _get_audio_tokens(self, wav_path):
"""
Runs the audio tokenizer.
IMPORTANT: This must match your training data generation logic.
"""
waveform, sr = torchaudio.load(wav_path)
if sr != 16000:
resampler = torchaudio.transforms.Resample(sr, 16000)
waveform = resampler(waveform)
inputs = self.audio_processor(waveform.squeeze().numpy(), return_tensors="pt", sampling_rate=16000)
input_values = inputs.input_values.to(self.device)
with torch.no_grad():
outputs = self.audio_model(input_values)
# Placeholder Quantization (Argmax) - Replace if using K-Means
features = outputs.last_hidden_state
tokens = torch.argmax(features, dim=-1).squeeze()
# Downsample to 25Hz (Assuming model is 50Hz)
tokens = tokens[::2]
return tokens.unsqueeze(0) # (1, T)
def predict(self, wav_path, text):
# A. Prepare Inputs
# 1. Text -> Phonemes -> IDs
# raw_phns = self.g2p(text)
raw_phns = text.split() # Assuming input text is already phonemized for inference
src_phns = self._clean_phn(raw_phns)
# Create text vocab from insert_to_id (same as dataset)
text_vocab = {k: v for k, v in self.ins_map.items() if k not in ['<NONE>', '<PAD>']}
text_ids = [text_vocab.get(p, text_vocab.get("AA", 2)) for p in src_phns]
text_tensor = torch.tensor([text_ids], dtype=torch.long).to(self.device)
# 2. Audio -> Tokens
audio_tensor = self._get_audio_tokens(wav_path)
# B. Run Model
with torch.no_grad():
# Create masks
txt_mask = torch.ones_like(text_tensor)
aud_mask = torch.ones_like(audio_tensor)
logits_op, logits_ins = self.model(
text_tensor, audio_tensor, txt_mask, aud_mask
)
# C. Decode
pred_ops = torch.argmax(logits_op, dim=-1).squeeze().tolist()
pred_ins = torch.argmax(logits_ins, dim=-1).squeeze().tolist()
# Ensure lists
if not isinstance(pred_ops, list): pred_ops = [pred_ops]
if not isinstance(pred_ins, list): pred_ins = [pred_ins]
# D. Reconstruct Sequence
final_phonemes = []
log = []
for i, (orig, op_id, ins_id) in enumerate(zip(src_phns, pred_ops, pred_ins)):
# 1. Apply Operation
op_str = self.id2op.get(op_id, "KEEP")
curr_log = {"src": orig, "op": op_str, "ins": "NONE"}
if op_str == "KEEP":
final_phonemes.append(orig)
elif op_str == "DEL":
pass # Do not append
elif op_str.startswith("SUB:"):
# Extract phoneme: "SUB:AA" -> "AA"
new_phn = op_str.split(":")[1]
final_phonemes.append(new_phn)
# 2. Apply Insertion
ins_str = self.id2ins.get(ins_id, "<NONE>")
if ins_str != "<NONE>":
final_phonemes.append(ins_str)
curr_log["ins"] = ins_str
log.append(curr_log)
return final_phonemes, log
if __name__ == "__main__":
ckpt_path = "/data/chenxu/checkpoints/edit_seq_speech/phoneme-corrector/last.ckpt"
vocab_path = "edit_seq_speech/config/vocab.json"
wav_file = "test.wav"
text_input = "Last Sunday"
if os.path.exists(ckpt_path) and os.path.exists(wav_file):
infer = PhonemeCorrectionInference(ckpt_path, vocab_path)
result, details = infer.predict(wav_file, text_input)
print(f"Input Text: {text_input}")
print(f"Result Phn: {result}")
print("-" * 20)
for step in details:
print(f"{step['src']} -> {step['op']} + Insert({step['ins']})")
else:
print("Please set valid paths for checkpoint and wav file.")