trysem commited on
Commit
4fbb118
·
verified ·
1 Parent(s): b02797a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -2
app.py CHANGED
@@ -6,6 +6,7 @@ from huggingface_hub import snapshot_download
6
  from omegaconf import OmegaConf
7
  from nemo.collections.asr.models import ASRModel
8
 
 
9
  MODEL_NAME = "trysem/stt_ml_fastconformer_ctc_med_punct"
10
 
11
  print("1. Downloading unzipped model files from Hugging Face...")
@@ -15,12 +16,20 @@ print("2. Patching the configuration...")
15
  config_path = os.path.join(model_dir, "model_config.yaml")
16
  config = OmegaConf.load(config_path)
17
 
18
- # Allow modifications to the config object and remove PyTorch 2.0 SDPA keys
19
  OmegaConf.set_struct(config, False)
 
 
20
  if 'encoder' in config:
21
  config.encoder.pop('use_pytorch_sdpa', None)
22
  config.encoder.pop('use_pytorch_sdpa_backends', None)
23
 
 
 
 
 
 
 
24
  print("3. Packaging files into a standard .nemo archive...")
25
  patched_dir = "patched_nemo_env"
26
  os.makedirs(patched_dir, exist_ok=True)
@@ -43,7 +52,6 @@ with tarfile.open(nemo_filepath, "w") as tar:
43
  tar.add(os.path.join(patched_dir, item), arcname=item)
44
 
45
  print("4. Restoring model from patched .nemo file (this may take a moment)...")
46
- # NeMo will be happy because it is receiving a proper file archive!
47
  model = ASRModel.restore_from(restore_path=nemo_filepath)
48
  model.eval()
49
  print("Model loaded successfully!")
 
6
  from omegaconf import OmegaConf
7
  from nemo.collections.asr.models import ASRModel
8
 
9
+ # Use the correct repository name
10
  MODEL_NAME = "trysem/stt_ml_fastconformer_ctc_med_punct"
11
 
12
  print("1. Downloading unzipped model files from Hugging Face...")
 
16
  config_path = os.path.join(model_dir, "model_config.yaml")
17
  config = OmegaConf.load(config_path)
18
 
19
+ # Allow modifications to the config object
20
  OmegaConf.set_struct(config, False)
21
+
22
+ # Patch 1: Remove PyTorch 2.0 SDPA keys for NeMo 1.23 compatibility
23
  if 'encoder' in config:
24
  config.encoder.pop('use_pytorch_sdpa', None)
25
  config.encoder.pop('use_pytorch_sdpa_backends', None)
26
 
27
+ # Patch 2: Downgrade 'greedy_batch' strategy to 'greedy'
28
+ if 'decoding' in config:
29
+ if config.decoding.get('strategy') == 'greedy_batch':
30
+ print(" -> Downgrading decoding strategy from 'greedy_batch' to 'greedy'")
31
+ config.decoding.strategy = 'greedy'
32
+
33
  print("3. Packaging files into a standard .nemo archive...")
34
  patched_dir = "patched_nemo_env"
35
  os.makedirs(patched_dir, exist_ok=True)
 
52
  tar.add(os.path.join(patched_dir, item), arcname=item)
53
 
54
  print("4. Restoring model from patched .nemo file (this may take a moment)...")
 
55
  model = ASRModel.restore_from(restore_path=nemo_filepath)
56
  model.eval()
57
  print("Model loaded successfully!")