Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -63,7 +63,11 @@ s2l = Speech2Language.from_pretrained(
|
|
| 63 |
device=device,
|
| 64 |
nbest=1,
|
| 65 |
)
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
model_tag=f"espnet/owsm_v4_medium_1B",
|
| 68 |
device=device,
|
| 69 |
beam_size=5,
|
|
@@ -74,6 +78,51 @@ s2t_ar = ARSpeech2Text.from_pretrained(
|
|
| 74 |
task_sym="<asr>",
|
| 75 |
predict_time=False,
|
| 76 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
s2t_ctc = CTCSpeech2Text.from_pretrained(
|
| 78 |
model_tag=f"espnet/owsm_ctc_v4_1B",
|
| 79 |
device=device,
|
|
|
|
| 63 |
device=device,
|
| 64 |
nbest=1,
|
| 65 |
)
|
| 66 |
+
|
| 67 |
+
# Hacking to cange config
|
| 68 |
+
# 1. download files
|
| 69 |
+
try:
|
| 70 |
+
s2t_ar = ARSpeech2Text.from_pretrained(
|
| 71 |
model_tag=f"espnet/owsm_v4_medium_1B",
|
| 72 |
device=device,
|
| 73 |
beam_size=5,
|
|
|
|
| 78 |
task_sym="<asr>",
|
| 79 |
predict_time=False,
|
| 80 |
)
|
| 81 |
+
except Exception as e:
|
| 82 |
+
print("File downloaded")
|
| 83 |
+
|
| 84 |
+
# 2. Remove unrequired file
|
| 85 |
+
import yaml
|
| 86 |
+
from pathlib import Path
|
| 87 |
+
import espnet_model_zoo
|
| 88 |
+
|
| 89 |
+
d = "models--espnet--owsm_v4_medium_1B/snapshots/471418ddaf0b03c9ab1fd75f1f5d26fc3aea3aa9/exp/s2t_train_conv2d8_size1024_e18_d18_mel128_raw_bpe50000/config.yaml"
|
| 90 |
+
p = Path(espnet_model_zoo.__file__)
|
| 91 |
+
config_path = p.parent / d
|
| 92 |
+
|
| 93 |
+
def remove_key(obj, key="gradient_checkpoint_layers"):
|
| 94 |
+
if isinstance(obj, dict):
|
| 95 |
+
if key in obj:
|
| 96 |
+
del obj[key]
|
| 97 |
+
for k, v in list(obj.items()):
|
| 98 |
+
remove_key(v, key)
|
| 99 |
+
elif isinstance(obj, list):
|
| 100 |
+
for item in obj:
|
| 101 |
+
remove_key(item, key)
|
| 102 |
+
|
| 103 |
+
with open(config_path, "r") as f:
|
| 104 |
+
config = yaml.safe_load(f)
|
| 105 |
+
|
| 106 |
+
remove_key(config)
|
| 107 |
+
|
| 108 |
+
with open(config_path, "w") as f:
|
| 109 |
+
yaml.safe_dump(config, f, sort_keys=False, allow_unicode=True)
|
| 110 |
+
|
| 111 |
+
print("Done! All 'gradient_checkpoint_layers' keys removed.")
|
| 112 |
+
|
| 113 |
+
s2t_ar = ARSpeech2Text.from_pretrained(
|
| 114 |
+
model_tag=f"espnet/owsm_v4_medium_1B",
|
| 115 |
+
device=device,
|
| 116 |
+
beam_size=5,
|
| 117 |
+
ctc_weight=0.0,
|
| 118 |
+
maxlenratio=0.0,
|
| 119 |
+
# below are default values which can be overwritten in __call__
|
| 120 |
+
lang_sym="<eng>",
|
| 121 |
+
task_sym="<asr>",
|
| 122 |
+
predict_time=False,
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
# CTC looks okay.
|
| 126 |
s2t_ctc = CTCSpeech2Text.from_pretrained(
|
| 127 |
model_tag=f"espnet/owsm_ctc_v4_1B",
|
| 128 |
device=device,
|