File size: 2,616 Bytes
7ef1aa8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import os
import tarfile
import tempfile
import torch
from omegaconf import OmegaConf
import nemo.collections.asr as nemo_asr
from nemo.collections.asr.models.hybrid_rnnt_ctc_bpe_models import EncDecHybridRNNTCTCBPEModel

MODEL_PATH = "c:/Users/USER/queue-buddy/mon_modele_soloni/soloni-114m-tdt-ctc-v3.nemo"

def load_with_manual_patch(nemo_path):
    with tempfile.TemporaryDirectory() as tmpdir:
        print(f"Extracting {nemo_path} to {tmpdir}...")
        with tarfile.open(nemo_path, "r") as tar:
            tar.extractall(path=tmpdir)
        
        config_path = os.path.join(tmpdir, "model_config.yaml")
        conf = OmegaConf.load(config_path)
        
        # --- PATCH CONFIG ---
        OmegaConf.set_struct(conf, False)
        # Remove training/validation data parts to avoid abstract method errors
        if 'train_ds' in conf: del conf['train_ds']
        if 'validation_ds' in conf: del conf['validation_ds']
        if 'test_ds' in conf: del conf['test_ds']
        
        # Patch the specific greedy decoding error
        if 'decoding' in conf and 'greedy' in conf.decoding and 'boosting_tree' in conf.decoding.greedy:
             if 'key_phrase_items_list' in conf.decoding.greedy.boosting_tree:
                 del conf.decoding.greedy.boosting_tree['key_phrase_items_list']
        
        print("Config patched.")
        
        # --- MONKEY PATCH ASRModel ---
        def mock_setup(self, cfg): pass
        original_setup_train = nemo_asr.models.ASRModel.setup_training_data
        nemo_asr.models.ASRModel.setup_training_data = mock_setup
        nemo_asr.models.ASRModel.setup_validation_data = mock_setup
        nemo_asr.models.ASRModel.setup_test_data = mock_setup

        # --- INSTANTIATE ---
        # Note: We use the specific class to be sure
        print("Instantiating model from config...")
        model = EncDecHybridRNNTCTCBPEModel.from_config_dict(conf)
        
        # --- LOAD WEIGHTS ---
        ckpt_path = os.path.join(tmpdir, "model_weights.ckpt")
        print(f"Loading weights from {ckpt_path}...")
        state_dict = torch.load(ckpt_path, map_location='cpu')
        model.load_state_dict(state_dict, strict=False)
        
        print("Model loaded successfully!")
        return model

if __name__ == "__main__":
    try:
        model = load_with_manual_patch(MODEL_PATH)
        # Test a transcription if model is valid
        print("Transcription test...")
        # (Add transcription test here if needed)
    except Exception as e:
        print(f"Execution failed: {e}")
        import traceback
        traceback.print_exc()