Commit ·
08836fe
1
Parent(s): 941a3bf
Fix: load safetensors checkpoints correctly
Browse files- edit_seq_speech/inference.py +89 -31
edit_seq_speech/inference.py
CHANGED
|
@@ -9,48 +9,106 @@ import pytorch_lightning as pl
|
|
| 9 |
|
| 10 |
from .model import PhonemeCorrector
|
| 11 |
from transformers import Wav2Vec2Processor, HubertModel
|
|
|
|
| 12 |
|
| 13 |
class PhonemeCorrectionInference:
|
| 14 |
def __init__(self, checkpoint_path, vocab_path, audio_model_name="facebook/hubert-large-ls960-ft", device=None):
|
| 15 |
self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 16 |
-
|
| 17 |
-
# 1
|
| 18 |
print(f"Loading config from {vocab_path}...")
|
| 19 |
-
with open(vocab_path,
|
| 20 |
self.config = json.load(f)
|
| 21 |
-
|
| 22 |
-
self.op_map = self.config[
|
| 23 |
-
self.ins_map = self.config[
|
| 24 |
-
|
| 25 |
-
# Create Reverse Maps (ID -> String)
|
| 26 |
self.id2op = {v: k for k, v in self.op_map.items()}
|
| 27 |
self.id2ins = {v: k for k, v in self.ins_map.items()}
|
| 28 |
-
|
| 29 |
-
# 2
|
| 30 |
self.g2p = G2p()
|
| 31 |
-
|
| 32 |
-
# 3
|
| 33 |
-
|
| 34 |
-
if os.path.exists(checkpoint_path):
|
| 35 |
-
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
| 36 |
-
hparams = checkpoint.get('hyper_parameters', {})
|
| 37 |
-
|
| 38 |
-
vocab_size = max(self.ins_map.values()) + 1
|
| 39 |
-
audio_vocab_size = hparams.get('audio_vocab_size', 2048)
|
| 40 |
-
|
| 41 |
-
self.model = PhonemeCorrector.load_from_checkpoint(
|
| 42 |
-
checkpoint_path,
|
| 43 |
-
map_location=self.device,
|
| 44 |
-
vocab_size=vocab_size,
|
| 45 |
-
audio_vocab_size=audio_vocab_size
|
| 46 |
-
)
|
| 47 |
-
else:
|
| 48 |
raise FileNotFoundError(f"Checkpoint not found at {checkpoint_path}")
|
| 49 |
-
|
| 50 |
-
self.model.to(self.device)
|
| 51 |
-
self.model.eval()
|
| 52 |
|
| 53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
print(f"Loading Audio Tokenizer: {audio_model_name}")
|
| 55 |
self.audio_processor = Wav2Vec2Processor.from_pretrained(audio_model_name)
|
| 56 |
self.audio_model = HubertModel.from_pretrained(audio_model_name).eval().to(self.device)
|
|
|
|
| 9 |
|
| 10 |
from .model import PhonemeCorrector
|
| 11 |
from transformers import Wav2Vec2Processor, HubertModel
|
| 12 |
+
from safetensors.torch import load_file as safetensors_load_file
|
| 13 |
|
| 14 |
class PhonemeCorrectionInference:
|
| 15 |
def __init__(self, checkpoint_path, vocab_path, audio_model_name="facebook/hubert-large-ls960-ft", device=None):
|
| 16 |
self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 17 |
+
|
| 18 |
+
# 1) Load vocab
|
| 19 |
print(f"Loading config from {vocab_path}...")
|
| 20 |
+
with open(vocab_path, "r") as f:
|
| 21 |
self.config = json.load(f)
|
| 22 |
+
|
| 23 |
+
self.op_map = self.config["op_to_id"]
|
| 24 |
+
self.ins_map = self.config["insert_to_id"]
|
|
|
|
|
|
|
| 25 |
self.id2op = {v: k for k, v in self.op_map.items()}
|
| 26 |
self.id2ins = {v: k for k, v in self.ins_map.items()}
|
| 27 |
+
|
| 28 |
+
# 2) Load G2P
|
| 29 |
self.g2p = G2p()
|
| 30 |
+
|
| 31 |
+
# 3) Load hparams.json (prefer same dir as checkpoint, fallback to parent)
|
| 32 |
+
if not os.path.exists(checkpoint_path):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
raise FileNotFoundError(f"Checkpoint not found at {checkpoint_path}")
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
+
hparams = {}
|
| 36 |
+
hp_candidates = [
|
| 37 |
+
os.path.join(os.path.dirname(checkpoint_path), "hparams.json"),
|
| 38 |
+
os.path.join(os.path.dirname(os.path.dirname(checkpoint_path)), "hparams.json"),
|
| 39 |
+
]
|
| 40 |
+
for hp in hp_candidates:
|
| 41 |
+
if os.path.exists(hp):
|
| 42 |
+
with open(hp, "r") as f:
|
| 43 |
+
hparams = json.load(f)
|
| 44 |
+
break
|
| 45 |
+
|
| 46 |
+
# 4) Load weights/state_dict
|
| 47 |
+
print(f"Loading model weights from {checkpoint_path}...")
|
| 48 |
+
lower = checkpoint_path.lower()
|
| 49 |
+
if lower.endswith(".safetensors"):
|
| 50 |
+
state_dict = safetensors_load_file(checkpoint_path, device="cpu")
|
| 51 |
+
elif lower.endswith(".ckpt") or lower.endswith(".pt") or lower.endswith(".pth"):
|
| 52 |
+
# NOTE: weights_only=False is needed for Lightning-style checkpoints in PyTorch 2.6+
|
| 53 |
+
ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
|
| 54 |
+
state_dict = ckpt.get("state_dict", ckpt)
|
| 55 |
+
if not hparams and isinstance(ckpt, dict):
|
| 56 |
+
hparams = ckpt.get("hyper_parameters", {}) or {}
|
| 57 |
+
else:
|
| 58 |
+
raise ValueError(f"Unsupported checkpoint format: {checkpoint_path}")
|
| 59 |
+
|
| 60 |
+
# 5) Build model with correct hyperparams
|
| 61 |
+
vocab_size_from_vocab = max(self.ins_map.values()) + 1
|
| 62 |
+
|
| 63 |
+
# Prefer hparams.json, but also sanity-check against state_dict shapes
|
| 64 |
+
vocab_size = int(hparams.get("vocab_size", vocab_size_from_vocab))
|
| 65 |
+
audio_vocab_size = int(hparams.get("audio_vocab_size", 2048))
|
| 66 |
+
d_model = int(hparams.get("d_model", 256))
|
| 67 |
+
nhead = int(hparams.get("nhead", 4))
|
| 68 |
+
num_layers = int(hparams.get("num_layers", 4))
|
| 69 |
+
dropout = float(hparams.get("dropout", 0.1))
|
| 70 |
+
lr = float(hparams.get("lr", 1e-4))
|
| 71 |
+
weight_decay = float(hparams.get("weight_decay", 0.01))
|
| 72 |
+
scheduler_config = hparams.get("scheduler_config", None)
|
| 73 |
+
optimizer_config = hparams.get("optimizer_config", None)
|
| 74 |
+
|
| 75 |
+
# Hard check: vocab.json and weights must agree
|
| 76 |
+
if "text_embedding.weight" in state_dict:
|
| 77 |
+
vsd, dsd = state_dict["text_embedding.weight"].shape
|
| 78 |
+
asd = state_dict["audio_embedding.weight"].shape[0]
|
| 79 |
+
if vsd != vocab_size_from_vocab:
|
| 80 |
+
raise ValueError(
|
| 81 |
+
f"vocab.json (vocab_size={vocab_size_from_vocab}) does not match weights (vocab_size={vsd}). "
|
| 82 |
+
"Please upload the matching vocab.json."
|
| 83 |
+
)
|
| 84 |
+
# Override to match weights exactly (safer)
|
| 85 |
+
vocab_size = vsd
|
| 86 |
+
audio_vocab_size = asd
|
| 87 |
+
d_model = dsd
|
| 88 |
+
|
| 89 |
+
self.model = PhonemeCorrector(
|
| 90 |
+
vocab_size=vocab_size,
|
| 91 |
+
audio_vocab_size=audio_vocab_size,
|
| 92 |
+
d_model=d_model,
|
| 93 |
+
nhead=nhead,
|
| 94 |
+
num_layers=num_layers,
|
| 95 |
+
dropout=dropout,
|
| 96 |
+
lr=lr,
|
| 97 |
+
weight_decay=weight_decay,
|
| 98 |
+
scheduler_config=scheduler_config,
|
| 99 |
+
optimizer_config=optimizer_config,
|
| 100 |
+
)
|
| 101 |
+
missing, unexpected = self.model.load_state_dict(state_dict, strict=False)
|
| 102 |
+
if missing or unexpected:
|
| 103 |
+
print(f"[load_state_dict] missing={len(missing)} unexpected={len(unexpected)}")
|
| 104 |
+
if missing[:5]:
|
| 105 |
+
print(" missing (first 5):", missing[:5])
|
| 106 |
+
if unexpected[:5]:
|
| 107 |
+
print(" unexpected (first 5):", unexpected[:5])
|
| 108 |
+
|
| 109 |
+
self.model.to(self.device).eval()
|
| 110 |
+
|
| 111 |
+
# 6) Load Audio Tokenizer
|
| 112 |
print(f"Loading Audio Tokenizer: {audio_model_name}")
|
| 113 |
self.audio_processor = Wav2Vec2Processor.from_pretrained(audio_model_name)
|
| 114 |
self.audio_model = HubertModel.from_pretrained(audio_model_name).eval().to(self.device)
|