Update app.py
Browse files
app.py
CHANGED
|
@@ -6,7 +6,7 @@ from huggingface_hub import snapshot_download
|
|
| 6 |
from omegaconf import OmegaConf
|
| 7 |
from nemo.collections.asr.models import ASRModel
|
| 8 |
|
| 9 |
-
#
|
| 10 |
MODEL_NAME = "trysem/stt_ml_fastconformer_ctc_med_punct"
|
| 11 |
|
| 12 |
print("1. Downloading unzipped model files from Hugging Face...")
|
|
@@ -24,11 +24,18 @@ if 'encoder' in config:
|
|
| 24 |
config.encoder.pop('use_pytorch_sdpa', None)
|
| 25 |
config.encoder.pop('use_pytorch_sdpa_backends', None)
|
| 26 |
|
| 27 |
-
# Patch 2:
|
| 28 |
if 'decoding' in config:
|
|
|
|
| 29 |
if config.decoding.get('strategy') == 'greedy_batch':
|
| 30 |
print(" -> Downgrading decoding strategy from 'greedy_batch' to 'greedy'")
|
| 31 |
config.decoding.strategy = 'greedy'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
print("3. Packaging files into a standard .nemo archive...")
|
| 34 |
patched_dir = "patched_nemo_env"
|
|
|
|
| 6 |
from omegaconf import OmegaConf
|
| 7 |
from nemo.collections.asr.models import ASRModel
|
| 8 |
|
| 9 |
+
# Pointing to your cloned model repository
|
| 10 |
MODEL_NAME = "trysem/stt_ml_fastconformer_ctc_med_punct"
|
| 11 |
|
| 12 |
print("1. Downloading unzipped model files from Hugging Face...")
|
|
|
|
| 24 |
config.encoder.pop('use_pytorch_sdpa', None)
|
| 25 |
config.encoder.pop('use_pytorch_sdpa_backends', None)
|
| 26 |
|
| 27 |
+
# Patch 2 & 3: Fix decoding strategy and confidence config
|
| 28 |
if 'decoding' in config:
|
| 29 |
+
# Patch 2: Downgrade 'greedy_batch' strategy to 'greedy'
|
| 30 |
if config.decoding.get('strategy') == 'greedy_batch':
|
| 31 |
print(" -> Downgrading decoding strategy from 'greedy_batch' to 'greedy'")
|
| 32 |
config.decoding.strategy = 'greedy'
|
| 33 |
+
|
| 34 |
+
# Patch 3: Remove 'tdt_include_duration' which NeMo 1.23 doesn't recognize
|
| 35 |
+
if 'confidence_cfg' in config.decoding and config.decoding.confidence_cfg is not None:
|
| 36 |
+
if 'tdt_include_duration' in config.decoding.confidence_cfg:
|
| 37 |
+
print(" -> Removing 'tdt_include_duration' from confidence_cfg")
|
| 38 |
+
config.decoding.confidence_cfg.pop('tdt_include_duration', None)
|
| 39 |
|
| 40 |
print("3. Packaging files into a standard .nemo archive...")
|
| 41 |
patched_dir = "patched_nemo_env"
|