Update app.py
Browse files
app.py
CHANGED
|
@@ -6,6 +6,7 @@ from huggingface_hub import snapshot_download
|
|
| 6 |
from omegaconf import OmegaConf
|
| 7 |
from nemo.collections.asr.models import ASRModel
|
| 8 |
|
|
|
|
| 9 |
MODEL_NAME = "trysem/stt_ml_fastconformer_ctc_med_punct"
|
| 10 |
|
| 11 |
print("1. Downloading unzipped model files from Hugging Face...")
|
|
@@ -15,12 +16,20 @@ print("2. Patching the configuration...")
|
|
| 15 |
config_path = os.path.join(model_dir, "model_config.yaml")
|
| 16 |
config = OmegaConf.load(config_path)
|
| 17 |
|
| 18 |
-
# Allow modifications to the config object
|
| 19 |
OmegaConf.set_struct(config, False)
|
|
|
|
|
|
|
| 20 |
if 'encoder' in config:
|
| 21 |
config.encoder.pop('use_pytorch_sdpa', None)
|
| 22 |
config.encoder.pop('use_pytorch_sdpa_backends', None)
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
print("3. Packaging files into a standard .nemo archive...")
|
| 25 |
patched_dir = "patched_nemo_env"
|
| 26 |
os.makedirs(patched_dir, exist_ok=True)
|
|
@@ -43,7 +52,6 @@ with tarfile.open(nemo_filepath, "w") as tar:
|
|
| 43 |
tar.add(os.path.join(patched_dir, item), arcname=item)
|
| 44 |
|
| 45 |
print("4. Restoring model from patched .nemo file (this may take a moment)...")
|
| 46 |
-
# NeMo will be happy because it is receiving a proper file archive!
|
| 47 |
model = ASRModel.restore_from(restore_path=nemo_filepath)
|
| 48 |
model.eval()
|
| 49 |
print("Model loaded successfully!")
|
|
|
|
| 6 |
from omegaconf import OmegaConf
|
| 7 |
from nemo.collections.asr.models import ASRModel
|
| 8 |
|
| 9 |
+
# Use the correct repository name
|
| 10 |
MODEL_NAME = "trysem/stt_ml_fastconformer_ctc_med_punct"
|
| 11 |
|
| 12 |
print("1. Downloading unzipped model files from Hugging Face...")
|
|
|
|
| 16 |
config_path = os.path.join(model_dir, "model_config.yaml")
|
| 17 |
config = OmegaConf.load(config_path)
|
| 18 |
|
| 19 |
+
# Allow modifications to the config object
|
| 20 |
OmegaConf.set_struct(config, False)
|
| 21 |
+
|
| 22 |
+
# Patch 1: Remove PyTorch 2.0 SDPA keys for NeMo 1.23 compatibility
|
| 23 |
if 'encoder' in config:
|
| 24 |
config.encoder.pop('use_pytorch_sdpa', None)
|
| 25 |
config.encoder.pop('use_pytorch_sdpa_backends', None)
|
| 26 |
|
| 27 |
+
# Patch 2: Downgrade 'greedy_batch' strategy to 'greedy'
|
| 28 |
+
if 'decoding' in config:
|
| 29 |
+
if config.decoding.get('strategy') == 'greedy_batch':
|
| 30 |
+
print(" -> Downgrading decoding strategy from 'greedy_batch' to 'greedy'")
|
| 31 |
+
config.decoding.strategy = 'greedy'
|
| 32 |
+
|
| 33 |
print("3. Packaging files into a standard .nemo archive...")
|
| 34 |
patched_dir = "patched_nemo_env"
|
| 35 |
os.makedirs(patched_dir, exist_ok=True)
|
|
|
|
| 52 |
tar.add(os.path.join(patched_dir, item), arcname=item)
|
| 53 |
|
| 54 |
print("4. Restoring model from patched .nemo file (this may take a moment)...")
|
|
|
|
| 55 |
model = ASRModel.restore_from(restore_path=nemo_filepath)
|
| 56 |
model.eval()
|
| 57 |
print("Model loaded successfully!")
|