pineapple-lover commited on
Commit
08836fe
·
1 Parent(s): 941a3bf

Fix: load safetensors checkpoints correctly

Browse files
Files changed (1) hide show
  1. edit_seq_speech/inference.py +89 -31
edit_seq_speech/inference.py CHANGED
@@ -9,48 +9,106 @@ import pytorch_lightning as pl
9
 
10
  from .model import PhonemeCorrector
11
  from transformers import Wav2Vec2Processor, HubertModel
 
12
 
13
  class PhonemeCorrectionInference:
14
  def __init__(self, checkpoint_path, vocab_path, audio_model_name="facebook/hubert-large-ls960-ft", device=None):
15
  self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
-
17
- # 1. Load Vocab / Config
18
  print(f"Loading config from {vocab_path}...")
19
- with open(vocab_path, 'r') as f:
20
  self.config = json.load(f)
21
-
22
- self.op_map = self.config['op_to_id']
23
- self.ins_map = self.config['insert_to_id']
24
-
25
- # Create Reverse Maps (ID -> String)
26
  self.id2op = {v: k for k, v in self.op_map.items()}
27
  self.id2ins = {v: k for k, v in self.ins_map.items()}
28
-
29
- # 2. Load G2P
30
  self.g2p = G2p()
31
-
32
- # 3. Load Model
33
- print(f"Loading model from {checkpoint_path}...")
34
- if os.path.exists(checkpoint_path):
35
- checkpoint = torch.load(checkpoint_path, map_location=self.device)
36
- hparams = checkpoint.get('hyper_parameters', {})
37
-
38
- vocab_size = max(self.ins_map.values()) + 1
39
- audio_vocab_size = hparams.get('audio_vocab_size', 2048)
40
-
41
- self.model = PhonemeCorrector.load_from_checkpoint(
42
- checkpoint_path,
43
- map_location=self.device,
44
- vocab_size=vocab_size,
45
- audio_vocab_size=audio_vocab_size
46
- )
47
- else:
48
  raise FileNotFoundError(f"Checkpoint not found at {checkpoint_path}")
49
-
50
- self.model.to(self.device)
51
- self.model.eval()
52
 
53
- # 4. Load Audio Tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  print(f"Loading Audio Tokenizer: {audio_model_name}")
55
  self.audio_processor = Wav2Vec2Processor.from_pretrained(audio_model_name)
56
  self.audio_model = HubertModel.from_pretrained(audio_model_name).eval().to(self.device)
 
9
 
10
  from .model import PhonemeCorrector
11
  from transformers import Wav2Vec2Processor, HubertModel
12
+ from safetensors.torch import load_file as safetensors_load_file
13
 
14
  class PhonemeCorrectionInference:
15
  def __init__(self, checkpoint_path, vocab_path, audio_model_name="facebook/hubert-large-ls960-ft", device=None):
16
  self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+
18
+ # 1) Load vocab
19
  print(f"Loading config from {vocab_path}...")
20
+ with open(vocab_path, "r") as f:
21
  self.config = json.load(f)
22
+
23
+ self.op_map = self.config["op_to_id"]
24
+ self.ins_map = self.config["insert_to_id"]
 
 
25
  self.id2op = {v: k for k, v in self.op_map.items()}
26
  self.id2ins = {v: k for k, v in self.ins_map.items()}
27
+
28
+ # 2) Load G2P
29
  self.g2p = G2p()
30
+
31
+ # 3) Load hparams.json (prefer same dir as checkpoint, fallback to parent)
32
+ if not os.path.exists(checkpoint_path):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  raise FileNotFoundError(f"Checkpoint not found at {checkpoint_path}")
 
 
 
34
 
35
+ hparams = {}
36
+ hp_candidates = [
37
+ os.path.join(os.path.dirname(checkpoint_path), "hparams.json"),
38
+ os.path.join(os.path.dirname(os.path.dirname(checkpoint_path)), "hparams.json"),
39
+ ]
40
+ for hp in hp_candidates:
41
+ if os.path.exists(hp):
42
+ with open(hp, "r") as f:
43
+ hparams = json.load(f)
44
+ break
45
+
46
+ # 4) Load weights/state_dict
47
+ print(f"Loading model weights from {checkpoint_path}...")
48
+ lower = checkpoint_path.lower()
49
+ if lower.endswith(".safetensors"):
50
+ state_dict = safetensors_load_file(checkpoint_path, device="cpu")
51
+ elif lower.endswith(".ckpt") or lower.endswith(".pt") or lower.endswith(".pth"):
52
+ # NOTE: weights_only=False is needed for Lightning-style checkpoints in PyTorch 2.6+
53
+ ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
54
+ state_dict = ckpt.get("state_dict", ckpt)
55
+ if not hparams and isinstance(ckpt, dict):
56
+ hparams = ckpt.get("hyper_parameters", {}) or {}
57
+ else:
58
+ raise ValueError(f"Unsupported checkpoint format: {checkpoint_path}")
59
+
60
+ # 5) Build model with correct hyperparams
61
+ vocab_size_from_vocab = max(self.ins_map.values()) + 1
62
+
63
+ # Prefer hparams.json, but also sanity-check against state_dict shapes
64
+ vocab_size = int(hparams.get("vocab_size", vocab_size_from_vocab))
65
+ audio_vocab_size = int(hparams.get("audio_vocab_size", 2048))
66
+ d_model = int(hparams.get("d_model", 256))
67
+ nhead = int(hparams.get("nhead", 4))
68
+ num_layers = int(hparams.get("num_layers", 4))
69
+ dropout = float(hparams.get("dropout", 0.1))
70
+ lr = float(hparams.get("lr", 1e-4))
71
+ weight_decay = float(hparams.get("weight_decay", 0.01))
72
+ scheduler_config = hparams.get("scheduler_config", None)
73
+ optimizer_config = hparams.get("optimizer_config", None)
74
+
75
+ # Hard check: vocab.json and weights must agree
76
+ if "text_embedding.weight" in state_dict:
77
+ vsd, dsd = state_dict["text_embedding.weight"].shape
78
+ asd = state_dict["audio_embedding.weight"].shape[0]
79
+ if vsd != vocab_size_from_vocab:
80
+ raise ValueError(
81
+ f"vocab.json (vocab_size={vocab_size_from_vocab}) does not match weights (vocab_size={vsd}). "
82
+ "Please upload the matching vocab.json."
83
+ )
84
+ # Override to match weights exactly (safer)
85
+ vocab_size = vsd
86
+ audio_vocab_size = asd
87
+ d_model = dsd
88
+
89
+ self.model = PhonemeCorrector(
90
+ vocab_size=vocab_size,
91
+ audio_vocab_size=audio_vocab_size,
92
+ d_model=d_model,
93
+ nhead=nhead,
94
+ num_layers=num_layers,
95
+ dropout=dropout,
96
+ lr=lr,
97
+ weight_decay=weight_decay,
98
+ scheduler_config=scheduler_config,
99
+ optimizer_config=optimizer_config,
100
+ )
101
+ missing, unexpected = self.model.load_state_dict(state_dict, strict=False)
102
+ if missing or unexpected:
103
+ print(f"[load_state_dict] missing={len(missing)} unexpected={len(unexpected)}")
104
+ if missing[:5]:
105
+ print(" missing (first 5):", missing[:5])
106
+ if unexpected[:5]:
107
+ print(" unexpected (first 5):", unexpected[:5])
108
+
109
+ self.model.to(self.device).eval()
110
+
111
+ # 6) Load Audio Tokenizer
112
  print(f"Loading Audio Tokenizer: {audio_model_name}")
113
  self.audio_processor = Wav2Vec2Processor.from_pretrained(audio_model_name)
114
  self.audio_model = HubertModel.from_pretrained(audio_model_name).eval().to(self.device)