sira-asr / test_manual_patch.py
Antigravity AI
fix: resolve build dependencies and add CORS support
7ef1aa8
import os
import tarfile
import tempfile
import torch
from omegaconf import OmegaConf
import nemo.collections.asr as nemo_asr
from nemo.collections.asr.models.hybrid_rnnt_ctc_bpe_models import EncDecHybridRNNTCTCBPEModel
MODEL_PATH = "c:/Users/USER/queue-buddy/mon_modele_soloni/soloni-114m-tdt-ctc-v3.nemo"
def load_with_manual_patch(nemo_path):
with tempfile.TemporaryDirectory() as tmpdir:
print(f"Extracting {nemo_path} to {tmpdir}...")
with tarfile.open(nemo_path, "r") as tar:
tar.extractall(path=tmpdir)
config_path = os.path.join(tmpdir, "model_config.yaml")
conf = OmegaConf.load(config_path)
# --- PATCH CONFIG ---
OmegaConf.set_struct(conf, False)
# Remove training/validation data parts to avoid abstract method errors
if 'train_ds' in conf: del conf['train_ds']
if 'validation_ds' in conf: del conf['validation_ds']
if 'test_ds' in conf: del conf['test_ds']
# Patch the specific greedy decoding error
if 'decoding' in conf and 'greedy' in conf.decoding and 'boosting_tree' in conf.decoding.greedy:
if 'key_phrase_items_list' in conf.decoding.greedy.boosting_tree:
del conf.decoding.greedy.boosting_tree['key_phrase_items_list']
print("Config patched.")
# --- MONKEY PATCH ASRModel ---
def mock_setup(self, cfg): pass
original_setup_train = nemo_asr.models.ASRModel.setup_training_data
nemo_asr.models.ASRModel.setup_training_data = mock_setup
nemo_asr.models.ASRModel.setup_validation_data = mock_setup
nemo_asr.models.ASRModel.setup_test_data = mock_setup
# --- INSTANTIATE ---
# Note: We use the specific class to be sure
print("Instantiating model from config...")
model = EncDecHybridRNNTCTCBPEModel.from_config_dict(conf)
# --- LOAD WEIGHTS ---
ckpt_path = os.path.join(tmpdir, "model_weights.ckpt")
print(f"Loading weights from {ckpt_path}...")
state_dict = torch.load(ckpt_path, map_location='cpu')
model.load_state_dict(state_dict, strict=False)
print("Model loaded successfully!")
return model
if __name__ == "__main__":
try:
model = load_with_manual_patch(MODEL_PATH)
# Test a transcription if model is valid
print("Transcription test...")
# (Add transcription test here if needed)
except Exception as e:
print(f"Execution failed: {e}")
import traceback
traceback.print_exc()