| import os |
| import tarfile |
| import tempfile |
| import torch |
| from omegaconf import OmegaConf |
| import nemo.collections.asr as nemo_asr |
| from nemo.collections.asr.models.hybrid_rnnt_ctc_bpe_models import EncDecHybridRNNTCTCBPEModel |
|
|
| MODEL_PATH = "c:/Users/USER/queue-buddy/mon_modele_soloni/soloni-114m-tdt-ctc-v3.nemo" |
|
|
| def load_with_manual_patch(nemo_path): |
| with tempfile.TemporaryDirectory() as tmpdir: |
| print(f"Extracting {nemo_path} to {tmpdir}...") |
| with tarfile.open(nemo_path, "r") as tar: |
| tar.extractall(path=tmpdir) |
| |
| config_path = os.path.join(tmpdir, "model_config.yaml") |
| conf = OmegaConf.load(config_path) |
| |
| |
| OmegaConf.set_struct(conf, False) |
| |
| if 'train_ds' in conf: del conf['train_ds'] |
| if 'validation_ds' in conf: del conf['validation_ds'] |
| if 'test_ds' in conf: del conf['test_ds'] |
| |
| |
| if 'decoding' in conf and 'greedy' in conf.decoding and 'boosting_tree' in conf.decoding.greedy: |
| if 'key_phrase_items_list' in conf.decoding.greedy.boosting_tree: |
| del conf.decoding.greedy.boosting_tree['key_phrase_items_list'] |
| |
| print("Config patched.") |
| |
| |
| def mock_setup(self, cfg): pass |
| original_setup_train = nemo_asr.models.ASRModel.setup_training_data |
| nemo_asr.models.ASRModel.setup_training_data = mock_setup |
| nemo_asr.models.ASRModel.setup_validation_data = mock_setup |
| nemo_asr.models.ASRModel.setup_test_data = mock_setup |
|
|
| |
| |
| print("Instantiating model from config...") |
| model = EncDecHybridRNNTCTCBPEModel.from_config_dict(conf) |
| |
| |
| ckpt_path = os.path.join(tmpdir, "model_weights.ckpt") |
| print(f"Loading weights from {ckpt_path}...") |
| state_dict = torch.load(ckpt_path, map_location='cpu') |
| model.load_state_dict(state_dict, strict=False) |
| |
| print("Model loaded successfully!") |
| return model |
|
|
| if __name__ == "__main__": |
| try: |
| model = load_with_manual_patch(MODEL_PATH) |
| |
| print("Transcription test...") |
| |
| except Exception as e: |
| print(f"Execution failed: {e}") |
| import traceback |
| traceback.print_exc() |
|
|