File size: 9,238 Bytes
7d0662d
 
 
 
 
 
 
 
 
 
 
08836fe
7d0662d
 
 
 
08836fe
 
7d0662d
08836fe
7d0662d
08836fe
 
 
7d0662d
 
08836fe
 
7d0662d
08836fe
 
 
7d0662d
 
08836fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d0662d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
import torch
import torch.nn as nn
import torchaudio
import json
import re
import os
from g2p_en import G2p
import pytorch_lightning as pl

from .model import PhonemeCorrector
from transformers import Wav2Vec2Processor, HubertModel
from safetensors.torch import load_file as safetensors_load_file

class PhonemeCorrectionInference:
    def __init__(self, checkpoint_path, vocab_path, audio_model_name="facebook/hubert-large-ls960-ft", device=None):
        self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # 1) Load vocab
        print(f"Loading config from {vocab_path}...")
        with open(vocab_path, "r") as f:
            self.config = json.load(f)

        self.op_map = self.config["op_to_id"]
        self.ins_map = self.config["insert_to_id"]
        self.id2op = {v: k for k, v in self.op_map.items()}
        self.id2ins = {v: k for k, v in self.ins_map.items()}

        # 2) Load G2P
        self.g2p = G2p()

        # 3) Load hparams.json (prefer same dir as checkpoint, fallback to parent)
        if not os.path.exists(checkpoint_path):
            raise FileNotFoundError(f"Checkpoint not found at {checkpoint_path}")

        hparams = {}
        hp_candidates = [
            os.path.join(os.path.dirname(checkpoint_path), "hparams.json"),
            os.path.join(os.path.dirname(os.path.dirname(checkpoint_path)), "hparams.json"),
        ]
        for hp in hp_candidates:
            if os.path.exists(hp):
                with open(hp, "r") as f:
                    hparams = json.load(f)
                break

        # 4) Load weights/state_dict
        print(f"Loading model weights from {checkpoint_path}...")
        lower = checkpoint_path.lower()
        if lower.endswith(".safetensors"):
            state_dict = safetensors_load_file(checkpoint_path, device="cpu")
        elif lower.endswith(".ckpt") or lower.endswith(".pt") or lower.endswith(".pth"):
            # NOTE: weights_only=False is needed for Lightning-style checkpoints in PyTorch 2.6+
            ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
            state_dict = ckpt.get("state_dict", ckpt)
            if not hparams and isinstance(ckpt, dict):
                hparams = ckpt.get("hyper_parameters", {}) or {}
        else:
            raise ValueError(f"Unsupported checkpoint format: {checkpoint_path}")

        # 5) Build model with correct hyperparams
        vocab_size_from_vocab = max(self.ins_map.values()) + 1

        # Prefer hparams.json, but also sanity-check against state_dict shapes
        vocab_size = int(hparams.get("vocab_size", vocab_size_from_vocab))
        audio_vocab_size = int(hparams.get("audio_vocab_size", 2048))
        d_model = int(hparams.get("d_model", 256))
        nhead = int(hparams.get("nhead", 4))
        num_layers = int(hparams.get("num_layers", 4))
        dropout = float(hparams.get("dropout", 0.1))
        lr = float(hparams.get("lr", 1e-4))
        weight_decay = float(hparams.get("weight_decay", 0.01))
        scheduler_config = hparams.get("scheduler_config", None)
        optimizer_config = hparams.get("optimizer_config", None)

        # Hard check: vocab.json and weights must agree
        if "text_embedding.weight" in state_dict:
            vsd, dsd = state_dict["text_embedding.weight"].shape
            asd = state_dict["audio_embedding.weight"].shape[0]
            if vsd != vocab_size_from_vocab:
                raise ValueError(
                    f"vocab.json (vocab_size={vocab_size_from_vocab}) does not match weights (vocab_size={vsd}). "
                    "Please upload the matching vocab.json."
                )
            # Override to match weights exactly (safer)
            vocab_size = vsd
            audio_vocab_size = asd
            d_model = dsd

        self.model = PhonemeCorrector(
            vocab_size=vocab_size,
            audio_vocab_size=audio_vocab_size,
            d_model=d_model,
            nhead=nhead,
            num_layers=num_layers,
            dropout=dropout,
            lr=lr,
            weight_decay=weight_decay,
            scheduler_config=scheduler_config,
            optimizer_config=optimizer_config,
        )
        missing, unexpected = self.model.load_state_dict(state_dict, strict=False)
        if missing or unexpected:
            print(f"[load_state_dict] missing={len(missing)} unexpected={len(unexpected)}")
            if missing[:5]:
                print("  missing (first 5):", missing[:5])
            if unexpected[:5]:
                print("  unexpected (first 5):", unexpected[:5])

        self.model.to(self.device).eval()

        # 6) Load Audio Tokenizer
        print(f"Loading Audio Tokenizer: {audio_model_name}")
        self.audio_processor = Wav2Vec2Processor.from_pretrained(audio_model_name)
        self.audio_model = HubertModel.from_pretrained(audio_model_name).eval().to(self.device)

    def _clean_phn(self, phn_list):
        """Standard cleaning to match training."""
        IGNORED = {"SIL", "'", "SPN", " "} 
        return [p.rstrip('012') for p in phn_list if p.rstrip('012') not in IGNORED]

    def _get_audio_tokens(self, wav_path):
        """
        Runs the audio tokenizer. 
        IMPORTANT: This must match your training data generation logic.
        """
        waveform, sr = torchaudio.load(wav_path)
        if sr != 16000:
            resampler = torchaudio.transforms.Resample(sr, 16000)
            waveform = resampler(waveform)
            
        inputs = self.audio_processor(waveform.squeeze().numpy(), return_tensors="pt", sampling_rate=16000)
        input_values = inputs.input_values.to(self.device)
        
        with torch.no_grad():
            outputs = self.audio_model(input_values)
            
        # Placeholder Quantization (Argmax) - Replace if using K-Means
        features = outputs.last_hidden_state
        tokens = torch.argmax(features, dim=-1).squeeze()
        
        # Downsample to 25Hz (Assuming model is 50Hz)
        tokens = tokens[::2]
        return tokens.unsqueeze(0) # (1, T)

    def predict(self, wav_path, text):
        # A. Prepare Inputs
        # 1. Text -> Phonemes -> IDs
        # raw_phns = self.g2p(text)
        raw_phns = text.split()  # Assuming input text is already phonemized for inference
        src_phns = self._clean_phn(raw_phns)
        
        # Create text vocab from insert_to_id (same as dataset)
        text_vocab = {k: v for k, v in self.ins_map.items() if k not in ['<NONE>', '<PAD>']}
        text_ids = [text_vocab.get(p, text_vocab.get("AA", 2)) for p in src_phns]
        text_tensor = torch.tensor([text_ids], dtype=torch.long).to(self.device)
        
        # 2. Audio -> Tokens
        audio_tensor = self._get_audio_tokens(wav_path)
        
        # B. Run Model
        with torch.no_grad():
            # Create masks
            txt_mask = torch.ones_like(text_tensor)
            aud_mask = torch.ones_like(audio_tensor)
            
            logits_op, logits_ins = self.model(
                text_tensor, audio_tensor, txt_mask, aud_mask
            )
            
            # C. Decode
            pred_ops = torch.argmax(logits_op, dim=-1).squeeze().tolist()
            pred_ins = torch.argmax(logits_ins, dim=-1).squeeze().tolist()
            
        # Ensure lists
        if not isinstance(pred_ops, list): pred_ops = [pred_ops]
        if not isinstance(pred_ins, list): pred_ins = [pred_ins]

        # D. Reconstruct Sequence
        final_phonemes = []
        log = []
        
        for i, (orig, op_id, ins_id) in enumerate(zip(src_phns, pred_ops, pred_ins)):
            
            # 1. Apply Operation
            op_str = self.id2op.get(op_id, "KEEP")
            curr_log = {"src": orig, "op": op_str, "ins": "NONE"}
            
            if op_str == "KEEP":
                final_phonemes.append(orig)
            elif op_str == "DEL":
                pass # Do not append
            elif op_str.startswith("SUB:"):
                # Extract phoneme: "SUB:AA" -> "AA"
                new_phn = op_str.split(":")[1]
                final_phonemes.append(new_phn)
            
            # 2. Apply Insertion
            ins_str = self.id2ins.get(ins_id, "<NONE>")
            if ins_str != "<NONE>":
                final_phonemes.append(ins_str)
                curr_log["ins"] = ins_str
                
            log.append(curr_log)
            
        return final_phonemes, log

if __name__ == "__main__":
    ckpt_path = "/data/chenxu/checkpoints/edit_seq_speech/phoneme-corrector/last.ckpt"
    vocab_path = "edit_seq_speech/config/vocab.json"
    wav_file = "test.wav"
    text_input = "Last Sunday"
    
    if os.path.exists(ckpt_path) and os.path.exists(wav_file):
        infer = PhonemeCorrectionInference(ckpt_path, vocab_path)
        result, details = infer.predict(wav_file, text_input)
        
        print(f"Input Text: {text_input}")
        print(f"Result Phn: {result}")
        print("-" * 20)
        for step in details:
            print(f"{step['src']} -> {step['op']} + Insert({step['ins']})")
    else:
        print("Please set valid paths for checkpoint and wav file.")